DeepLearning/Pytorch

[Pytorch] nn.Module & super().__init__()

yooj_lee 2021. 8. 21. 02:05
300x250

우리는 pytorch에서 각자 레이어 혹은 모델을 구성할 때, nn.Module을 상속받는다. 왜 상속을 받을까? 또 상속받을 때, super().__init__()은 왜 해주는 것일까? 해당 코드를 작성함으로써 어떤 속성을 갖게되는 걸까?

이번 글에서는 이 두 가지를 중점적으로 정리해볼 것이다. 아래의 코드는 간단히 convolution block을 구현한 것이다.

import torch
import torch.nn as nn

class ConvBlock(nn.Module):
	"""
	conv2d - batchnorm - relu - max_pooling2d
	로 구성된 conv block을 만들어봄
	"""
	def __init__(self, kernel_size, stride, padding, pool, pool_stride):
		super().__init__() # 1. why? super().__init__()
		self.kernel_size = kernel_size
		self.stride = stride
		self.padding = padding
		self.pool = pool
		self.pool_stride = pool_stride
		self.conv_block = nn.Sequential(
        	nn.Conv2d(3,16,kernel_size=self.kernel_size,stride=self.stride,padding=self.padding),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                nn.MaxPool2d(self.pool, self.pool_stride)
                )

	def forward(self,x):
		output = self.conv_block(x)
		return output

 

super().__init__()

super()로 기반 클래스(부모 클래스)를 초기화해줌으로써, 기반 클래스의 속성을 subclass가 받아오도록 한다. (초기화를 하지 않으면, 부모 클래스의 속성을 사용할 수 없음)

cf.) super().__init__() vs super(MyClass,self).__init__()
좀 더 명확하게 super를 사용하기 위해서는 단순히 super().__init__()을 하는 것이 아니라 super(파생클래스, self).__init__() 을 해준다.이와 같이 적어주면 기능적으로 차이는 없지만, 생클래스와 self를 넣어서 현재 클래스가 어떤 클래스인지 명확하게 표시해줄 수 있다.

 

class MyModel(nn.Module()):

글 처음의 코드를 보면 ConvBlock은 nn.Module을 상속받는다.
왜 상속 받아서 사용할까? 상속 받으면 무슨 일이 일어나길래?

 

관련 스택오버플로우 질문을 하나 참고하자면,

 

class LR(nn.Module):
	
	def __init__(self, input_size, output_size):
		
		super(LR, self).__init__()
		self.test = 1
		self.linear = nn.Linear(input_size, output_size)

	def forward(self, x):
		out = self.linear(x)
		return out

 

위의 코드에서 super(LR, self).__init__()를 작성하지 않는다면 self.linear = nn.Linear(input_size, output_size)에서 "AttributeError: cannot assign module before Module.__init__() call" 을 일으키게 된다.

이러한 이유는 self.linear = ~ 를 실행하면, 내부적으로는 선언한 클래스의 __setattr__함수를 실행하게 된다 (__setattr__는 nn.Module을 extend하면서 상속받는 것 중에 하나).

 

def __setattr__(self, name: str, value: Union[Tensor, 'Module']) -> None:
    # [...]
    modules = self.__dict__.get('_modules')
    if isinstance(value, Module):
        if modules is None:
            raise AttributeError("cannot assign module before Module.__init__() call")
        remove_from(self.__dict__, self._parameters, self._buffers, self._non_persistent_buffers_set)
        modules[name] = value

 

위의 __setattr__ 함수의 정의에 따르면, nn.Linear가 nn.Module의 인스턴스인 경우에는 클래스가 _modules 속성을 가지고 있지 않으면 AttributeError를 raise하도록 되어 있다. 즉, _modules가 초기화되어 있지 않으면 AttributeError가 발생한다.

nn.Module의 __init__ 함수를 보면, self._modules는 nn.Module의 __init__에서 선언 및 초기화됨을 알 수 있다.

 

def __init__(self):
    """
    Initializes internal Module state, shared by both nn.Module and ScriptModule.
    """
    torch._C._log_api_usage_once("python.nn_module")

    self.training = True
    self._parameters = OrderedDict()
    self._buffers = OrderedDict()
    self._non_persistent_buffers_set = set()
    self._backward_hooks = OrderedDict()
    self._forward_hooks = OrderedDict()
    self._forward_pre_hooks = OrderedDict()
    self._state_dict_hooks = OrderedDict()
    self._load_state_dict_pre_hooks = OrderedDict()
    self._modules = OrderedDict() ## _modules 선언 및 초기화

 

 이처럼, nn.Module을 extend함으로써 custom model or custom layer는 위와 같은 여러 속성들 또한 상속받게 된다. 위와 같은 속성들은 parameter, modules 등의 네트워크 학습에 있어서 중요한 특성을 포함한다.

파이토치에서 제공하는 layer를 사용하여 모델 빌드를 간편하게 하기 위해서는 위와 같이 nn.Module을 상속받고, 이를 초기화 함으로써 nn.Module에서 상속받는 특성들을 초기화해주는 것이 필요하다.


reference

1. https://stackoverflow.com/questions/63058355/why-is-the-super-constructor-necessary-in-pytorch-custom-modules
2. https://dojang.io/mod/page/view.php?id=2386

300x250