AttributeError / XlaFullyShardedDataParallel / no attribute 'full_optim_state_dict'
문제 로그
AttributeError: type object 'XlaFullyShardedDataParallel' has no attribute 'full_optim_state_dict'
문제 발생 원인
XlaFullyShardedDataParallel 에서 full_optim_state_dict 는 PyTorch/XLA의 2.2버전부터 제거되는 속성
하지만 현재 사용중인 버전은 PyTorch, XLA 모두 2.0버전.
Deprecated이지만 사용 못할 정도는 아님
결국 FSDP도 Params가 쪼개져있으니 하나로 합치는 과정이 필요하다.
이전의 DeepSpeed에서 save fp16 ckpt 옵션이 있을때 AllGather로 한 장비에 모아서 저장한 것과 동일한 이유!
근데 왜 없다고 하는거지?
일시적 해결책: Trainer 코드에서 FSDP save 제거
아래와 같이 하면 저장 자체는 문제 없이 넘어간다
하지만 정말 제대로 저장되는가?는 아닌 것 같다에 한표.
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index 8ed88d931..ca463f71b 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -2306,11 +2306,11 @@ class Trainer:
save_fsdp_optimizer(
self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir
)
- else:
- # FSDP has a different interface for saving optimizer states.
- # Needs to be called on all ranks to gather all states.
- # full_optim_state_dict will be deprecated after Pytorch 2.2!
- full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
+ # else:
+ # # FSDP has a different interface for saving optimizer states.
+ # # Needs to be called on all ranks to gather all states.
+ # # full_optim_state_dict will be deprecated after Pytorch 2.2!
+ # full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
@@ -2336,10 +2336,11 @@ class Trainer:
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
elif self.args.should_save and not self.is_deepspeed_enabled:
# deepspeed.save_checkpoint above saves model/optim/sched
- if self.fsdp and not self.is_fsdp_enabled:
- torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
- else:
- torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+ # if self.fsdp and not self.is_fsdp_enabled:
+ # torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
+ # else:
+ # torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
+ # torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
with warnings.catch_warnings(record=True) as caught_warnings:
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))