DeepLearning/Pytorch

[Pytorch] Learning Rate Scheduler 커스텀하기

yooj_lee 2022. 5. 26. 15:41
300x250

CycleGAN을 구현하던 도중 learning rate decay를 linear하게 주었다는 부분을 확인할 수 있었다. 그러나 내가 찾아본 바로는 torch 내의 learning rate scheduler는 linear decay를 적용하는 scheduler가 없었다 (아마 그런 scheduling이 좋지 않으니까 구현이 안되어 있는 게 아닐까 한다). 하지만 논문을 최대한 반영해서 구현해보고 싶어서 직접 구현해보기로 했다. 코드 구현 자체는 어렵지 않으나 pytorch 내에서 어떻게 learning rate scheduler가 작동하는지의 원리를 파악할 필요가 있었다.


_LRScheduler

파이토치 내에 구현되어 있는 scheduler의 베이스 클래스이다. 

"""
code reference: https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html
"""

class _LRScheduler(object):

    def __init__(self, optimizer, last_epoch=-1, verbose=False):

        # Attach optimizer
        if not isinstance(optimizer, Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer

        # Initialize epoch and base learning rates
        if last_epoch == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = [group['initial_lr'] for group in optimizer.param_groups]
        self.last_epoch = last_epoch

        # Following https://github.com/pytorch/pytorch/issues/20124
        # We would like to ensure that `lr_scheduler.step()` is called after
        # `optimizer.step()`
        def with_counter(method):
            if getattr(method, '_with_counter', False):
                # `optimizer.step()` has already been replaced, return.
                return method

            # Keep a weak reference to the optimizer instance to prevent
            # cyclic references.
            instance_ref = weakref.ref(method.__self__)
            # Get the unbound method for the same purpose.
            func = method.__func__
            cls = instance_ref().__class__
            del method

            @wraps(func)
            def wrapper(*args, **kwargs):
                instance = instance_ref()
                instance._step_count += 1
                wrapped = func.__get__(instance, cls)
                return wrapped(*args, **kwargs)

            # Note that the returned function here is no longer a bound method,
            # so attributes like `__func__` and `__self__` no longer exist.
            wrapper._with_counter = True
            return wrapper

        self.optimizer.step = with_counter(self.optimizer.step)
        self.optimizer._step_count = 0
        self._step_count = 0
        self.verbose = verbose

        self.step()

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`.

        It contains an entry for every variable in self.__dict__ which
        is not the optimizer.
        """
        return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}

    def load_state_dict(self, state_dict):
        """Loads the schedulers state.

        Args:
            state_dict (dict): scheduler state. Should be an object returned
                from a call to :meth:`state_dict`.
        """
        self.__dict__.update(state_dict)

    def get_last_lr(self):
        """ Return last computed learning rate by current scheduler.
        """
        return self._last_lr

    def get_lr(self):
        # Compute learning rate using chainable form of the scheduler
        raise NotImplementedError

    def print_lr(self, is_verbose, group, lr, epoch=None):
        """Display the current learning rate.
        """
        if is_verbose:
            if epoch is None:
                print('Adjusting learning rate'
                      ' of group {} to {:.4e}.'.format(group, lr))
            else:
                epoch_str = ("%.2f" if isinstance(epoch, float) else
                             "%.5d") % epoch
                print('Epoch {}: adjusting learning rate'
                      ' of group {} to {:.4e}.'.format(epoch_str, group, lr))


    def step(self, epoch=None):
        # Raise a warning if old pattern is detected
        # https://github.com/pytorch/pytorch/issues/20124
        if self._step_count == 1:
            if not hasattr(self.optimizer.step, "_with_counter"):
                warnings.warn("Seems like `optimizer.step()` has been overridden after learning rate scheduler "
                              "initialization. Please, make sure to call `optimizer.step()` before "
                              "`lr_scheduler.step()`. See more details at "
                              "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

            # Just check if there were two first lr_scheduler.step() calls before optimizer.step()
            elif self.optimizer._step_count < 1:
                warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
                              "In PyTorch 1.1.0 and later, you should call them in the opposite order: "
                              "`optimizer.step()` before `lr_scheduler.step()`.  Failure to do this "
                              "will result in PyTorch skipping the first value of the learning rate schedule. "
                              "See more details at "
                              "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
        self._step_count += 1

        class _enable_get_lr_call:

            def __init__(self, o):
                self.o = o

            def __enter__(self):
                self.o._get_lr_called_within_step = True
                return self

            def __exit__(self, type, value, traceback):
                self.o._get_lr_called_within_step = False

        with _enable_get_lr_call(self):
            if epoch is None:
                self.last_epoch += 1
                values = self.get_lr()
            else:
                warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
                self.last_epoch = epoch
                if hasattr(self, "_get_closed_form_lr"):
                    values = self._get_closed_form_lr()
                else:
                    values = self.get_lr()

        for i, data in enumerate(zip(self.optimizer.param_groups, values)):
            param_group, lr = data
            param_group['lr'] = lr
            self.print_lr(self.verbose, i, lr, epoch)

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

베이스 클래스에서 기억할 사항은 많지 않다. 스케쥴러가 step을 밟아갈 수록 last_epoch과 _step_count가 업데이트된다는 점. step 함수 내에서 self.get_lr()을 호출하여 learning rate update가 수행된다는 점. 그리고 베이스 클래스 내에서 get_lr(self)는 호출 시 NotImplementedError를 일으키게 되어 있다. 즉, 상속 받아서 하위 클래스 작성 시 구현해야 하는 부분이다.

 


Custom LinearDecayLRScheduler

100 에폭 이후 선형적으로 learning rate가 줄어들도록 하기 위해서 아래와 같이 두 개의 스케쥴러 클래스를 구현했다.

  • LinearDecayLR: initial_lr부터 target_lr까지 동일한 값의 learning rate가 epoch마다 줄어드는 형식의 스케쥴러
  • DelayedLinearDecayLR: LinearDecayLR을 상속받는 서브클래스로 구현. 처음부터 learning rate decay를 적용하는 것이 아니라 지정한 epoch 이후부터 learning rate decay가 선형적으로 발생하도록 하는 스케쥴러.

 

LinearDecayLR

LinearDecayLR의 경우, learning rate의 감소량을 정하기 위해 learning rate의 초깃값을 지정해줘야했는데, 처음에는 베이스 클래스 내의 base_lr의 속성값을 활용하고자 했다. 하지만, 이 경우 에러가 발생하였다.

# 1.
class LinearDecayLR(_LRScheduler):
	def __init__(self, optimizer, target_lr, last_epoch, total_iters, verbose=False):
    	self.optimizer = optimizer
        self.target_lr = target_lr
        self.total_iters = total_iters
    	super().__init__(optimizer, last_epoch, verbose)
        self.subtract_lr = self._get_decay_constant()
        
    def _get_decay_constant(self):
    	"""
        cf.)
        self.base_lrs: List[float]
        """
    	return [float((init_lr-self.target_lr)/self.total_iters) for init_lr in self.base_lrs]
        
# 2.
class LinearDecayLR(_LRScheduler):
	def __init__(self, optimizer, target_lr, last_epoch, total_iters, verbose=False):
    	self.optimizer = optimizer
        self.target_lr = target_lr
        self.total_iters = total_iters
        self.subtract_lr = self._get_decay_constant()
    	super().__init__(optimizer, last_epoch, verbose)
        
    def _get_decay_constant(self):
    	"""
        cf.)
        self.base_lrs: List[float]
        """
    	return [float((init_lr-self.target_lr)/self.total_iters) for init_lr in self.base_lrs]

1번의 경우, 부모 클래스(super class)를 초기화한 후 self.subtract_lr을 정의하도록 했다. 이 경우에 self.subtract_lr라는 속성이 없다는 AttributeError가 발생했다. 이는 부모 클래스인 _LRScheduler 초기화 시 LinearDecayLR의 get_lr 메소드를 호출하게 되는데 이 get_lr 메소드 내에 self.subtract_lr이 사용되기 때문에 발생하는 문제였던 것이다. 애초에 super().__init__()을 수행할 때에 에러가 발생하는 것이었다.

2번의 경우엔 base_lrs은 부모 클래스 초기화 시 함께 초기화되는 속성이므로 당연히 super().__init__()를 호출하기도 전에 subtract_lr 계산을 위해 self._get_decay_constant()를 호출하면 self.base_lrs가 존재하지 않는다고 에러가 발생할 수 밖에 없다.

따라서, self.base_lrs를 활용할 수 없다고 판단을 내렸고 initial learning rate를 클래스 외부에서 인자로 받아오도록 했다. 최종적으로 다음과 같이 LinearDecayLR Scheduler를 구현했다.

class LinearDecayLR(_LRScheduler):
    """
    Custom LR Scheduler which linearly decay to a target learning rate.

    override 해줘야 하는 부분은 get_lr, _get_closed_form_lr
    """
    def __init__(self, optimizer:Optimizer, initial_lr: float, target_lr:float, total_iters:int, last_epoch:int=-1, verbose:bool=False):
        
        if initial_lr < 0:
            raise ValueError("Initial Learning rate expected to be a non-negative integer.")
            
        if target_lr < 0:
            raise ValueError("Target Learning rate expected to be a non-negative integer.")

        if target_lr > initial_lr:
            raise ValueError("Target Learning Rate must be larger than Initial Learning Rate.")

        self.init_lr = initial_lr
        self.target_lr = target_lr
        self.total_iters = total_iters
        self.subtract_lr = self._get_decay_constant()
        
        super(LinearDecayLR, self).__init__(optimizer, last_epoch, verbose) # 부모클래스의 init에 필요한 arg 넘겨줌.
                    
    def _get_decay_constant(self):
        return float((self.init_lr-self.target_lr)/self.total_iters)
        
    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the learning rate computed by the scheduler, "
                        "please use 'get_last_lr()'.", UserWarning)

        return [group['lr']-self.subtract_lr for group in self.optimizer.param_groups]

 

DelayedLinearDecayLR

CycleGAN에서는 200 에폭 중 100 에폭 이후부터 초기의 0.0002부터 0까지 선형적으로 줄어드는 learning rate decay를 사용했다 (decay에서 delay가 존재함). 따라서 위의 LinearDecayLR을 상속받도록 하고, decay에 delay만 주는 형태로 스케쥴러를 구현했다.

class DelayedLinearDecayLR(LinearDecayLR):
    def __init__(self, optimizer:Optimizer, initial_lr: float, target_lr: float, total_iters:int, last_epoch:int=-1, decay_after:int=100, verbose:bool=False):
        self.decay_after = decay_after
        
        super(DelayedLinearDecayLR, self).__init__(optimizer, initial_lr, target_lr, total_iters, last_epoch, verbose)

    
    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the learning rate computed by the scheduler, "
                        "please use 'get_last_lr()'.", UserWarning)

        if self.decay_after <= self.last_epoch < (self.decay_after + self.total_iters): # 여기에 total iter도 고려해줘야함.
            return [group['lr']-self.subtract_lr for group in self.optimizer.param_groups]

        else:
            return [group['lr'] for group in self.optimizer.param_groups]

get_lr 메소드만 override하였고, last_epoch을 활용해서 몇번째 에폭 이후로 decay를 수행할 것인지를 나타내는 decay_after보다 last_epoch이 크거나 같을 경우, total_iters 동안 learning rate decay가 일어나도록 스케쥴러를 구현했다. 저자 구현 코드를 봤을 때, LambdaLR (epoch을 변수로 갖는 lambda 함수를 이용해서 learning rate scheduling 구현)을 활용해서 선형적으로 learning rate가 줄어드는 함수를 넘겨줌으로써 스케쥴러를 구현한 것 같았다.

 

cf.) LambdaLR
사용자 정의 함수를 받아 해당 함수에 맞춰 learning rate scheduling을 진행. 비슷한 방식으로 MultiplicativeLR도 구현되어 있음. 그 둘의 차이는 다음과 같음.

  • LambdaLR's function: $lr_{epoch} = lr_{initial} * Lambda(epoch)$
  • MultiplicativeLR's function: $lr_{epoch} = lr_{epoch-1} * Lambda(epoch)$

LambdaLR의 경우에는 함수가 learning rate 초깃값과 곱해져서 결정이 되는 거지만, MultiplicativeLR은 이전 epoch의 learning rate 값에 적용이 되기 때문에 recursive하게 작용이 된다는 점이다.

 


정리

  • 파이토치 내에서 learning rate scheduler가 어떤 식으로 동작하는지 파악할 수 있었다.
  • 직접적으로 learning rate scheduler를 구현해보며 클래스 상속에 대해 보다 이해가 깊어졌다.

References

300x250