Sparse Transformer

Generating Long Sequences with Sparse Transformers

논문이 하려고 하는 이야기

기존 Attention Matrix -> O(n^2)
새로운 Attention (Sparse) -> O(n루트n)

어떻게 했니?

  • Attention Sparse Factorization
  • 어텐션 구조 다르게
  • initialization 다르게
  • Attention Matrix Recomputation
  • 좀더 빠른 Attention 커널
=> 합쳐서 Sparse Transformer!

Factorized Self-Attention

notion image
(a)가 기존의 Transformer
(b)는 이미지/음악파일 같이 특정한 길이가 의미를 갖는 경우 해당 길이/주기만큼을 Attention
(c)는 Text처럼 특정한 길이 의미 없이 정해진 시퀀스 길이에 따른 Attention 결정
위쪽 이미지는 6*6 사이즈의 "이미지"
아래쪽은 Connectivity Matrix -> 실제로 펼치면 어떤 Attention을 취하는지 보여주는 셈 (어텐션 패턴)
이 논문에서 연구 방향은 Sparse Attention Pattern 자체에만 집중함!
notion image
  • X: Input Embedding
  • S: Connectivity Pattern -> Embedding 받아서 Output
  • W(q,k,v): QKV Matrix (d는 쿼리/키 사이즈)
AutoRegressive 모델 -> 이전것만 보도록 포지션 제한
Factorized Self Attention은 p개의 떨어진 Attention Head들을 가진다!
→ 코드상에서 보면, all, fixed, local, strided 로 구별해서 처리한다.
(이미지 등 주기에 의미가 있는 경우)
notion image
  • all: 기존과 동일
  • fixed: Attention context를 두고 stride는 그 이내로 정해서 진행
    • notion image
  • local: 일정 길이만큼 잘라서 진행
  • strided: 일정 길이의 Attention을 잘라서 이동

Two-dimensional Factorized Attention

  • 2차원(NxN) 행렬에서 Attention을 쪼개는 방법
  • 주기성이 있는(가로 Pixel 수 or 음악의 길이 등) 경우
    • = 하나는 현재부터 과거 K개 까지 Att & 하나는 처음부터 띄엄띄엄 J개 Att
      = Strided Attention
       
notion image
  • 텍스트같은 경우
    • → Strided Pattern에서는 성능 안좋음
      = 일정 길이 K 이내에 모두 Attention & 특정 토큰에서는 이후 전체 토큰에 Attention
       
       
notion image

Sparse Transformer

1. Factorized Attention Heads

앞서 본 Factorized Attention들 적용
  • 방법1) 각 Attention 별 위 방법 중 하나만 사용
  • 방법2) Merged Head: 여러가지 방법 중 선택된 Attention에 해당하는 모든 Token대해서 → 합쳐서 Attention을 취한다!
    • 조금 더 많은 computational cost → 하지만 constant 정도라서 큰 문제가 아님
  • 방법3) MultiHead Attention: Head별로 다른 Attention 사용
    • 최종 Feature concat

2. 수백개의 Layer로 확장

  1. Pre-activation Residual Block
      • Layer 태우기 전에 Activation 미리 태우는 블록
      • 여기서는 실제로는 아래와 같이 진행함! a = norm → att → dropout b = a+H → norm → ff → dropout return a+b
        • notion image
      • Residual을 다르게 태우는 방법 중 하나.
       
notion image

3. Diverse Data types...? (왜 쓰는지 잘 모르겠음)

4. Recomputing Attention Weights

결국 정리하면 Gradient Checkpointing
→ 메모리 쓰는 대신 Disk i/o가 생긴다
그렇지만 최대 길이 16,384만큼 학습할 수 있었다!
  • 여기서는 Attention & FF Block을 Backward 과정에서 Recompute
  • Attention 내에서는 Dropout 없애서 구현 간단히, 대신 Residual addition 뒤에 붙임.

5. 효율적인 sparse att 커널

직접 GPU 커널 만들었다고...

6. Mixed Precision = FP16