Generating Long Sequences with Sparse Transformers
Arxiv Link: https://arxiv.org/abs/1904.10509
논문이 하려고 하는 이야기
기존 Attention Matrix -> O(n^2)
새로운 Attention (Sparse) -> O(n루트n)
어떻게 했니?
- Attention Sparse Factorization
- 어텐션 구조 다르게
- initialization 다르게
- Attention Matrix Recomputation
- 좀더 빠른 Attention 커널
=> 합쳐서 Sparse Transformer!
Factorized Self-Attention
(a)가 기존의 Transformer
(b)는 이미지/음악파일 같이 특정한 길이가 의미를 갖는 경우 해당 길이/주기만큼을 Attention
(c)는 Text처럼 특정한 길이 의미 없이 정해진 시퀀스 길이에 따른 Attention 결정
위쪽 이미지는 6*6 사이즈의 "이미지"
아래쪽은 Connectivity Matrix -> 실제로 펼치면 어떤 Attention을 취하는지 보여주는 셈 (어텐션 패턴)
이 논문에서 연구 방향은 Sparse Attention Pattern 자체에만 집중함!
- X: Input Embedding
- S: Connectivity Pattern -> Embedding 받아서 Output
- W(q,k,v): QKV Matrix (d는 쿼리/키 사이즈)
AutoRegressive 모델 -> 이전것만 보도록 포지션 제한
Factorized Self Attention은 p개의 떨어진 Attention Head들을 가진다!
→ 코드상에서 보면,
all
, fixed
, local
, strided
로 구별해서 처리한다.(이미지 등 주기에 의미가 있는 경우)
- all: 기존과 동일
- fixed: Attention context를 두고 stride는 그 이내로 정해서 진행
- local: 일정 길이만큼 잘라서 진행
- strided: 일정 길이의 Attention을 잘라서 이동
Two-dimensional Factorized Attention
- 2차원(NxN) 행렬에서 Attention을 쪼개는 방법
- 주기성이 있는(가로 Pixel 수 or 음악의 길이 등) 경우
= 하나는 현재부터 과거 K개 까지 Att & 하나는 처음부터 띄엄띄엄 J개 Att
= Strided Attention
- 텍스트같은 경우
→ Strided Pattern에서는 성능 안좋음
= 일정 길이 K 이내에 모두 Attention & 특정 토큰에서는 이후 전체 토큰에 Attention
Sparse Transformer
1. Factorized Attention Heads
앞서 본 Factorized Attention들 적용
- 방법1) 각 Attention 별 위 방법 중 하나만 사용
- 방법2) Merged Head: 여러가지 방법 중 선택된 Attention에 해당하는 모든 Token대해서 → 합쳐서 Attention을 취한다!
- 조금 더 많은 computational cost → 하지만 constant 정도라서 큰 문제가 아님
- 방법3) MultiHead Attention: Head별로 다른 Attention 사용
- 최종 Feature concat
2. 수백개의 Layer로 확장
- Pre-activation Residual Block
- Layer 태우기 전에 Activation 미리 태우는 블록
- 여기서는 실제로는 아래와 같이 진행함! a = norm → att → dropout b = a+H → norm → ff → dropout return a+b
- Residual을 다르게 태우는 방법 중 하나.
3. Diverse Data types...? (왜 쓰는지 잘 모르겠음)
4. Recomputing Attention Weights
결국 정리하면 Gradient Checkpointing
→ 메모리 쓰는 대신 Disk i/o가 생긴다
그렇지만 최대 길이 16,384만큼 학습할 수 있었다!
- 여기서는 Attention & FF Block을 Backward 과정에서 Recompute
- Attention 내에서는 Dropout 없애서 구현 간단히, 대신 Residual addition 뒤에 붙임.
5. 효율적인 sparse att 커널
직접 GPU 커널 만들었다고...