DeepLearning/Computer Vision

[Pytorch] Vision Transformer (ViT) 코드 구현

yooj_lee 2022. 9. 27. 12:32
300x250

ViT의 코드를 pytorch로 구현해보았습니다. vit의 경우에는 구현하면서 꽤 애를 먹었습니다. 최대한 논문 성능을 재현해보고자 했으나 ImageNet-1k (그 이상은 현실적으로 안됨..) 데이터셋을 논문 training detail 대로 학습을 시킬 경우 학습 완료 시까지 6일(혹은 12일..) 정도 소요되어 많은 실험이 불가능해서 좀 아쉬웠습니다. 이번 구현의 경우, 최적화 관련한 이슈가 발생해서 다양한 실험을 해보지는 못하고 정말 모델을 학습시키는 것 그 자체에 있어서 삽질을 많이 했던 것 같습니다.. 또 나름 ImageNet-1k 자체도 무거운 데이터셋이기 때문에 분산 학습을 진행했는데 이 과정에서 data parallel 에러를 핸들링하는 데에도 꽤 시간을 많이 잡아먹었던 것 같네요..ㅎㅎ

소스코드 전체는 아래의 레포지토리 참고 부탁드립니다.
https://github.com/YoojLee/vit

 

GitHub - YoojLee/vit: vit reimplementation

vit reimplementation. Contribute to YoojLee/vit development by creating an account on GitHub.

github.com

또한, 하기한 내용에 오류 혹은 질문이 있으실 경우 언제든 댓글 부탁드립니다.


프로젝트 구조

  • augmentation.py: 간단한 Dataset Transform을 구현
  • dataset.py: ImageNet-1k dataset 클래스 구현
  • model.py: ViT 모델 구현
  • scheduler.py: linear warm-up + cosine annealing 등의 스케쥴러 구현.
  • train.py: single gpu 상황을 가정한 train.py
  • train_multi.py: multi-gpu 하에서의 train.py
  • utils.py: metrics 계산, checkpoint load 등의 여러 함수 구현

당시 너무 정신 없이 구현하느라 train.py와 train_multi.py를 따로 그냥 구현해버렸는데, 나중에 보수 작업을 거쳐서 argument로 지정해서 하나의 train.py 파일로 multi-gpu 혹은 single-gpu를 둘 다 대응할 수 있도록 해야될 것 같습니다.

ViT 모델 구현에서는 model.py, scheduler.py, train_multi.py를 위주로 포스팅을 작성할 예정입니다. 이번 포스팅은 그 중에서도 model.py에 집중하여 서술하겠습니다.

학습 환경

일전에 cycle gan 구현 당시와 동일합니다. 


 

model.py

model.py를 얘기하기 전에 ViT의 구조를 간단히 살펴보겠습니다.

ViT의 architecture

위를 보면, 구현해야할 부분은 1) 이미지를 flattened patch sequence로 만들어주는 과정 2) 패치 임베딩 3) multi-head attention 4) MLP (Encoder 내부) 5) MLP Head 입니다. 5개의 모듈을 만들고 최종적으로 조합하는 방식으로 최종 구현했습니다.

 

Patch Sequence 및 Embedding

ViT의 입력으로 패치 시퀀스를 만들어주는 부분에서 처음에는 다음과 같이 무식하게! 물리적으로 구현했었습니다.

def make_patches(img:np.ndarray, p:int)->np.ndarray:
    """
        generate image patches in a serial manner
        - Args
            img (np.ndarray): an image array. (C,H,W) in which H=W=224
            p (int): size of patch
        
        - Returns
            patches (np.ndarray): 2-D array of flattened image patches (each patch has a size of P^2*C)
    """
    patches = np.array([])
    x1, y1 = 0, 0
    _,h,w = img.shape

    for y in range(0, h, p):
        for x in range(0, w, p):
            if (h-y) < p or (w-x) < p:
                break
            
            y1 = min(y+p, h)
            x1 = min(x+p, w)

            tiles = img[:, y:y1, x:x1]

            if patches.size == 0:
                patches = tiles.reshape(1,-1)
                
            else:
                patches = np.vstack([patches, tiles.reshape(1,-1)]) # reshape(-1) or ravel도 사용 가능. flatten은 카피 떠서 쓰는 거

    return patches

 

