No Limitation
[ Debugging ] Expected more than 1 spatial element when training, got input size torch.Size([~,~,..]) 문제 다루기 본문
프로그래밍
[ Debugging ] Expected more than 1 spatial element when training, got input size torch.Size([~,~,..]) 문제 다루기
yesungcho 2022. 7. 14. 16:52참고 포스팅
https://koowater.tistory.com/4?category=903773
https://discuss.pytorch.org/t/error-expected-more-than-1-value-per-channel-when-training/26274
메모리 문제로 batch를 줄여가며 기존 12에서 -> 4까지 줄일 때, 메모리 문제를 극복했지만 다음과 같은 문제가 발생하였습니다.
배치를 조절하다보면 종종 있는 이슈인데, 이 경우 다음과 같은 해결책으로 접근을 시도할 수 있습니다.
1) 주로 batch size를 설정하고 남는 batch에 대해서 연산이 오류가 나는 경우이므로, DataLoader 파라미터로 drop_last=True 설정
1번은 직관적이니 따로 예시는 넣지 않겠습니다.
train_data_loader = DataLoader(..., drop_last=True)
val_data_loader = DataLoader(..., drop_last=True)
대충 끝단에 저런 파라미터를 넣어주면 됩니다.
2) batchnorm 부분에 대해 eval 모드로 설정하고 batch 통과 이후 다시 train 전환
구체적으로 살펴보면, 다음과 같이 수정이 가능합니다.
실제 모형에 있는 코드를 가져왔는데요
class BlockUNet1(nn.Module):
def __init__(self, in_channels, out_channels, upsample=False, relu=False, drop=False, bn=True):
super(BlockUNet1, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)
self.deconv = nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
self.dropout = nn.Dropout2d(0.5)
self.batch = nn.InstanceNorm2d(out_channels)
self.upsample = upsample
self.relu = relu
self.drop = drop
self.bn = bn
def forward(self, x):
if self.relu == True:
y = F.relu(x)
elif self.relu == False:
y = F.leaky_relu(x, 0.2)
if self.upsample == True:
y = self.deconv(y)
if self.bn == True:
y = self.batch(y)
if self.drop == True:
y = self.dropout(y)
elif self.upsample == False:
y = self.conv(y)
if self.bn == True:
y = self.batch(y)
if self.drop == True:
y = self.dropout(y)
이렇게 코드가 되어 있는 부분이었습니다. 저기서 저 같은 경우는 InstanceNorm2d -> BatchNorm2d로 바꾸고
y = self.batch(y)로 되어 있는 부분에 .eval()을 적용하여 수정하였습니다. 아래 부분이 수정된 부분입니다.
self.batch = nn.InstanceNorm2d(out_channels)
을
self.batch = nn.BatchNorm2d(out_channels)
로 바꾸고
y = self.batch(y)
여기를
try :
y = self.batch(y)
except :
self.batch.eval()
y = self.batch(y)
self.batch.train()
로 바꾸어 해당 에러를 해결할 수 있었습니다. 실제 프로그래밍하면서 자주 발생할 수 있는 문제이니 잊지 말기....!
'프로그래밍' 카테고리의 다른 글
[Ubuntu] SSH Server 접속 오류 해결 (0) | 2022.09.06 |
---|---|
[ Debugging ] Size of tensors must match except in dimension 에러 해결 (0) | 2022.07.15 |
wget 명령어 오류 관련 (0) | 2022.06.15 |
[Greedy] 신입사원 - 백준 (0) | 2022.03.02 |
[Graph] MST - Kruskal Algorithm - 백준 (0) | 2022.02.27 |