|
|
""" |
|
|
fsdp.py |
|
|
|
|
|
Core class definition for a strategy implementing Torch native Fully Sharded Data Parallel Training (with support for |
|
|
fine-grained control over wrapping policies and mixed precision per component). |
|
|
""" |
|
|
import gc |
|
|
import json |
|
|
import math |
|
|
import threading |
|
|
from datetime import datetime |
|
|
from functools import partial |
|
|
from pathlib import Path |
|
|
from typing import Callable, Optional |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch.nn as nn |
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
|
|
CheckpointImpl, |
|
|
apply_activation_checkpointing, |
|
|
checkpoint_wrapper, |
|
|
) |
|
|
from torch.distributed.fsdp import ( |
|
|
FullStateDictConfig, |
|
|
FullOptimStateDictConfig, |
|
|
MixedPrecision, |
|
|
ShardingStrategy, |
|
|
StateDictType, |
|
|
) |
|
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
|
|
from torch.optim import AdamW |
|
|
from tqdm import tqdm |
|
|
from transformers.optimization import ( |
|
|
get_constant_schedule, |
|
|
get_constant_schedule_with_warmup, |
|
|
get_cosine_with_min_lr_schedule_with_warmup, |
|
|
) |
|
|
|
|
|
from vitra.training.base_strategy import TrainingStrategy |
|
|
from vitra.training.metrics import VLAMetrics |
|
|
from vitra.utils.overwatch import initialize_overwatch |
|
|
|
|
|
|
|
|
overwatch = initialize_overwatch(__name__) |
|
|
|
|
|
|
|
|
def get_constant_schedule_with_freeze_warmup( |
|
|
optimizer: torch.optim.Optimizer, |
|
|
num_warmup_steps: int, |
|
|
last_epoch: int = -1, |
|
|
) -> torch.optim.lr_scheduler.LambdaLR: |
|
|
"""Create a learning rate scheduler that is zero for the first `num_warmup_steps` steps, then constant.""" |
|
|
def lr_lambda(current_step: int) -> float: |
|
|
if current_step < num_warmup_steps: |
|
|
return 0.0 |
|
|
return 1.0 |
|
|
|
|
|
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch) |
|
|
|
|
|
|
|
|
def split_modality_collator( |
|
|
vla, |
|
|
cognition_token_weight_decay: bool = False, |
|
|
move_word_embedding_to_action_model: bool = False, |
|
|
verbose: bool = True |
|
|
): |
|
|
""" |
|
|
Split model parameters into vlm backbone and other (action model) groups with separate decay settings. |
|
|
|
|
|
Returns: |
|
|
Tuple of (backbone_decay, backbone_no_decay, other_decay, other_no_decay) parameter lists |
|
|
""" |
|
|
backbone_decay, backbone_no_decay, other_decay, other_no_decay = [], [], [], [] |
|
|
|
|
|
def is_backbone_param(name: str) -> bool: |
|
|
"""Check if the parameter is part of the vision or text backbone.""" |
|
|
if move_word_embedding_to_action_model and "embed_tokens" in name: |
|
|
return False |
|
|
return "backbone" in name |
|
|
|
|
|
for name, param in vla.named_parameters(): |
|
|
if not param.requires_grad: |
|
|
continue |
|
|
|
|
|
|
|
|
no_weight_decay = param.ndim <= 1 or name.endswith(".bias") |
|
|
if "cognition_token" in name: |
|
|
no_weight_decay = not cognition_token_weight_decay |
|
|
|
|
|
|
|
|
if no_weight_decay: |
|
|
if is_backbone_param(name): |
|
|
backbone_no_decay.append(param) |
|
|
if verbose: |
|
|
overwatch.info(f"Parameter `{name}` is part of the backbone and has no decay; added to `backbone_no_decay`") |
|
|
else: |
|
|
other_no_decay.append(param) |
|
|
if verbose: |
|
|
overwatch.info(f"Parameter `{name}` is not part of the backbone and has no decay; added to `other_no_decay`") |
|
|
else: |
|
|
if is_backbone_param(name): |
|
|
backbone_decay.append(param) |
|
|
if verbose: |
|
|
overwatch.info(f"Parameter `{name}` is part of the backbone and has decay; added to `backbone_decay`") |
|
|
else: |
|
|
other_decay.append(param) |
|
|
if verbose: |
|
|
overwatch.info(f"Parameter `{name}` is not part of the backbone and has decay; added to `other_decay`") |
|
|
|
|
|
return backbone_decay, backbone_no_decay, other_decay, other_no_decay |
|
|
|
|
|
|
|
|
class VLAFSDPStrategy(TrainingStrategy): |
|
|
"""FSDP (Fully Sharded Data Parallel) training strategy for VLA models.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vla, |
|
|
device_id: int, |
|
|
stage: str, |
|
|
epochs: int, |
|
|
max_steps: Optional[int], |
|
|
global_batch_size: int, |
|
|
per_device_batch_size: int, |
|
|
learning_rate: float, |
|
|
weight_decay: float, |
|
|
max_grad_norm: float, |
|
|
lr_scheduler_type: str, |
|
|
warmup_ratio: float, |
|
|
enable_gradient_checkpointing: bool = True, |
|
|
enable_mixed_precision_training: bool = True, |
|
|
reduce_in_full_precision: bool = False, |
|
|
action_model_learning_rate: Optional[float] = None, |
|
|
action_model_weight_decay: Optional[float] = None, |
|
|
mixed_precision_dtype: torch.dtype = torch.bfloat16, |
|
|
sharding_strategy: str = "shard-grad-op", |
|
|
state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT, |
|
|
cognition_token_weight_decay: bool = False, |
|
|
llm_freeze_step: int = 0, |
|
|
move_word_embedding_to_action_model: bool = False, |
|
|
optimizer_betas: tuple = (0.9, 0.999), |
|
|
) -> None: |
|
|
super().__init__( |
|
|
vla=vla, |
|
|
device_id=device_id, |
|
|
stage=stage, |
|
|
epochs=epochs, |
|
|
max_steps=max_steps, |
|
|
global_batch_size=global_batch_size, |
|
|
per_device_batch_size=per_device_batch_size, |
|
|
learning_rate=learning_rate, |
|
|
weight_decay=weight_decay, |
|
|
max_grad_norm=max_grad_norm, |
|
|
lr_scheduler_type=lr_scheduler_type, |
|
|
warmup_ratio=warmup_ratio, |
|
|
enable_gradient_checkpointing=enable_gradient_checkpointing, |
|
|
enable_mixed_precision_training=enable_mixed_precision_training, |
|
|
reduce_in_full_precision=reduce_in_full_precision, |
|
|
mixed_precision_dtype=mixed_precision_dtype, |
|
|
) |
|
|
|
|
|
self.action_model_learning_rate = action_model_learning_rate if action_model_learning_rate is not None else learning_rate |
|
|
self.action_model_weight_decay = action_model_weight_decay if action_model_weight_decay is not None else weight_decay |
|
|
self.cognition_token_weight_decay = cognition_token_weight_decay |
|
|
self.llm_freeze_step = llm_freeze_step |
|
|
self.move_word_embedding_to_action_model = move_word_embedding_to_action_model |
|
|
self.optimizer_betas = optimizer_betas |
|
|
|
|
|
|
|
|
if sharding_strategy == "shard-grad-op": |
|
|
self.fsdp_sharding_strategy = ShardingStrategy._HYBRID_SHARD_ZERO2 |
|
|
elif sharding_strategy == "full-shard": |
|
|
self.fsdp_sharding_strategy = ShardingStrategy.HYBRID_SHARD |
|
|
else: |
|
|
raise ValueError(f"FSDP sharding strategy '{sharding_strategy}' is not supported!") |
|
|
|
|
|
assert state_dict_type == StateDictType.FULL_STATE_DICT, "Sharded state saving is not yet implemented!" |
|
|
self.fsdp_state_dict_type = state_dict_type |
|
|
self.fsdp_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
|
|
self.fsdp_save_optimizer_policy = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) |
|
|
|
|
|
def save_checkpoint( |
|
|
self, |
|
|
run_dir: Path, |
|
|
global_step: int, |
|
|
epoch: int, |
|
|
only_trainable: bool = True, |
|
|
is_epoch_end: bool = False, |
|
|
) -> None: |
|
|
"""Save a checkpoint to the `run_dir` only containing the state_dicts for trainable parameters by default.""" |
|
|
assert isinstance(self.vla, FSDP), "FSDPStrategy.save_checkpoint assumes VLM is already wrapped in FSDP!" |
|
|
if is_epoch_end: |
|
|
checkpoint_name = f"epoch={epoch}-step={global_step}.end.ckpt" |
|
|
else: |
|
|
checkpoint_name = f"epoch={epoch}-step={global_step}.ckpt" |
|
|
checkpoint_dir = run_dir / "checkpoints"/ checkpoint_name |
|
|
if overwatch.is_rank_zero(): |
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def save_with_time(state_dict, path): |
|
|
overwatch.info(f"Saving state dict to {path} start at {datetime.now()}") |
|
|
torch.save(state_dict, path) |
|
|
overwatch.info(f"Saving state dict to {path} end at {datetime.now()}") |
|
|
|
|
|
|
|
|
with FSDP.state_dict_type(self.vla, self.fsdp_state_dict_type, self.fsdp_save_policy, self.fsdp_save_optimizer_policy): |
|
|
overwatch.info("Gathering model state") |
|
|
model_state = self.vla.state_dict() |
|
|
overwatch.info("Preparing save checkpoint") |
|
|
overwatch.info("Gathering optimizer state") |
|
|
optim_state = FSDP.optim_state_dict(self.vla, self.optimizer) |
|
|
meta_state = { |
|
|
"epoch": epoch, |
|
|
"global_step": global_step |
|
|
} |
|
|
if overwatch.is_rank_zero(): |
|
|
with open(checkpoint_dir / "meta.json", "w") as f: |
|
|
json.dump(meta_state, f) |
|
|
dist.barrier() |
|
|
if overwatch.is_rank_zero(): |
|
|
threading.Thread(target=save_with_time, args=(model_state, checkpoint_dir / 'weights.pt')).start() |
|
|
threading.Thread(target=save_with_time, args=(optim_state, checkpoint_dir / 'optimizer.pt')).start() |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
def load_optimizer_and_scheduler(self, checkpoint_folder: str) -> None: |
|
|
"""Load optimizer and scheduler state from checkpoint.""" |
|
|
assert isinstance(self.vla, FSDP), "FSDPStrategy.load_optimizer_and_scheduler assumes VLM is already wrapped in FSDP!" |
|
|
|
|
|
checkpoint_folder = Path(checkpoint_folder) |
|
|
optimizer_path = checkpoint_folder / "optimizer.pt" |
|
|
|
|
|
if not optimizer_path.exists(): |
|
|
overwatch.warning(f"Optimizer checkpoint not found at {optimizer_path}!") |
|
|
return |
|
|
|
|
|
|
|
|
optim_state_dict = torch.load(optimizer_path, map_location="cpu") |
|
|
|
|
|
with FSDP.state_dict_type( |
|
|
self.vla, |
|
|
self.fsdp_state_dict_type, |
|
|
FullStateDictConfig(offload_to_cpu=True, rank0_only=False), |
|
|
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False) |
|
|
): |
|
|
optim_state_dict = FSDP.optim_state_dict_to_load(self.vla, self.optimizer, optim_state_dict) |
|
|
|
|
|
self.optimizer.load_state_dict(optim_state_dict) |
|
|
|
|
|
overwatch.info(f"Loaded optimizer state dict from {optimizer_path}") |
|
|
|
|
|
def run_setup( |
|
|
self, |
|
|
run_dir: Path, |
|
|
n_train_examples: int, |
|
|
auto_wrap_policy_modules, |
|
|
checkpointing_policy_modules, |
|
|
) -> None: |
|
|
"""Setup FSDP training (wrap model, create optimizer, etc.).""" |
|
|
from torch.distributed.fsdp.wrap import ModuleWrapPolicy |
|
|
|
|
|
auto_wrap_policy = ModuleWrapPolicy(auto_wrap_policy_modules) |
|
|
|
|
|
|
|
|
if self.enable_mixed_precision_training and self.mixed_precision_dtype == torch.bfloat16: |
|
|
reduce_buffer_dtype = torch.bfloat16 if not self.reduce_in_full_precision else torch.float32 |
|
|
fsdp_precision_policy = MixedPrecision( |
|
|
param_dtype=torch.bfloat16, |
|
|
reduce_dtype=reduce_buffer_dtype, |
|
|
buffer_dtype=reduce_buffer_dtype |
|
|
) |
|
|
else: |
|
|
fsdp_precision_policy = MixedPrecision( |
|
|
param_dtype=torch.float32, |
|
|
reduce_dtype=torch.float32, |
|
|
buffer_dtype=torch.float32 |
|
|
) |
|
|
|
|
|
|
|
|
self.vla = FSDP( |
|
|
self.vla, |
|
|
auto_wrap_policy=auto_wrap_policy, |
|
|
mixed_precision=fsdp_precision_policy, |
|
|
sharding_strategy=self.fsdp_sharding_strategy, |
|
|
device_id=torch.cuda.current_device(), |
|
|
limit_all_gathers=True, |
|
|
use_orig_params=True, |
|
|
) |
|
|
|
|
|
|
|
|
if self.enable_gradient_checkpointing: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT) |
|
|
if checkpointing_policy_modules is not None: |
|
|
def check_fn(submodule: nn.Module) -> bool: |
|
|
if isinstance(checkpointing_policy_modules, (list, set)): |
|
|
return any(isinstance(submodule, module) for module in checkpointing_policy_modules) |
|
|
return isinstance(submodule, checkpointing_policy_modules) |
|
|
|
|
|
|
|
|
apply_activation_checkpointing(self.vla, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn) |
|
|
|
|
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
n_train_examples = math.ceil(n_train_examples / self.global_batch_size) * self.global_batch_size |
|
|
if self.max_steps is None: |
|
|
num_training_steps = (n_train_examples * self.epochs) // self.global_batch_size |
|
|
else: |
|
|
num_training_steps = self.max_steps |
|
|
|
|
|
backbone_decay, backbone_no_decay, other_decay, other_no_decay = split_modality_collator( |
|
|
self.vla, |
|
|
cognition_token_weight_decay=self.cognition_token_weight_decay, |
|
|
move_word_embedding_to_action_model=self.move_word_embedding_to_action_model, |
|
|
verbose=False |
|
|
) |
|
|
groups = [ |
|
|
{"params": backbone_decay, "weight_decay": self.weight_decay, "lr": self.learning_rate}, |
|
|
{"params": backbone_no_decay, "weight_decay": 0.0, "lr": self.learning_rate}, |
|
|
{"params": other_decay, "weight_decay": self.action_model_weight_decay, "lr": self.action_model_learning_rate}, |
|
|
{"params": other_no_decay, "weight_decay": 0.0, "lr": self.action_model_learning_rate}, |
|
|
] |
|
|
|
|
|
|
|
|
self.optimizer = AdamW(groups, betas=self.optimizer_betas) |
|
|
|
|
|
if self.lr_scheduler_type == "linear-warmup+cosine-decay" or self.lr_scheduler_type == "warmup_cosine": |
|
|
|
|
|
num_warmup_steps = int(num_training_steps * self.warmup_ratio) |
|
|
|
|
|
self.lr_scheduler = get_cosine_with_min_lr_schedule_with_warmup( |
|
|
self.optimizer, |
|
|
num_warmup_steps, |
|
|
num_training_steps, |
|
|
min_lr_rate=0.1 |
|
|
) |
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group["lr"] = 0.0 |
|
|
|
|
|
elif self.lr_scheduler_type == "constant": |
|
|
num_warmup_steps = 0 |
|
|
self.lr_scheduler = get_constant_schedule(self.optimizer) |
|
|
|
|
|
elif self.lr_scheduler_type == "linear-warmup+constant" or self.lr_scheduler_type == "warmup_constant": |
|
|
num_warmup_steps = int(num_training_steps * self.warmup_ratio) |
|
|
self.lr_scheduler = get_constant_schedule_with_warmup( |
|
|
self.optimizer, num_warmup_steps=num_warmup_steps |
|
|
) |
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group["lr"] = 0.0 |
|
|
|
|
|
elif self.lr_scheduler_type == "backbone-freeze-warmup": |
|
|
|
|
|
num_warmup_steps = self.llm_freeze_step |
|
|
|
|
|
|
|
|
backbone_groups = [ |
|
|
{"params": backbone_decay, "weight_decay": self.weight_decay, "lr": self.learning_rate}, |
|
|
{"params": backbone_no_decay, "weight_decay": 0.0, "lr": self.learning_rate}, |
|
|
] |
|
|
action_model_groups = [ |
|
|
{"params": other_decay, "weight_decay": self.action_model_weight_decay, "lr": self.action_model_learning_rate}, |
|
|
{"params": other_no_decay, "weight_decay": 0.0, "lr": self.action_model_learning_rate}, |
|
|
] |
|
|
|
|
|
|
|
|
backbone_optimizer = AdamW(backbone_groups, betas=self.optimizer_betas) |
|
|
action_model_optimizer = AdamW(action_model_groups, betas=self.optimizer_betas) |
|
|
|
|
|
|
|
|
backbone_scheduler = get_constant_schedule_with_freeze_warmup( |
|
|
backbone_optimizer, num_warmup_steps=num_warmup_steps |
|
|
) |
|
|
action_model_scheduler = get_constant_schedule(action_model_optimizer) |
|
|
|
|
|
|
|
|
self.lr_scheduler = MultiGroupLRScheduler( |
|
|
self.optimizer, backbone_scheduler, action_model_scheduler |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Learning Rate Schedule with type `{self.lr_scheduler_type}` is not supported!") |
|
|
|
|
|
|
|
|
scheduler_info = f" |-> LR Scheduler Type = {self.lr_scheduler_type}\n" |
|
|
if self.lr_scheduler_type == "backbone-freeze-warmup+action-constant": |
|
|
scheduler_info += f" |-> Backbone: Constant schedule with freeze warmup ({num_warmup_steps} steps)\n" |
|
|
scheduler_info += f" |-> Action Head: Constant schedule\n" |
|
|
else: |
|
|
scheduler_info += f" |-> LR Scheduler Warmup Steps (Ratio) = {num_warmup_steps} ({self.warmup_ratio})\n" |
|
|
|
|
|
overwatch.info( |
|
|
"FSDP Full-Shard Strategy =>> Finalized Training Setup:\n" |
|
|
f" |-> Global (Effective) Batch Size = {self.global_batch_size}\n" |
|
|
f" |-> Per-Device Batch Size = {self.per_device_batch_size}\n" |
|
|
f" |-> Distributed World Size = {overwatch.world_size()}\n" |
|
|
f" |-> Gradient Accumulation Steps = {self.grad_accumulation_steps}\n\n" |
|
|
f" |-> LLM Backbone FSDP Gradient Checkpointing = {self.enable_gradient_checkpointing}\n" |
|
|
f" |-> Use FSDP Mixed Precision = {self.enable_mixed_precision_training}\n" |
|
|
f" |-> Parameter Precision = {fsdp_precision_policy.param_dtype}\n" |
|
|
f" |-> Reduction Precision = {fsdp_precision_policy.reduce_dtype}\n" |
|
|
f" |-> Buffer Precision = {fsdp_precision_policy.buffer_dtype}\n\n" |
|
|
f" |-> Default AdamW LR = {self.learning_rate}\n" |
|
|
f" |-> AdamW Weight Decay = {self.weight_decay}\n" |
|
|
f" |-> AdamW Betas = {self.optimizer_betas}\n" |
|
|
+ scheduler_info + |
|
|
f" |-> LLM Learning Rate = {self.learning_rate}\n" |
|
|
f" |-> Action Model Learning Rate = {self.action_model_learning_rate}\n" |
|
|
f" |-> LLM Weight Decay = {self.weight_decay}\n" |
|
|
f" |-> Action Model Weight Decay = {self.action_model_weight_decay}\n" |
|
|
f" |-> Cognition Token Weight Decay = {self.cognition_token_weight_decay}\n" |
|
|
f" |-> Dataset Size = {n_train_examples} Examples\n" |
|
|
f" |-> Max Steps = {num_training_steps}\n" |
|
|
) |
|
|
|
|
|
def clip_grad_norm(self) -> None: |
|
|
"""Clip gradients using FSDP's built-in gradient clipping.""" |
|
|
self.vla.clip_grad_norm_(max_norm=self.max_grad_norm) |
|
|
|
|
|
def run_training( |
|
|
self, |
|
|
dataloader, |
|
|
metrics: VLAMetrics, |
|
|
save_interval: int = 2500, |
|
|
epoch_save_interval: int = 1, |
|
|
start_epoch: int = 0, |
|
|
start_global_step: int = 0, |
|
|
save_full_model: bool = True, |
|
|
) -> None: |
|
|
"""Run the VLA training loop for the given dataloader; log losses and action metrics to metrics.""" |
|
|
vla_dataset = dataloader.dataset |
|
|
|
|
|
status = metrics.get_status() |
|
|
with tqdm( |
|
|
total=(self.epochs * (len(dataloader) // self.grad_accumulation_steps)) if self.max_steps is None else self.max_steps, |
|
|
desc=status, |
|
|
leave=False, |
|
|
disable=not overwatch.is_rank_zero(), |
|
|
initial=start_global_step, |
|
|
) as progress: |
|
|
train_idx = 0 |
|
|
for epoch in range(start_epoch, self.epochs): |
|
|
self.vla.train() |
|
|
self.optimizer.zero_grad() |
|
|
for batch_idx, batch in enumerate(dataloader): |
|
|
|
|
|
|
|
|
input_ids = batch["input_ids"] |
|
|
rgb = batch["pixel_values"] |
|
|
attention_mask = batch["attention_mask"] |
|
|
action_labels = batch["actions"] |
|
|
action_masks = batch["action_masks"] |
|
|
current_state_mask = batch["current_state_mask"] |
|
|
current_state = batch["current_state"] |
|
|
fov = batch["fov"] |
|
|
|
|
|
prediction = self.vla.forward( |
|
|
rgb, |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
action_labels=action_labels, |
|
|
action_masks=action_masks, |
|
|
current_state_mask=current_state_mask, |
|
|
current_state=current_state, |
|
|
data_source=['action'], |
|
|
fov=fov, |
|
|
) |
|
|
loss = prediction["loss"] |
|
|
|
|
|
|
|
|
metrics.commit( |
|
|
loss=loss, |
|
|
left_hand_6d=prediction["left_hand_6d"], |
|
|
left_hand_joints=prediction["left_hand_joints"], |
|
|
right_hand_6d=prediction["right_hand_6d"], |
|
|
right_hand_joints=prediction["right_hand_joints"], |
|
|
) |
|
|
|
|
|
normalized_loss = loss / self.grad_accumulation_steps |
|
|
normalized_loss.backward() |
|
|
|
|
|
|
|
|
|
|
|
if (train_idx + 1) % self.grad_accumulation_steps == 0: |
|
|
|
|
|
self.clip_grad_norm() |
|
|
|
|
|
|
|
|
self.optimizer.step() |
|
|
self.lr_scheduler.step() |
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lr_dict = {} |
|
|
|
|
|
if isinstance(self.lr_scheduler, MultiGroupLRScheduler): |
|
|
|
|
|
lr = self.lr_scheduler.get_last_lr() |
|
|
lr_dict['backbone_decay_lr'] = lr[0] |
|
|
lr_dict['backbone_no_decay_lr'] = lr[1] |
|
|
lr_dict['action_decay_lr'] = lr[2] |
|
|
lr_dict['action_no_decay_lr'] = lr[3] |
|
|
current_lr = lr_dict['backbone_decay_lr'] |
|
|
else: |
|
|
current_lr = self.lr_scheduler.get_last_lr()[0] |
|
|
|
|
|
metrics.commit(update_step_time=True, global_step=metrics.global_step + 1, epoch=epoch, lr=current_lr, **lr_dict) |
|
|
status = metrics.push() |
|
|
|
|
|
|
|
|
if (terminate := (self.max_steps is not None and metrics.global_step >= self.max_steps)) or ( |
|
|
(metrics.global_step % save_interval) == 0 |
|
|
): |
|
|
self.save_checkpoint( |
|
|
metrics.run_dir, metrics.global_step, epoch, only_trainable=not save_full_model |
|
|
) |
|
|
dist.barrier() |
|
|
|
|
|
if terminate: |
|
|
return |
|
|
train_idx += 1 |
|
|
|
|
|
|
|
|
progress.set_description(status) |
|
|
progress.update() |
|
|
|
|
|
|
|
|
if epoch % epoch_save_interval == 0: |
|
|
self.save_checkpoint( |
|
|
metrics.run_dir, metrics.global_step, epoch, only_trainable=not save_full_model, is_epoch_end=True |
|
|
) |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
class MultiGroupLRScheduler: |
|
|
""" |
|
|
A custom learning rate scheduler that applies different scheduling strategies |
|
|
to different parameter groups in the optimizer. |
|
|
""" |
|
|
def __init__(self, optimizer, backbone_scheduler, action_model_scheduler): |
|
|
self.optimizer = optimizer |
|
|
self.backbone_scheduler = backbone_scheduler |
|
|
self.action_model_scheduler = action_model_scheduler |
|
|
|
|
|
|
|
|
self.backbone_group_indices = [0, 1] |
|
|
self.action_model_group_indices = [2, 3] |
|
|
|
|
|
def step(self): |
|
|
"""Step both schedulers and update the corresponding parameter groups""" |
|
|
|
|
|
self.backbone_scheduler.step() |
|
|
self.action_model_scheduler.step() |
|
|
|
|
|
|
|
|
backbone_lrs = self.backbone_scheduler.get_last_lr() |
|
|
for i, group_idx in enumerate(self.backbone_group_indices): |
|
|
|
|
|
self.optimizer.param_groups[group_idx]['lr'] = backbone_lrs[0] if len(backbone_lrs) == 1 else backbone_lrs[i] |
|
|
|
|
|
|
|
|
action_model_lrs = self.action_model_scheduler.get_last_lr() |
|
|
for i, group_idx in enumerate(self.action_model_group_indices): |
|
|
|
|
|
self.optimizer.param_groups[group_idx]['lr'] = action_model_lrs[0] if len(action_model_lrs) == 1 else action_model_lrs[i] |
|
|
|
|
|
def get_last_lr(self): |
|
|
"""Return the last learning rates for all parameter groups""" |
|
|
backbone_lrs = self.backbone_scheduler.get_last_lr() |
|
|
action_model_lrs = self.action_model_scheduler.get_last_lr() |
|
|
|
|
|
|
|
|
return [ |
|
|
backbone_lrs[0] if len(backbone_lrs) == 1 else backbone_lrs[0], |
|
|
backbone_lrs[0] if len(backbone_lrs) == 1 else backbone_lrs[0], |
|
|
action_model_lrs[0] if len(action_model_lrs) == 1 else action_model_lrs[0], |
|
|
action_model_lrs[0] if len(action_model_lrs) == 1 else action_model_lrs[0], |
|
|
] |
|
|
|