Huggingface Transformers Train with FSDP on PyTorch/XLA @ TPU
🤗

Huggingface Transformers Train with FSDP on PyTorch/XLA @ TPU

Tags
MLDL Framework
Dev
Cloud
TPU
Published
Published September 9, 2023

환경설정

  • TPUv3-8
    • 이미지: tpu-vm-pt-1.13
  • Python 3.8 (built-in OS image)
  • PyTorch 2.0
  • PyTorch-XLA 2.0
    • 둘다 글쓰는 시점 2023.09.09 기준 최신 stable 버전.

환경설치

  • TPUv3-8 위 VM 이미지로 띄우기 (TPU VM. not TPU Node)
  • 가상환경 설정 & 활성화: virtualenv venv && source venv/bin/activate
  • 패키지 설치(PyTorch+XLA)
pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch-2.0-cp38-cp38-linux_x86_64.whl pip3 install https://storage.googleapis.com/tpu-pytorch/wheels/tpuvm/torch_xla-2.0-cp38-cp38-linux_x86_64.whl
  • Huggingface 관련 패키지 설치
git clone https://github.com/huggingface/transformers.git cd transformers git checkout v4.31-release pip3 install -e . pip3 install datasets evaluate scikit-learn pip3 install accelerate==0.21.0

학습 설정

1. 모델 config 설정

  • Huggingface Transformers 라이브러리에 맞는 config.json 파일 받아서 넣기
// 예시 config.json { "activation_function": "gelu_new", "architectures": [ "GPT2LMHeadModel" ], "attn_pdrop": 0.1, "bos_token_id": 50256, "embd_pdrop": 0.1, "eos_token_id": 50256, "initializer_range": 0.02, "layer_norm_epsilon": 1e-05, "model_type": "gpt2", "n_embd": 3072, "n_head": 24, "n_layer": 18, "n_positions": 1024, "resid_pdrop": 0.1, "summary_activation": null, "summary_first_dropout": 0.1, "summary_proj_to_labels": true, "summary_type": "cls_index", "summary_use_proj": true, "task_specific_params": { "text-generation": { "do_sample": true, "max_length": 50 } }, "vocab_size": 50257 }

2. FSDP config 설정

  • 위 URL 참고해서 FSDP/XLA용 config 설정하기.
// fsdp_config.json { "fsdp_transformer_layer_cls_to_wrap": [ "GPT2Block" ], "xla": true, "xla_fsdp_settings": { "compute_dtype": "bfloat16", "buffer_dtype": "bfloat16", "shard_param_on_dim_0": true, "pin_layout_in_collective_ops": true }, "xla_fsdp_grad_ckpt": true }

3. PJRT_DEVICE ENV 설정

export PJRT_DEVICE=TPU

4. 학습!

python3 -u examples/pytorch/xla_spawn.py \ --num_cores 8 \ examples/pytorch/language-modeling/run_clm.py \ --num_train_epochs 1 \ --dataset_name wikitext \ --dataset_config_name wikitext-2-raw-v1 \ --per_device_train_batch_size 32 \ --per_device_eval_batch_size 32 \ --do_train \ --do_eval \ --output_dir ./test-clm \ --overwrite_output_dir \ --config_name ./config.json \ --cache_dir /tmp \ --tokenizer_name gpt2 \ --block_size 1024 \ --optim adafactor \ --save_strategy "epoch" \ --logging_strategy "steps" \ --fsdp "full_shard" \ --fsdp_config fsdp_config.json

Err & Fix

AttributeError / XlaFullyShardedDataParallel / no attribute 'full_optim_state_dict'

문제 로그

AttributeError: type object 'XlaFullyShardedDataParallel' has no attribute 'full_optim_state_dict'

문제 발생 원인

  • XlaFullyShardedDataParallel 에서 full_optim_state_dict 는 PyTorch/XLA의 2.2버전부터 제거되는 속성
  • 하지만 현재 사용중인 버전은 PyTorch, XLA 모두 2.0버전.
  • Deprecated이지만 사용 못할 정도는 아님
  • 결국 FSDP도 Params가 쪼개져있으니 하나로 합치는 과정이 필요하다.
    • 이전의 DeepSpeed에서 save fp16 ckpt 옵션이 있을때 AllGather로 한 장비에 모아서 저장한 것과 동일한 이유!
  • 근데 왜 없다고 하는거지?

일시적 해결책: Trainer 코드에서 FSDP save 제거

  • 아래와 같이 하면 저장 자체는 문제 없이 넘어간다
  • 하지만 정말 제대로 저장되는가?는 아닌 것 같다에 한표.
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 8ed88d931..ca463f71b 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2306,11 +2306,11 @@ class Trainer: save_fsdp_optimizer( self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir ) - else: - # FSDP has a different interface for saving optimizer states. - # Needs to be called on all ranks to gather all states. - # full_optim_state_dict will be deprecated after Pytorch 2.2! - full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) + # else: + # # FSDP has a different interface for saving optimizer states. + # # Needs to be called on all ranks to gather all states. + # # full_optim_state_dict will be deprecated after Pytorch 2.2! + # full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") @@ -2336,10 +2336,11 @@ class Trainer: torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) elif self.args.should_save and not self.is_deepspeed_enabled: # deepspeed.save_checkpoint above saves model/optim/sched - if self.fsdp and not self.is_fsdp_enabled: - torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) - else: - torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + # if self.fsdp and not self.is_fsdp_enabled: + # torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) + # else: + # torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + # torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
 

참고 링크