AI/Pytorch

[Pytorch] gather 함수

surrr 2023. 4. 12. 15:17

이번에는 pytorch의 gather 함수에 대해 포스팅 해보도록 하겠습니다.

간단하게 torch.gather는 특정 인덱스를 쉽게 추출하기 위한 함수입니다!


다음은 파이토치 공식 문서에 있는 내용입니다.

공식문서의 내용이 궁금하신 분은 밑 함수부분에 링크를 연결해놓았으니 참조해주시기 바랍니다.

torch.gather(input, dim, index, *, sparse_grad=False, out=None)  Tensor

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

 

여기서 input은 복사하고자 하는 tensor이며,

dim은 어느 방향으로 복사할지,

index는 input tensor을 모아줄 인덱스로 이루어진 tensor 입니다.

(input tensor에서 무엇을 추출할 지 인덱스 정보를 저장하고 있는 tensor)

 

여기서 차원을 맞춰주어야 하는 조건이 있는데,

input이 (C, M, N) 으로 이루어져 있을 때 dim에 따라서 index의 차원을 맞추어 주어야 합니다.

가령,

input이 (C, M, N)이고, dim = 0 이면 index 행렬은 (X, M, N) 으로,

input이 (C, M, N)이고, dim = 1 이면 index 행렬은 (C, X, N) 으로,

input이 (C, M, N)이고, dim = 2 이면 index 행렬은 (C, M, X) 으로 맞추어 주어야 합니다.

즉 dim을 제외한 나머지 부분의 차원을 맞추어주어야 합니다.

 

간단하게 그림을 통해서 알아봅시다.


2차원 행렬의 경우

예시 1

input = [[1,2,3],[4,5,6],[7,8,9]]

dim = 0

index = [[0,1,2],[2,0,1]]

 

예시 2

input = [[1,2,3],[4,5,6],[7,8,9]]

dim = 1

index = [[0,2],[0,1],[2,0]]


3차원 행렬의 경우

2차원 행렬을 통해 gather가 어떻게 작동하는 지 이해했다면 3차원 행렬은 매우 간단합니다!

그저 차원하나를 추가만 하면 됩니다.

그림을 그리기 난해하니 간단한 예제를 통해 알아봅시다.

input = [[[1,2,3],[4,5,6],[7,8,9]],

               [[10,11,12],[13,14,15],[16,17,18]],

               [[19,20,21],[22,23,24],[25,26,27]]]

dim = 0

index = [[[0,1,2],[2,1,0],[1,1,0]]]

 

output tensor => [[[input[0][0][0], input[1][0][1], input[2][0][2]],

                              [input[2][1][0], input[1][1][1], input[0][1][2]],

                              [input[1][2][0], input[1][2][1], input[0][2][2]]]]

 

예시

 

 

References

https://pytorch.org/docs/stable/generated/torch.gather.html

https://stackoverflow.com/questions/50999977/what-does-the-gather-function-do-in-pytorch-in-layman-terms

https://velog.io/@nawnoes/torch.gather란