본문 바로가기

PyTorch

PyTorch 기초1: 코드짜는 법

PyTorch로 네트워크를 설계하는 방법은 크게 3가지로 나눌 수 있습니다.

   1) Barebones PyTorch
   2) PyTorch Module API
   3) PyTorch Sequential API

 

각각을 간략히 둘러보겠습니다.

 

1) Barebones PyTorch

# weight initialization
conv_w1 = torch.empty((channel_1, C, kernel_size_1, kernel_size_1), dtype=dtype).to(device)
conv_b1 = torch.empty(channel_1, dtype=dtype).to(device)
conv_w1 = nn.init.kaiming_normal_(conv_w1)
conv_b1 = nn.init.zeros_(conv_b1)

conv_w2 = torch.empty((channel_2, channel_1, kernel_size_2, kernel_size_2), dtype=dtype).to(device)
conv_b2 = torch.empty(channel_2, dtype=dtype).to(device)
conv_w2 = nn.init.kaiming_normal_(conv_w2)
conv_b2 = nn.init.zeros_(conv_b2)

H1 = (H - kernel_size_1 + pad_size_1*2) + 1
H2 = (H1 - kernel_size_2 + pad_size_2*2) + 1

fc_w = torch.empty(num_classes, channel_2*H2*H2, dtype=dtype).to(device)
fc_b = torch.empty(num_classes, dtype=dtype).to(device)
fc_w = nn.init.kaiming_normal_(fc_w)
fc_b = nn.init.zeros_(fc_b)

for tensor in [conv_w1, conv_b1, conv_w2, conv_b2, fc_w, fc_b]:
    tensor.requires_grad = True

# forward
x = F.conv2d(x, conv_w1, conv_b1, padding=pad_size_1)
x = F.relu(x)
x = F.conv2d(x, conv_w2, conv_b2, padding=pad_size_2)
x = F.relu(x)
x = x.flatten(1)
x = F.linear(x, fc_w, fc_b)
scores = x

2) PyTorch Module API

class ThreeLayerConvNet(nn.Module):
    def __init__(self, in_shape, conv_params, num_classes):
        super().__init__()

        C, H, W = in_shape
        channel_1, kernel_size_1, pad_size_1 = conv_params[0]
        channel_2, kernel_size_2, pad_size_2 = conv_params[1]

        self.conv1 = nn.Conv2d(C, channel_1, kernel_size_1, padding=pad_size_1)
        self.conv2 = nn.Conv2d(channel_1, channel_2, kernel_size_2, padding=pad_size_2)

        H1 = (H - kernel_size_1 + pad_size_1*2) + 1
        H2 = (H1 - kernel_size_2 + pad_size_2*2) + 1
        self.fc = nn.Linear(channel_2*H2*H2, num_classes)

        nn.init.kaiming_normal_(self.conv1.weight)
        nn.init.zeros_(self.conv1.bias)

        nn.init.kaiming_normal_(self.conv2.weight)
        nn.init.zeros_(self.conv2.bias)

        nn.init.kaiming_normal_(self.fc.weight)
        nn.init.zeros_(self.fc.bias)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        scores = self.fc(x.view(x.shape[0], -1))
        
        return scores

3) PyTorch Sequential API 

model = nn.Sequential(OrderedDict([
        ('conv1', nn.Conv2d(C, channel_1, kernel_size_1, padding=pad_size_1)),
        ('relu1', nn.ReLU()),
        ('conv2', nn.Conv2d(C, channel_1, kernel_size_1, padding=pad_size_1)),
        ('relu2', nn.ReLU()),
        ('flatten', nn.Flatten()),
        ('fc1', nn.Linear(channel_2*H*W, num_classes)),
    ]))