DeepLearning/Pytorch

[Pytorch] hook (forward_hook, backward_hook)

yooj_lee 2021. 9. 24. 17:49
300x250

Hook

패키지 중간에 자기가 원하는 코드 끼워넣을 수 있는 부분 정도로 이해하면 될 듯하다! (register hook)

  • hook: 일반적으로 hook은 프로그램, 혹은 특정 함수 실행 후에 걸어놓는 경우를 일컬음.
  • pre-hook: 프로그램 실행 전에 걸어놓는 hook
  • forward hook
    1. register_forward_hook: forward 호출 후에 forward output 계산 후 걸어두는 hook
      # register_forward_hook should have the following signature 
      hook(module, input, output) -> None or modified output
      input은 positional arguments만 담을 수 있으며 (index 같은?) keyword arguments 등은 담을 수 없고, forward 서만 적용이 된다.
      hook은 forward output 수정 가능, input 또한 수정 가능하지만 forward에는 영향 없음.

    2. register_forward_pre_hook: forward 호출 전에 걸어두는 hook
      # The hook will be called every time before forward() is invoked. It should have the following signature:
      hook(module, input) -> None or modified input

      input은 positional arguments만 담을 수 있으며 (index 같은?) keyword arguments 등은 담을 수 없고, forward 에서만 적용이 된다.
      여기서 hook은 input을 수정할 수 있고, 출력은 튜플 혹은 single modified value를 리턴함. single value여도 tuple로 wrapping되어서 나간다는 점 참고.
  • backward hook
    1. register_full_backward_hook (module에 적용)
      module input에 대한 gradient가 계산될 때마다 hook이 호출됨.
      # The hook should have the following signature:
      hook(module, grad_input, grad_output) -> tuple(Tensor) or None​
      grad_input과 grad_output은 각각 input과 output에 대한 gradient를 포함하고 있는 튜플. hook은 hook의 인자, 즉 grad_input과 grad_output을 수정할 수는 없지만 새로운 그래디언트를 리턴해서 grad_input 대신 사용할 수 있음 (이후 computation에서 사용 가능). 역시 positional arguments만 허용 가능하며 keyword arguments는 허용되지 않음.
      또한 여기서 input 또는 output을 직접적으로 수정하는 건 error 발생.

    2. register_hook (in Tensor)
      → Tensor의 경우에는 only backward hook (Tensor의 gradient가 계산될 때마다 hook 호출됨. hook은 gradient를 바꿀 수는 없지만 새로운 gradient를 생성 가능하며, 기존 grad 대신 사용 가능함.)

 

hook이 어디에 사용이 될까?

  • 디버깅 (레이어 shape, output 등을 출력하는 hook을 넣어주는 방식)
  • feature extraction
    """
    author: Frank Odom
    https://medium.com/the-dl/how-to-use-pytorch-hooks-5041d777f904
    """
    from typing import Dict, Iterable, Callable
    
    class FeatureExtractor(nn.Module):
        def __init__(self, model: nn.Module, layers: Iterable[str]):
            super().__init__()
            self.model = model
            self.layers = layers
            self._features = {layer: torch.empty(0) for layer in layers}
    
            for layer_id in layers:
                layer = dict([*self.model.named_modules()])[layer_id]
                layer.register_forward_hook(self.save_outputs_hook(layer_id))
    
        def save_outputs_hook(self, layer_id: str) -> Callable:
            def fn(_, __, output):
                self._features[layer_id] = output
            return fn
    
        def forward(self, x: Tensor) -> Dict[str, Tensor]:
            _ = self.model(x)
            return self._features​
  • gradient clipping (이 경우에는 torch.Tensor.register_hook)
  • visualising activation (forward hook)
    """
    author: Ayoosh Kathuria
    https://blog.paperspace.com/pytorch-hooks-gradient-clipping-debugging/
    """
    import torch 
    import torch.nn as nn
    
    class myNet(nn.Module):
      def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3,10,2, stride = 2)
        self.relu = nn.ReLU()
        self.flatten = lambda x: x.view(-1)
        self.fc1 = nn.Linear(160,5)
        self.seq = nn.Sequential(nn.Linear(5,3), nn.Linear(3,2))
        
       
      
      def forward(self, x):
        x = self.relu(self.conv(x))
        x = self.fc1(self.flatten(x))
        x = self.seq(x)
      
    
    net = myNet()
    visualisation = {}
    
    def hook_fn(m, i, o):
      visualisation[m] = o 
    
    def get_all_layers(net):
      for name, layer in net._modules.items():
        #If it is a sequential, don't register a hook on it
        # but recursively register hook on all it's module children
        if isinstance(layer, nn.Sequential):
          get_all_layers(layer)
        else:
          # it's a non sequential. Register a hook
          layer.register_forward_hook(hook_fn)
    
    get_all_layers(net)
    
      
    out = net(torch.randn(1,3,8,8))
    
    # Just to check whether we got all layers
    visualisation.keys()      #output includes sequential layers​

 

300x250