PyTorch로 네트워크를 설계하는 방법은 크게 3가지로 나눌 수 있습니다.
1) Barebones PyTorch
2) PyTorch Module API
3) PyTorch Sequential API
각각을 간략히 둘러보겠습니다.
1) Barebones PyTorch
# weight initialization
conv_w1 = torch.empty((channel_1, C, kernel_size_1, kernel_size_1), dtype=dtype).to(device)
conv_b1 = torch.empty(channel_1, dtype=dtype).to(device)
conv_w1 = nn.init.kaiming_normal_(conv_w1)
conv_b1 = nn.init.zeros_(conv_b1)
conv_w2 = torch.empty((channel_2, channel_1, kernel_size_2, kernel_size_2), dtype=dtype).to(device)
conv_b2 = torch.empty(channel_2, dtype=dtype).to(device)
conv_w2 = nn.init.kaiming_normal_(conv_w2)
conv_b2 = nn.init.zeros_(conv_b2)
H1 = (H - kernel_size_1 + pad_size_1*2) + 1
H2 = (H1 - kernel_size_2 + pad_size_2*2) + 1
fc_w = torch.empty(num_classes, channel_2*H2*H2, dtype=dtype).to(device)
fc_b = torch.empty(num_classes, dtype=dtype).to(device)
fc_w = nn.init.kaiming_normal_(fc_w)
fc_b = nn.init.zeros_(fc_b)
for tensor in [conv_w1, conv_b1, conv_w2, conv_b2, fc_w, fc_b]:
tensor.requires_grad = True
# forward
x = F.conv2d(x, conv_w1, conv_b1, padding=pad_size_1)
x = F.relu(x)
x = F.conv2d(x, conv_w2, conv_b2, padding=pad_size_2)
x = F.relu(x)
x = x.flatten(1)
x = F.linear(x, fc_w, fc_b)
scores = x
2) PyTorch Module API
class ThreeLayerConvNet(nn.Module):
def __init__(self, in_shape, conv_params, num_classes):
super().__init__()
C, H, W = in_shape
channel_1, kernel_size_1, pad_size_1 = conv_params[0]
channel_2, kernel_size_2, pad_size_2 = conv_params[1]
self.conv1 = nn.Conv2d(C, channel_1, kernel_size_1, padding=pad_size_1)
self.conv2 = nn.Conv2d(channel_1, channel_2, kernel_size_2, padding=pad_size_2)
H1 = (H - kernel_size_1 + pad_size_1*2) + 1
H2 = (H1 - kernel_size_2 + pad_size_2*2) + 1
self.fc = nn.Linear(channel_2*H2*H2, num_classes)
nn.init.kaiming_normal_(self.conv1.weight)
nn.init.zeros_(self.conv1.bias)
nn.init.kaiming_normal_(self.conv2.weight)
nn.init.zeros_(self.conv2.bias)
nn.init.kaiming_normal_(self.fc.weight)
nn.init.zeros_(self.fc.bias)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
scores = self.fc(x.view(x.shape[0], -1))
return scores
3) PyTorch Sequential API
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(C, channel_1, kernel_size_1, padding=pad_size_1)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(C, channel_1, kernel_size_1, padding=pad_size_1)),
('relu2', nn.ReLU()),
('flatten', nn.Flatten()),
('fc1', nn.Linear(channel_2*H*W, num_classes)),
]))