위의 코드 블럭에서 보시다시피, 정말 슬라이싱을 통해서 이미지 타일을 만들고, numpy array로 stacking하는 구조였습니다. 이 과정에서 결국 이미지 전체를 모두 탐색해야하고, 또 numpy array 하나에 stack되는 구조이기 때문에 메모리 낭비도 발생하게 됩니다. 당연히 이 부분에서 병목이 발생할 수 밖에 없고, 데이터셋을 초기화할 때 호출하는 구조이다보니 백만장이 넘는 이미지에 대해 $O(n^2)$으로 탐색하게 되면서 데이터셋 로딩 과정이 매우 느려졌습니다.

따라서 이러한 부분을 데이터셋 측에서 핸들링하는 게 아니라 모델에 입력할 때는 기존의 이미지 $(b,c,h,w)$ 형태로 넣어주되 patch embedding을 할 때에 $(b,n,hw/p^2)$ 형태로 변환해주는 과정을 하나 더 추가하는 방식으로 해결하고자 했습니다. 이 과정에서 einops의 Rearrange라는 클래스를 이용했습니다.

 

class Embeddings(nn.Module):
    def __init__(self, p:int, input_dim: int, model_dim:int, n_patches:int, dropout_p:float):
        """
        patch embedding과 positional embedding 생성하는 부분 (Encoder 인풋 만들어주는 부분이라고 생각하면 됨)
        """
        super().__init__()
        self.input_dim = input_dim
        self.model_dim = model_dim
        self.n_patches = n_patches

        self.to_patch_embedding = nn.Sequential(
            OrderedDict(
                {
                    "rearrange": Rearrange('b c (h1 p1) (w1 p2) -> b (h1 w1) (p1 p2 c)', p1 = p, p2 = p), # h1 = h / p1, w1 = w / p2
                    "projection": nn.Linear(self.input_dim, self.model_dim)
                })
        )

        # 여기서는 cls_token을 정의
        self.cls_token = nn.Parameter(torch.randn(1,1,model_dim))

        # self.position_embedding
        self.pos_emb = nn.Parameter(torch.randn(1,n_patches+1,model_dim))

        # dropout
        self.dropout = nn.Dropout(dropout_p)
    
    def forward(self, x):
        """
        x: an image tensor
        """
        proj = self.to_patch_embedding(x)
        b, _, _ = proj.shape

        # embedding
        cls_token = self.cls_token.repeat(b,1,1) # 이렇게 해야 b,1,model_dim으로 됨.

        # summation
        patch_emb = torch.cat((cls_token, proj), dim = 1)
        
        return self.dropout(self.pos_emb + patch_emb)

Embeddings라는 클래스를 정의하여 패치 시퀀스 변환 및 linear projection, class token과 positional embedding 부분을 구현했습니다. class token과 positional embedding의 경우에는 학습 가능한 파라미터로, nn.Parameter로 register해주는 방식으로 해당 부분을 구현했습니다. 

 

Multi-Head Attention

Multi-Head Attention 설명

MultiHeadAttention는 기본적으로 q,k,v로 입력 matrix를 쪼개고 이에 각각 attention weight을 얻어내는 것입니다. 이를 병렬적으로 head 개수만큼 수행을 하고 마지막에 다시 합쳐주는 과정을 거쳐주면 됩니다. 

이를 위해 처음에는 설정해둔 model dimension의 3배 차원으로 projection을 시켜주고, 이를 다시 각각 $(b,n,d_h)$의 차원을 갖는 q,k,v로 쪼개줘야 합니다. 다만, multi-head이기 때문에 최종적으로는 $(b,k,n,d_h)$의 텐서를 q,k,v 3개씩 얻어내야 합니다. ($k$: head 개수) 

이후, query와 key를 곱하여 attention matrix를 얻어내고 이를 다시 value에 행렬곱하여 최종적인 self-attention 값을 얻어내게 됩니다. self-attention을 구했으면 이를 최종적으로 projection 시켜서 FFN(MLP) 블록으로 들어가게 됩니다. (후속 논문에서 해당 projection은 크게 모델 성능에 영향을 미치지 않는다고 밝히기도 했습니다.)

