DeepLearning/Pytorch

[Pytorch] torch.no_grad() versus requires_grad=False

yooj_lee 2023. 11. 13. 14:29
300x250

prompt learning 코드 보는데 갑자기 이해가 안가는 점이 생김. 궁금증은 아래 이슈에서부터 시작됨.

https://github.com/KaiyangZhou/CoOp/issues/7

 

question about gradients on text encoder · Issue #7 · KaiyangZhou/CoOp

Hi, may I ask if the gradients of the original CLIP text encoder are frozen or not? The paper mentioned that the gradients of text encoder is frozen, but I couldn't find that part in the code... Th...

github.com

여기 보면, CLIP 텍스트 인코더가 freeze되어 있다고 함 (the text encoder's gradients are turned off). 해당 이슈 관련 커밋을 보면, requires_grad_를 False로 둬서 CLIP 텍스트 인코더의 grad를 차단한다고 함.

https://github.com/KaiyangZhou/CoOp/commit/be57b16c86f3de7d7135bf51930bf0276cb3b2ac

 

change misleading code by turning off grad in the two encoders · KaiyangZhou/CoOp@be57b16

KaiyangZhou committed Sep 27, 2021

github.com

갑자기 드는 생각이 loss backward를 할 때 text prompt가 encoder를 통과해서 gradient를 산출하게 될텐데 encoder의 gradient를 끄면 어떻게 prompt의 gradient가 계산되는 건지 헷갈리기 시작함.

➡️ requires_grad에 대한 이해가 부족하다고 판단..

열심히 구글링하다가 스택오버플로우의 글 두 개를 접하게 되었다.

https://stackoverflow.com/questions/63785319/pytorch-torch-no-grad-versus-requires-grad-false

 

PyTorch torch.no_grad() versus requires_grad=False

I'm following a PyTorch tutorial which uses the BERT NLP model (feature extractor) from the Huggingface Transformers library. There are two pieces of interrelated code for gradient updates that I d...

stackoverflow.com

https://stackoverflow.com/questions/51748138/pytorch-how-to-set-requires-grad-false

 

pytorch how to set .requires_grad False

I want to set some of my model frozen. Following the official docs: with torch.no_grad(): linear = nn.Linear(1, 1) linear.eval() print(linear.weight.requires_grad) But it prints True

stackoverflow.com

정리하게 되면, model.eval()을 걸게 되면 batchnorm이나 dropout 같이 train과 eval 시 동작방식이 상이한 모듈 관련하여 모드 변경을 하는 것이고 gradient memory 관련해서는 크게 관여하는 바가 없음. 우리는 흔히 inference 시에 torch.no_grad를 걸어주게 되는데 우리는 inference 시에 gradient 계산을 할 이유가 없음 (forward 패스만 필요하기 때문). 따라서 이때 gradient를 계산하게 되면 메모리만 잡아먹게 되므로 꺼주는 것. 다만, no_grad로 context가 잡히는 부분부터는 gradient가 계산이 되지 않는다는 점 (prevent calculating gradients; requires_grad 속성값과는 무관하게 그냥 계산을 안하는 형태). 따라서, no_grad로 wrap해주는 부분 이전까지의 레이어는 아예 gradient가 흘러들어가지 않게 됨.

그림1

이렇게 원천 차단되기 때문에 torch.no_grad는 training을 아예 하지 않는 inference 단이 아니라면 (특히 모델 일부분만 freeze하는 방식일 경우) 그리 선호되는 방식은 아닐 듯함. 반면에, requires_grad=False로 설정하여 모델 parameter 일부를 disable하면 해당 파라미터에 대한 gradient는 아예 저장이 되지 않음. 즉, requires_grad 자체는 gradient를 계산하지 않느냐가 아니라 gradient를 물리적으로 저장할 것인지 아닌지에 대한 문제로 보임.

여기서 조금 헷갈렸던 게 어차피 chain rule을 통해서 계산을 하는 거면 끝으로 갈 수록 그 residual만 계산하면 되는 거 아닌가 했는데 실제로 미분 연산 수행 자체는 그렇게 하더라도 이 gradient 값을 물리적으로 가지고 있느냐 마느냐는 또 다른 문제인 것 같음. 그림1에서 outputs에서 각 노드 별로 grad tensor를 가지고 있게 되니까. 

이거에 대해 좀 더 생각해보자. torch.no_grad는 파라미터 수에 영향을 미치진 않지만, backward를 호출하면서 gradient를 계산할 때에 해당 context에서는 gradient 계산이 전혀 이루어지지 않음. 따라서, gradient를 저장하지 않기 때문에 그만큼 메모리 소모량은 줄어들게 됨. 반면에 no_grad 컨텍스트는 비활성화하고 requires_grad를 True로 놓게 되면 3M 파라미터에 해당하는 gradient만 저장하게 됨 (나머지 82M에 대해서는 gradient 계산은 이루어지지만 물리적인 저장은 하지 않을 것).

b와 d 옵션을 조금 더 살펴보면, b같은 경우에는 모든 파라미터에 대한 grad 텐서를 저장하게끔 설정이 되어 있지만, 애초에 no_grad context에서는 gradient 계산 자체가 이루어지지 않기 때문에 grad 텐서는 어차피 None이 됨. 따라서, 파라미터 수가 112M로 잡히더라도 메모리 사용량은 줄어들게 됨. d 같은 경우에는, 모든 파라미터에 대한 gradient 텐서도 저장이 되게끔 설정되어 있고 실제로 계산도 모두 이루어지기 때문에 112M개의 파라미터에 대한 gradient를 다 저장해야함. 따라서, OOM error가 발생하는 것.

➡️ 즉, no_grad는 wrapping되어 있는 context에서는 gradient 계산이 이루어지지 않게 되는 것. requires_grad는 grad 텐서를 물리적으로 저장하느냐 아니냐에 대한 여부인 듯함. 계산을 하더라도 물리적으로 저장이 안될 수는 있고 계산을 하지 않더라도 물리적으로 저장은 될 수 있음 (이때는 empty tensor이니 정확히 물리적인 공간 할당이 되느냐 안되느냐의 문제 아닐까 싶음. 실제로 테스트해보니까 아무것도 출력을 하지 않긴 하는데 정확히 메모리공간이 할당이 안되는 건지 뭔지는 잘 모르겠음.)

따라서, 내가 처음에 고민했던 부분으로 돌아가면 text encoder를 통해 gradient가 흘러들어가는 것은 맞지만 text encoder의 gradient를 저장하지 않기 때문에 text encoder는 learnable parameter에서 카운트되지 않고 freeze되어 있는 형태로 상정(?)할 수 있음. 다만 이렇게 gradient를 끄게 되면 메모리 소모량이 줄어들기 때문에 어차피 학습 안시키면 (optimizer에 feed되는 건 어차피 prompt learner밖에 없어서 상관 없을 것 같긴 한데) 그냥 requires_grad는 False로 두는 게 현명하다.

 

300x250