Notice
Recent Posts
Recent Comments
«   2024/09   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30
Tags
more
Archives
Today
Total
관리 메뉴

No Limitation

[ Debugging ] Expected more than 1 spatial element when training, got input size torch.Size([~,~,..]) 문제 다루기 본문

프로그래밍

[ Debugging ] Expected more than 1 spatial element when training, got input size torch.Size([~,~,..]) 문제 다루기

yesungcho 2022. 7. 14. 16:52

참고 포스팅

https://koowater.tistory.com/4?category=903773 

 

[PyTorch] ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 32])

Traceback (most recent call last): File "train.py", line 193, in g_loss = torch.mean(torch.abs(netD(gen_imgs) - gen_imgs)) File "/home/cvmi-koo/.local/lib/python3.8/site-packages/torch/nn/modules/mo..

koowater.tistory.com

https://discuss.pytorch.org/t/error-expected-more-than-1-value-per-channel-when-training/26274

 

Error: Expected more than 1 value per channel when training

I have a model that works perfectly when there are multiple input. However, if there is only one datapoint as input, I will get the above error. Does anyone have an idea on what’s going on here?

discuss.pytorch.org

 

 

메모리 문제로 batch를 줄여가며 기존 12에서 -> 4까지 줄일 때, 메모리 문제를 극복했지만 다음과 같은 문제가 발생하였습니다.

배치를 조절하다보면 종종 있는 이슈인데, 이 경우 다음과 같은 해결책으로 접근을 시도할 수 있습니다.

 

1) 주로 batch size를 설정하고 남는 batch에 대해서 연산이 오류가 나는 경우이므로, DataLoader 파라미터로 drop_last=True 설정

 

1번은 직관적이니 따로 예시는 넣지 않겠습니다.

train_data_loader = DataLoader(..., drop_last=True)
val_data_loader = DataLoader(..., drop_last=True)

대충 끝단에 저런 파라미터를 넣어주면 됩니다. 

 

2) batchnorm 부분에 대해 eval 모드로 설정하고 batch 통과 이후 다시 train 전환

 

구체적으로 살펴보면, 다음과 같이 수정이 가능합니다.

 

실제 모형에 있는 코드를 가져왔는데요

class BlockUNet1(nn.Module):
    def __init__(self, in_channels, out_channels, upsample=False, relu=False, drop=False, bn=True):
        super(BlockUNet1, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)
        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)

        self.dropout = nn.Dropout2d(0.5)
        self.batch = nn.InstanceNorm2d(out_channels)

        self.upsample = upsample
        self.relu = relu
        self.drop = drop
        self.bn = bn

    def forward(self, x):
        if self.relu == True:
            y = F.relu(x)
        elif self.relu == False:
            y = F.leaky_relu(x, 0.2)
        if self.upsample == True:
            y = self.deconv(y)
            if self.bn == True:
                y = self.batch(y)
            if self.drop == True:
                y = self.dropout(y)

        elif self.upsample == False:
            y = self.conv(y)
            if self.bn == True:
                y = self.batch(y)
            if self.drop == True:
                y = self.dropout(y)

 

이렇게 코드가 되어 있는 부분이었습니다. 저기서 저 같은 경우는 InstanceNorm2d -> BatchNorm2d로 바꾸고

y = self.batch(y)로 되어 있는 부분에 .eval()을 적용하여 수정하였습니다. 아래 부분이 수정된 부분입니다. 

 

self.batch = nn.InstanceNorm2d(out_channels)
self.batch = nn.BatchNorm2d(out_channels) 
로 바꾸고
 
y = self.batch(y)
여기를
try :
    y = self.batch(y)
except :
    self.batch.eval()
    y = self.batch(y)
    self.batch.train()
 로 바꾸어 해당 에러를 해결할 수 있었습니다. 
 

실제 프로그래밍하면서 자주 발생할 수 있는 문제이니 잊지 말기....!