# FROM: https://github.com/pytorch/xla/issues/5553#issuecomment-1714281894
# 신규 VENV 내에서 하는 것 추천
pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp38-cp38-linux_x86_64.whl
pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp38-cp38-linux_x86_64.whl
import torch
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh
# Enable XLA SPMD execution mode.
xr.use_spmd()
# Device mesh, this and partition spec as well as the input tensor shape define the individual shard shape.
mesh_shape = (2, 4)
num_devices = xr.global_runtime_device_count()
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
t = torch.randn(8, 4).to(xm.xla_device())
# Mesh partitioning, each device holds 1/8-th of the input
partition_spec = (0, 1)
m1_sharded = xs.mark_sharding(t, mesh, partition_spec)
assert isinstance(m1_sharded, XLAShardedTensor) == True
SPMD의 경우 유저가 하나하나 Mesh 모양을 지정하는것도 가능하지만,
모델의 Graph를 보고 SPMD가 알아서 추정하게 하는 것도 가능하다.
이때 Mesh의 Shape 지정을 통해 (device 갯수를 적당히 쪼개면 OK) Device ids를 적당히 나눠준다.
위 예시에서는 x, y 두 축으로 했지만, EasyLM의 경우는 dp, fsdp, mp 로 사용했었음.
이렇게 Mesh로 해주고 나서, 생성한 텐서 t 를 partition_spec 축에 따라서 샤딩해준다.
partition_spec ← 이게 SPMD 힌트
이때, Mesh = Physical mesh라서 물리적인 크기에 따라서 숫자를 맞춰줘야함
반대로 Hybrid Mesh로 가상 Mesh를 만들어서 처리할 수도 있음
하지만 TPU Topology상이랑 맞게 구성해주는게 연산력 효율이 높음 (MP등에서 너무 많이 왓다갓다하면 성능이 낮을수 있음)
3. 일반적인 학습
# Sharding annotate the linear layer weights.
model = SimpleLinear().to(xm.xla_device())
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)
# Training loop
model.train()
for step, (data, target) in enumerate(loader):
# Assumes `loader` returns data, target on XLA device
optimizer.zero_grad()
# Sharding annotate input data, we can shard any input
# dimensions. Sharidng the batch dimension enables
# in data parallelism, sharding the feature dimension enables
# spatial partitioning.
xs.mark_sharding(data, mesh, partition_spec)
ouput = model(data)
loss = loss_fn(output, target)
optimizer.step()
xm.mark_step()
학습을 하되, 모델을 만든 후에 xs.mark_sharding 통해서 Mesh 모양 + 파티션힌트 줘서 모델도 쪼개 올린다.
SPMD에서는 입력 데이터 DP 조차도 단순한 하나의 mesh 축으로 간주하기 때문에
→ xs.mark_sharding , 즉 모델에서 사용했던 샤딩 함수와 완전 똑같은 함수를 사용해 분산시킨다.
이유: mesh_shape 가 (4, 2)인데 현재는 힌트에 (0, None)을 넣어서 첫번째 Mesh 축(=x), 즉 4개 장비에 대해서, Input Tensor의 0 번째 축, 즉 배치사이즈를 → None 형식(복제)으로 → Data Parallel
Model을 Tensor Parallel로 쪼개거나..
이유: 힌트에 (None, 1)을 넣어서 두번째 Mesh 축(=y), 즉 2개 장비에 대해서 1, 즉 ‘Split’ 형식으로 → Tensor Parallel
# Provide optional mesh axis names and use them in the partition spec
mesh = Mesh(device_ids, (4, 2), ('data', 'model'))
partition_spec = ('model', 'data')
xs.mark_sharding(input_tensor, mesh, partition_spec)
Mesh 할 때, 이렇게 data, model 와 같은 모양으로 만들고 나서,
Partition_spec에 이름을 mesh 축 이름으로 넣어줄 수도 있다!
이렇게 하면 이제 모든 쪽으로
Partition Spec이란?
# partition_spec (Tuple[int, str, None]): A tuple of device_mesh dimension index or `None`. Each index is an int or str if the mesh axis is named.
# This specifies how each input rank is sharded (index to mesh_shape) or replicated (None).
# For example, we can shard an 8x10 tensor 4-way row-wise, and replicate column-wise.
>> input = torch.randn(8, 10)
>> mesh_shape = (4, 2)
>> partition_spec = (0, None)
“…how each input rank is sharded (index to mesh_shape) or replicated (None).”