DeepLearning/Pytorch

[Pytorch] 여러 gpu 사용할 때 torch.save & torch.load (feat. map_location)

yooj_lee 2022. 10. 20. 20:45
300x250

굳이 multi-gpu에서 분산 학습을 한 게 아니더라도, 여러 gpu를 사용할 수 있는 환경이라면 한 번쯤 겪어볼 만한 error에 대한 해결책 정리.

0번 gpu에서 모델을 학습시키고 torch.save를 통해 저장한 상황이라고 할 때, 1번 gpu에서 torch.load를 통해 해당 weight을 가져와서 학습을 재개하고자 하는 경우 혹은 재학습시키고자 하는 경우 device가 일치하지 않는다는 error가 뜨는 경우가 존재함.

이는, torch.save를 통해서 모델을 가져올 때 map location을 지정해주지 않아서 생기는 문제. torch DDP docs를 보다가 아래와 같은 note를 발견함.

  If you use torch.save on one process to checkpoint the module, and torch.load on some other processes to recover it, make sure that map_location is configured properly for every process. Without map_location, torch.load would recover the module to devices where the module was saved from.

 

즉, torch.load를 할 때에 map_location 인자에 현재 사용하고자 하는 cuda device를 지정해주지 않으면 자동적으로 torch.save를 했던 device에 모듈을 recover하고자 하므로 문제가 발생하게 됨.

따라서, 아래와 같이 map_location을 지정을 잘 해주자.

"""
Ref from Pytorch documents (https://pytorch.org/docs/stable/generated/torch.load.html)
"""

# Load all tensors onto the CPU
>>> torch.load('tensors.pt', map_location=torch.device('cpu'))
# Load all tensors onto the CPU, using a function
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
# Load all tensors onto GPU 1
>>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
# Map tensors from GPU 1 to GPU 0
>>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})

 

300x250