class MSA(nn.Module):
    def __init__(self, model_dim:int, n_heads:int, dropout_p:float, drop_hidden:bool):
        """
        Multi-Head Self-Attention Block.
        - Args
            model_dim (int): model dimension D
            n_heads (int): number of heads
            dropout_p (float): a probability of a dropout masking
        """
        super().__init__()

        self.model_dim = model_dim
        self.n_heads = n_heads
        self.dropout_p = dropout_p
        self.drop_hidden = drop_hidden
        self.scale = (model_dim/n_heads) ** -0.5

        self.norm = nn.LayerNorm(model_dim) # LayerNorm은 입력 배치 내에서 통계량을 계산하는 것.
        self.linear_qkv = nn.Linear(model_dim, 3*model_dim, bias=False) # [U_qkv] for changing the dimension 
        self.projection = nn.Identity() if self.drop_hidden else nn.Linear(model_dim, model_dim)
    
    def forward(self, z):
        b,n,_ = z.shape
        qkv = self.linear_qkv(self.norm(z)) # [B, N, 3*D_h] -> 3개의 [B,k,N,D_h]로 쪼개는 게 목표임

        # destack qkv into q,k,v (3 vectors)
        qkv_destack = qkv.reshape(b,n,3,self.n_heads,-1) # 이렇게 해줘야 b,n 차원은 건드리지 않고 벡터 차원만 가지고 크기 조작이 가능함. 마지막 차원은 d_h와 동일함.
        q,k,v = qkv_destack.chunk(3, dim=2) # [b,n,1,k,d_h] 차원의 벡터 3개를 리턴
        
        q = q.squeeze().transpose(1,2) # 이 방식은 batch size가 1일 때는 문제가 될 것 같다.
        k = k.squeeze().transpose(1,2)
        v = v.squeeze().transpose(1,2)
        
        # q, k attention
        # k.mT → from [b,k,n,d_h] to [b,k,d_h,n] (마지막 2개 차원을 transpose)
        qk_T = torch.matmul(q,k.mT) # [b,k,n,n]
    
        attention = F.softmax(qk_T*self.scale, dim=-1) # 각 row 내에서 softmax 취해서 attention score 구함. [b,k,n,1]

        if self.dropout_p:
            attention = F.dropout(attention, p=self.dropout_p)

        # compute a weighted sum of v
        msa = torch.matmul(attention, v) # [b,k,n,d_h]
        msa = msa.transpose(1,2) # [b,n,k,d_h]

        # concatenate k attention heads
        msa_cat = msa.reshape(b,n,self.model_dim) # [b,n,d] where d_h = d/k

        # projection
        output = self.projection(msa_cat)

        if self.dropout_p:
            output = F.dropout(output, p=self.dropout_p)
        
        return z+output # skip-connection

 

저는 그냥 torch로 reshape을 구현했지만, 다른 코드에서는 einops를 활용해서 보다 쉽게 reshape을 하는 경우가 많았습니다. 구현하실 때 참고하시길 바랍니다. attention 자체는 matrix multiplication과 reshape..의 반복이었습니다. reshape을 하면서 어떤 식으로 텐서가 재구성될지에 대해 약간 확신이 들지 않아서 그런 부분이 조금 어려웠던 것 같습니다. softmax 역시 attention 메커니즘을 제대로 이해하지 못하면 어느 차원으로 적용을 해주어야 하는지에 대해 헷갈릴 듯합니다. 가장 중요한 건, 어쨌든 $(b,n,d)$ 차원을 지속적으로 유지를 시켜주어야 한다는 점입니다. 해당 사항을 염두에 두고 구현하다 보면 헷갈리다가도 갈피를 잡을 수 있는 것 같습니다. (아님 말고..)

 

FFN (MLP)

feed-forward network이기 때문에 구현은 간단합니다. 다만, 시퀀스 내 n개의 패치가 전부 하나의 mlp 가중치를 공유한다는 점을 염두에 두면 될 것 같습니다. 이를 position-wise feed-forward network라고 하는데요, 이는 transformer가 2017년 발표되면서부터 제안된 방식입니다. 모든 position이 동일한 가중치를 공유하면서 각자 독립적으로 weight와 곱해지는 형식입니다. MLP Mixer에서도 이런 형식으로 가중치를 공유하는 MLP 구조를 사용했습니다.

transformer 원 논문에서는 ReLU를 activation function으로 활용했지만, 여기서는 GELU를 활용했습니다.

class FFN(nn.Module):
    def __init__(self, model_dim, hidden_dim, dropout_p):
        super().__init__()

        self.model_dim = model_dim
        self.hidden_dim = hidden_dim

        self.norm = nn.LayerNorm(self.model_dim)
        self.fc1 = nn.Linear(self.model_dim, self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, self.model_dim)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(dropout_p)

        self.block = nn.Sequential(
            self.norm,
            self.fc1,
            self.activation,
            self.dropout,
            self.fc2,
            self.dropout # 여기에는 gelu를 쓰지 X.
        )

    def forward(self, z):
        return z + self.block(z)

 

