No Limitation
[ Debugging ] Size of tensors must match except in dimension 에러 해결 본문
[ Debugging ] Size of tensors must match except in dimension 에러 해결
yesungcho 2022. 7. 15. 11:55본 포스팅은 개인 공부용으로 작성되었습니다.
U-Net 아키텍처를 사용하다 위와 같은 에러가 발생한 경우가 있었다.
U-Net은 convolution을 해줄 때의 feature map과 후에 deconvolution을 해주는 부분 간의 concat이 이루어지는 부분이 있다.
예를 들면 아래 그림과 같은 부분에서다.
위 그림은 U-Net 원 논문에 있는 그림인데 빨간색 그림처럼 표시된 부분에서 concat이 이루어지는데 상위 level의 feature map과의 concat을 해주어야 하기 때문에 dimension이 맞아야 하는 조건이 있다.
하지만 git에서 긁어온 U-Net을 사용한 코드에서 이러한 부분에 에러가 발생하는 경우가 존재하였다.
예를 들어 x (input) -> y1 -> y2, ... -> y8가지 내려갔다고 했을 때
이제 다시 deconv를 수행해주어야 하므로 y8을 deconv 연산을 통해 dy8을 만들고
이를 이전 level인 y7과 concat을 해주어야 하는데
위 print한 결과를 보면
y7의 shape은 [1,64,3,4]
dy8의 shape은 [1,64,2,4]인 문제가 있었다. 바로 이러한 부분에서는 이렇게 dimension이 다른 이슈가 존재할 수 있기 때문에 적절한 padding을 수행해주어야 한다.
따라서 본 U-Net을 구현한 github의 코드를 참고하니 다음과 같은 부분이 존재하였다.
즉 위의 코드에서도 x1와 x2를 67번 줄에서 concat을 수행할 때 dimension이 맞지 않은 경우가 있으므로 저렇게 F.pad를 이용해 패딩을 수행해준다. 따라서 필자도 위의 코드를 적용해서 위 문제를 다음과 같이 shape을 맞추어주어
다음과 같이 문제를 해결할 수 있었다.
이러한 padding 문제를 푸는 방법은 많이 있는데, 예를 들어 32의 배수로 shape을 통일해주어야 하는 경우도 있다.
예를 들어 필자의 경우도 입력이 (1,3,413,550)으로 되어 있는 경우가 있는데 2,3 dim 쪽을 32의 배수로 맞추어, 예를 들어 (1,3,448,576)으로 바꾸어주는 zero-padding을 수행해주고 싶을 경우가 있는데 이런 경우는 F.pad() 함수를 사용하면 이를 맞출 수 있다.
이에 대해 자세한 사항은 다음 블로그를 참고하면 좋을 것 같다. 꼭 참고하길...!
https://hichoe95.tistory.com/116
위 task를 풀기 위해 다음과 같은 절차를 밟으면 padding을 수행해줄 수 있다.
import torch
import torch.nn.functional as F
t4d = torch.empty(1,3,413,550) # original
p1d = (1,1,34,1)
print(F.pad(t4d,p1d,'constant',0).shape) ## print result : torch.Size([1, 3, 448, 552])
i_1 = F.pad(t4d,p1d,'constant',0)
p2d = (12,12)
print(F.pad(i_1,p2d,'constant',0).shape) ## print result : torch.Size([1, 3, 448, 576])
이러한 텐서의 shape이나 dimension을 다루는 연산에도 익숙해지자...!
참고 포스팅
https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py
https://hichoe95.tistory.com/116
'프로그래밍' 카테고리의 다른 글
[Ubuntu] nvidia-smi, 그래픽 드라이버 연결 문제 해결 (0) | 2022.09.15 |
---|---|
[Ubuntu] SSH Server 접속 오류 해결 (0) | 2022.09.06 |
[ Debugging ] Expected more than 1 spatial element when training, got input size torch.Size([~,~,..]) 문제 다루기 (0) | 2022.07.14 |
wget 명령어 오류 관련 (0) | 2022.06.15 |
[Greedy] 신입사원 - 백준 (0) | 2022.03.02 |