PyTorch/XLA SPMD @ TPU
🤗

PyTorch/XLA SPMD @ TPU

Tags
MLDL Framework
Dev
GCP
Cloud
TPU
Published
Published September 9, 2023

SPMD?

https://pytorch.org/blog/pytorch-xla-spmd
  • GSPMD라는 방식의 Parallel 방식이 PyTorch/XLA에 추가됨
  • 아래와 같이, TPU chip간에서 원래는 DP/MP등을 직접 지정해주고 Physical Mesh를 구성해야했으나,
  • 이것을 가상 Logical Mesh로 한 뒤 (사실 이까지는 새롭지 않다)
  • DP든 MP든 샤딩이든 뭐든 각각의 Axis로 처리하고
  • “모델에 힌트”를 주면 → 알아서 그래프에서 모델을 쪼개서 학습한다.
notion image

설치 @ TPU VM

  • VM 이미지: pytorch 1.13 (tpu-vm-pt-1.13)
  • TPU 환경: TPU VM, v3-8
sudo apt-get update sudo apt-get install -y libopenblas-base
# FROM: https://github.com/pytorch/xla/issues/5553#issuecomment-1714281894 # 신규 VENV 내에서 하는 것 추천 pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
  • 아직 Torch/XLA master branch(nightly)에서만 지원한다.

TPU에서 SPMD 쓰기

1. SPMD 활성화 시키기

import torch_xla.runtime as xr # Enable PyTorch/XLA SPMD execution mode. xr.use_spmd() assert xr.is_spmd() == True

2. SPMD에게 Mesh 모양 힌트 주기

import torch import torch_xla.core.xla_model as xm import torch_xla.runtime as xr import torch_xla.experimental.xla_sharding as xs from torch_xla.experimental.xla_sharding import Mesh # Enable XLA SPMD execution mode. xr.use_spmd() # Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape. mesh_shape = (2, 4) num_devices = xr.global_runtime_device_count() device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) t = torch.randn(8, 4).to(xm.xla_device()) # Mesh partitioning, each device holds 1/8-th of the input partition_spec = (0, 1) m1_sharded = xs.mark_sharding(t, mesh, partition_spec) assert isinstance(m1_sharded, XLAShardedTensor) == True
  • SPMD의 경우 유저가 하나하나 Mesh 모양을 지정하는것도 가능하지만, 모델의 Graph를 보고 SPMD가 알아서 추정하게 하는 것도 가능하다.
  • 이때 Mesh의 Shape 지정을 통해 (device 갯수를 적당히 쪼개면 OK) Device ids를 적당히 나눠준다.
    • 위 예시에서는 x, y 두 축으로 했지만, EasyLM의 경우는 dp, fsdp, mp 로 사용했었음.
  • 이렇게 Mesh로 해주고 나서, 생성한 텐서 tpartition_spec 축에 따라서 샤딩해준다.
    • partition_spec ← 이게 SPMD 힌트
  • 이때, Mesh = Physical mesh라서 물리적인 크기에 따라서 숫자를 맞춰줘야함
    • 반대로 Hybrid Mesh로 가상 Mesh를 만들어서 처리할 수도 있음
    • 하지만 TPU Topology상이랑 맞게 구성해주는게 연산력 효율이 높음 (MP등에서 너무 많이 왓다갓다하면 성능이 낮을수 있음)

3. 일반적인 학습

# Sharding annotate the linear layer weights. model = SimpleLinear().to(xm.xla_device()) xs.mark_sharding(model.fc1.weight, mesh, partition_spec) # Training loop model.train() for step, (data, target) in enumerate(loader): # Assumes `loader` returns data, target on XLA device optimizer.zero_grad() # Sharding annotate input data, we can shard any input # dimensions. Sharidng the batch dimension enables # in data parallelism, sharding the feature dimension enables # spatial partitioning. xs.mark_sharding(data, mesh, partition_spec) ouput = model(data) loss = loss_fn(output, target) optimizer.step() xm.mark_step()
  • 학습을 하되, 모델을 만든 후에 xs.mark_sharding 통해서 Mesh 모양 + 파티션힌트 줘서 모델도 쪼개 올린다.
  • SPMD에서는 입력 데이터 DP 조차도 단순한 하나의 mesh 축으로 간주하기 때문에 → xs.mark_sharding , 즉 모델에서 사용했던 샤딩 함수와 완전 똑같은 함수를 사용해 분산시킨다.
#Examples #—------------------------------ mesh_shape = (4, 2) num_devices = xr.global_runtime_device_count() device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, ('x', 'y')) # 4-way data parallel input = torch.randn(8, 32).to(xm.xla_device()) xs.mark_sharding(input, mesh, (0, None)) # 2-way model parallel linear = nn.Linear(32, 10).to(xm.xla_device()) xs.mark_sharding(linear.weight, mesh, (None, 1))
  • 위 예시에서 볼 수 있는 것처럼
    • Data Parallel처럼 input tensor를 복제하거나
      • 이유: mesh_shape 가 (4, 2)인데 현재는 힌트에 (0, None)을 넣어서 첫번째 Mesh 축(=x), 즉 4개 장비에 대해서, Input Tensor의 0 번째 축, 즉 배치사이즈를 → None 형식(복제)으로 → Data Parallel
    • Model을 Tensor Parallel로 쪼개거나..
      • 이유: 힌트에 (None, 1)을 넣어서 두번째 Mesh 축(=y), 즉 2개 장비에 대해서 1, 즉 ‘Split’ 형식으로 → Tensor Parallel
# Provide optional mesh axis names and use them in the partition spec mesh = Mesh(device_ids, (4, 2), ('data', 'model')) partition_spec = ('model', 'data') xs.mark_sharding(input_tensor, mesh, partition_spec)
  • Mesh 할 때, 이렇게 data, model 와 같은 모양으로 만들고 나서,
  • Partition_spec에 이름을 mesh 축 이름으로 넣어줄 수도 있다!
    • 이렇게 하면 이제 모든 쪽으로

Partition Spec이란?

# partition_spec (Tuple[int, str, None]): A tuple of device_mesh dimension index or `None`. Each index is an int or str if the mesh axis is named. # This specifies how each input rank is sharded (index to mesh_shape) or replicated (None). # For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise. >> input = torch.randn(8, 10) >> mesh_shape = (4, 2) >> partition_spec = (0, None)
“…how each input rank is sharded (index to mesh_shape) or replicated (None).”
  • 즉 숫자로 들어간 경우는 해당