Encoder

class Encoder(nn.Module):
    def __init__(self, n_layers, model_dim, n_heads, hidden_dim, dropout_p, drop_hidden):
        super().__init__()

        self.n_layers = n_layers
        self.model_dim = model_dim
        self.n_heads = n_heads
        self.hidden_dim = hidden_dim
        self.dropout_p = dropout_p
        self.drop_hidden = drop_hidden
        
        layers = []

        for _ in range(self.n_layers):
            layers.append(nn.Sequential(
                MSA(self.model_dim, self.n_heads, self.dropout_p, self.drop_hidden),
                FFN(self.model_dim, self.hidden_dim, self.dropout_p)
            )) # 반복해서 넣어줘야 하는 레이어들은 속성으로 할당해서 쓰면 안됨. 그때마다 객체를 새롭게 선언해야 함!
        
        self.encoder = nn.Sequential(*layers)

    def forward(self, z):
        return self.encoder(z)

 

 Encoder의 경우에는 레이어 개수만큼 MSA와 FFN을 반복해서 넣어주었습니다. Normalization과 skip connection을 따로 클래스로 정의해서 추가해줘도 되지만, 저는 MSA와 FFN을 구현할 때 normalization과 skip connection을 이미 적용했기에 여기에서는 MSA와 FFN만이 쌓이는 구조로 구현되었습니다.

제가 처음에 실수한 부분은 하나의 레이어가 별도의 파라미터인데 self.msa = MSA(...) 이런 식으로 클래스 속성으로 정의해두고 이를 반복적으로 쌓아나가서 결국엔 동일한 레이어가 반복되는 구조가 되었습니다..파라미터 개수를 체크해서 이런 실수를 잡아낼 수 있었습니다.

 

 

Classification Head

마지막에 classification을 수행하는 classification head 부분입니다. classification head의 input은 두 가지 옵션이 있습니다.

1) class token만을 넣어준다
2) encoder output을 global average pooling을 적용해서 넣어준다

입니다.

Input type을 결정 짓고 나면, pre-norm을 거쳐서 hidden layer를 통과함으로써 classification을 수행하게 됩니다.

class ClassificationHead(nn.Module):
    def __init__(self, model_dim, n_class, training_phase, dropout_p, pool:str):
        super().__init__()

        self.model_dim = model_dim
        self.n_class = n_class
        self.training_phase = training_phase

        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.pool = pool

        self.norm = nn.LayerNorm(self.model_dim)
        self.hidden = nn.Linear(self.model_dim, self.n_class)
        self.dropout = nn.Dropout(dropout_p)
        self.relu = nn.ReLU(inplace=True)

        self.block = nn.Sequential(self.norm, self.hidden)
    
    def forward(self, encoder_output):
        y = encoder_output.mean(dim=1) if self.pool == 'mean' else encoder_output[:, 0] # cls_token으로 predict할 경우 첫번째 요소만 slicing

        return self.block(y) # pre-norm 적용

 

Vision Transformer

위의 클래스를 다 종합해서 최종적인 ViT 클래스를 구현한 결과입니다. embedding, encoder, classification head를 ordered dict를 활용하여 하나의 nn.Sequential로 묶어주었습니다.

class ViT(nn.Module):
    def __init__(self, p, model_dim, hidden_dim, n_class, n_heads, n_layers, n_patches, dropout_p=.1, training_phase='p', pool='cls', drop_hidden=True):
        super().__init__()
        input_dim = (p**2)*3
        
        self.vit = nn.Sequential(
            OrderedDict({
                "embedding": Embeddings(p, input_dim, model_dim, n_patches, dropout_p),
                "encoder": Encoder(n_layers, model_dim, n_heads, hidden_dim, dropout_p, drop_hidden),
                "c_head": ClassificationHead(model_dim, n_class, training_phase, dropout_p, pool)
            })
        )
    
    def forward(self, x):
        return self.vit(x)

 

위와 같이 구현하면, 파라미터 개수가 대략 85.8M 정도가 나오게 됩니다.


오늘은 ViT 모델 구현 코드에 대해 글을 작성해보았습니다. reshape과 attention, matmul 등의 연산으로만 철저히 구성되어 있고 모듈화 시키기 매우 좋은 구조라고 생각합니다. 다음 포스팅에서는 ViT의 스케쥴러에 대해 작성해보겠습니다. 또, 스케쥴링 구성에 따라 어떤 식으로 학습이 이루어졌는지에 대해서도 간단히 얘기해보고자 합니다.

300x250