lerobot训练

初始化

@parser.wrap()
def train(cfg: TrainPipelineConfig):
    cfg.validate()  # 验证配置合法性(如路径、超参数范围)
    init_logging()  # 初始化日志系统(本地文件+控制台输出)
    if cfg.seed is not None:
        set_seed(cfg.seed)  # 固定随机种子(确保训练可复现)
    device = get_safe_torch_device(cfg.policy.device, log=True)  # 自动选择训练设备(GPU/CPU)
    torch.backends.cudnn.benchmark = True  # 启用CuDNN自动优化(加速卷积运算)
    torch.backends.cuda.matmul.allow_tf32 = True  # 启用TF32精度(加速矩阵乘法)

初始化阶段主要是解析参数,初始化日志,确定训练的设备。

参数解析依旧是使用了装饰器parser.wrap,通过命令后参数构建生成TrainPipelineConfig类,该类是LeRobot 框架中训练流程的核心配置类,继承自 HubMixin(支持 Hugging Face Hub 交互),通过 dataclass 定义训练全流程的参数(如数据集路径、模型超参、训练步数等)。其核心作用是:

  • 参数聚合:统一管理数据集、策略模型、优化器、评估等模块的配置,避免参数分散。
  • 合法性校验:通过 validate 方法确保配置参数有效(如路径存在、超参数范围合理)。
  • 可复现性支持:固定随机种子、保存/加载配置,确保训练过程可复现。
  • Hub 集成:支持从 Hub 加载预训练配置或推送配置至 Hub,便于共享和断点续训。

核心属性

  • dataset:DatasetConfig, 数据集配置(如 repo_id="laumy/record-07271539"、图像预处理参数)
  • env: envs.EnvConfig,评估环境配置(如仿真环境类型、任务名称,仅用于训练中评估)
  • policy: PreTrainedConfig,策略模型配置(如 ACT 的 Transformer 层数、视觉编码器类型)。
  • output_dir: Path,训练输出目录(保存 checkpoint、日志、评估视频)。
  • resume,是否从 checkpoint 续训(需指定 checkpoint_path)
  • seed,随机种子(控制模型初始化、数据 shuffle、评估环境随机性,确保复现)。
  • num_workers,数据加载线程数,用于加速数据预处理。
  • batch_size,训练批次大小,即单步输入样本数。
  • steps,总训练步数,每次参数更新计为1步。
  • log_freq,日志记录频率,每 200 步打印一次训练指标,如 loss、梯度范数。
  • eval_freq,评估频率(每 20,000 步在环境中测试策略性能,计算成功率、平均奖励)
  • save_checkpoint,是否保存 checkpoint(模型权重、优化器状态),用于续训。
  • wandb,Weights & Biases 日志配置(控制是否上传指标、视频至 WandB)
  • use_policy_training_preset,是否使用策略内置的训练预设(如 ACT 策略默认 AdamW 优化器、学习率)。
  • optimizer,优化器配置(如学习率、权重衰减,仅当 use_policy_training_preset=False 时需手动设置)。
  • scheduler,学习率调度器配置(如余弦退火,同上)。

核心方法

  • post_init:初始化实例后设置 checkpoint_path(断点续训时的路径),为后续配置校验做准备。
  • validate:确保所有配置参数合法且一致,例如续训时校验 checkpoint 路径存在,自动生成唯一输出目录(避免覆盖),强制要求非预设模式下手动指定优化器。
  • _save_pretrained:将配置保存为 JSON 文件(train_config.json),用于 Hub 共享或本地存储。
  • from_pretrained::从 Hub 或本地路径加载配置(支持断点续训或复用已有配置)。

数据与模型准备

# 1. 加载离线数据集(如机器人操作轨迹数据)
logging.info("Creating dataset")
dataset = make_dataset(cfg)  # 从HuggingFace Hub或本地路径加载数据集

# 2. 初始化策略模型(如ACT、Diffusion Policy)
logging.info("Creating policy")
policy = make_policy(
    cfg=cfg.policy,  # 策略配置(如ACT的transformer层数、视觉编码器类型)
    ds_meta=dataset.meta,  # 数据集元信息(输入/输出维度、特征类型)
)

# 3. 创建优化器和学习率调度器
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)  # 默认AdamW优化器
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)  # 混合精度训练梯度缩放器
  • make_dataset(cfg):根据配置加载数据集(如 laumy/record-07271539),返回 LeRobotDataset 对象,包含观测(图像、关节状态)和动作序列。
  • make_policy(...):根据 policy.type(如 act)和数据集元信息初始化模型,自动适配输入维度(如图像分辨率、状态维度)和输出维度(如动作空间大小)。
  • make_optimizer_and_scheduler(cfg, policy):创建优化器(默认AdamW,学习率 1e-5)和调度器(默认无,可配置余弦退火等),支持对不同参数组设置不同学习率(如视觉 backbone 微调)。

数据加载

数据加载调用的是make_dataset函数,其是LeRobot 框架中数据集创建的核心工厂函数,负责根据训练配置(TrainPipelineConfig)初始化离线机器人数据集。它整合了图像预处理、时序特征处理(delta timestamps)和数据集元信息加载,最终返回可直接用于训练的 LeRobotDataset 对象。

根据数据集配置初始化图像预处理管道(如Resize、Normalize、RandomCrop等)。若 cfg.dataset.image_transforms.enable=True(通过命令行或配置文件设置),则创建 ImageTransforms 实例,加载预设的图像变换参数(如分辨率、是否翻转等),否则不进行图像预处理。

image_transforms = (
    ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)

单数据集加载

if isinstance(cfg.dataset.repo_id, str):
    # 加载数据集元信息(特征定义、统计数据、帧率等)
    ds_meta = LeRobotDatasetMetadata(
        cfg.dataset.repo_id, root=cfg.dataset.root, revision=cfg.dataset.revision
    )
    # 计算时序偏移(delta timestamps)
    delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
    # 创建 LeRobotDataset 实例
    dataset = LeRobotDataset(
        cfg.dataset.repo_id,  # 数据集标识(HuggingFace Hub repo_id 或本地路径)
        root=cfg.dataset.root,  # 本地缓存根目录
        episodes=cfg.dataset.episodes,  # 指定加载的轨迹片段(如 ["ep001", "ep002"])
        delta_timestamps=delta_timestamps,  # 时序特征偏移(见下文详解)
        image_transforms=image_transforms,  # 图像预处理管道
        revision=cfg.dataset.revision,  # 数据集版本(如 Git commit hash)
        video_backend=cfg.dataset.video_backend,  # 视频解码后端(如 "pyav")
    )

首先实例化LeRobotDatasetMetadata,加载数据集的信息,包括特征定义如observation.images.laptop 的形状、action的维度等,以及统计信息如图像均值/方差、动作范围,还有帧率fps以便时序偏移计算。

其次调用resolve_delta_timestamps根据模型计算时序特征偏移,例如如果策略需要当前帧及前2帧的观测,则生成[-0.04, -0.02, 0](单位:秒),用于从数据中提取多时序特征。

接着实例化LeRobotDataset,其实现数据加载、时序特征拼接、图像预处理等功能,为悬链提供批次化数据,具体见https://www.laumy.tech/2332.html#h37

多数据集支持

else:
    raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
    # 以下为预留的多数据集加载代码(暂未实现)
    dataset = MultiLeRobotDataset(
        cfg.dataset.repo_id,  # 多数据集标识列表(如 ["repo1", "repo2"])
        image_transforms=image_transforms,
        video_backend=cfg.dataset.video_backend,
    )
    logging.info(f"多数据集索引映射: {pformat(dataset.repo_id_to_index)}")

预留多数据集合并功能(如融合不同场景的机器人轨迹数据),目前未实现,直接抛出 NotImplementedError。

imsageNet统计量替换

if cfg.dataset.use_imagenet_stats:
    for key in dataset.meta.camera_keys:  # 遍历所有相机图像特征(如 "observation.images.laptop")
        for stats_type, stats in IMAGENET_STATS.items():  # IMAGENET_STATS = {"mean": [...], "std": [...]}
            dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)

其目的主要将数据集图像的归一化统计量(均值/方差)替换为 ImageNet 预训练模型的统计量。当使用预训练视觉编码器(如 ResNet)时,用 ImageNet 统计量归一化图像,可提升模型迁移学习效果(避免因数据集自身统计量导致的分布偏移)。

模型加载

# 初始化策略模型(如ACT、Diffusion Policy)
logging.info("Creating policy")
policy = make_policy(
    cfg=cfg.policy,  # 策略配置(如ACT的transformer层数、视觉编码器类型)
    ds_meta=dataset.meta,  # 数据集元信息(输入/输出维度、特征类型)
)

make_policy 是 LeRobot 框架中策略模型实例化的核心工厂函数,负责根据配置(PreTrainedConfig)、数据集元信息(ds_meta)或环境配置(env_cfg),动态创建并初始化策略模型(如 ACT、Diffusion、TDMPC 等)。其核心作用是自动适配策略输入/输出维度(基于数据或环境特征),并支持加载预训练权重或初始化新模型。

在函数中,根据策略类型cfg.type,如 "act"、"diffusion")动态获取对应的策略类,如若 cfg.type = "act",get_policy_class 返回 ACTPolicy 类(ACT 策略的实现)。如果模型需要明确输入特征(如图像、状态)和输出特征(如动作)的维度,需进一步通过数据集或环境解析。

if ds_meta is not None:
    features = dataset_to_policy_features(ds_meta.features)  # 从数据集元信息提取特征
    kwargs["dataset_stats"] = ds_meta.stats  # 数据集统计量(用于输入归一化,如图像均值/方差)
else:
    if not cfg.pretrained_path:
        logging.warning("无数据集统计量,归一化模块可能初始化异常")  # 无预训练时,缺少数据统计量会导致归一化参数异常
    features = env_to_policy_features(env_cfg)  # 从环境配置提取特征(如 Gym 环境的观测/动作空间)

如果定义了离线数据进行解析特征,否则基于环境解析特征。

获取到特征后进行配置输入/输出特征。

cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}

将解析后的特征划分为输入特征(如图像、状态)和输出特征(仅动作,FeatureType.ACTION),并更新到策略配置 cfg 中。例如若 features 包含 "observation.images.laptop"(图像)和 "action"(动作),则:input_features = {"observation.images.laptop": ...},output_features = {"action": ...}。

接着对模型进行实例化,也就是预训练模型的加载或新模型的初始化。

if cfg.pretrained_path:
    # 加载预训练策略(如从 HuggingFace Hub 或本地路径)
    kwargs["pretrained_name_or_path"] = cfg.pretrained_path
    policy = policy_cls.from_pretrained(**kwargs)  # 调用策略类的 from_pretrained 方法加载权重
else:
    # 初始化新模型(随机权重)
    policy = policy_cls(** kwargs)  # 传入配置和特征信息初始化模型结构

最后将模型迁移到目标设备,如cuda:0或cpu。

policy.to(cfg.device)  # 将模型移至目标设备(如 "cuda:0"、"cpu")
assert isinstance(policy, nn.Module)  # 确保返回的是 PyTorch 模型
return policy

创建优化器

#  创建优化器和学习率调度器
logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)  # 默认AdamW优化器
grad_scaler = GradScaler(device.type, enabled=cfg.policy.use_amp)  # 混合精度训练梯度缩放器

负责初始化训练核心组件:优化器(更新模型参数)、学习率调度器(动态调整学习率)和梯度缩放器(混合精度训练支持)。三者共同构成策略模型的“参数更新引擎”,直接影响训练效率和收敛稳定性。

优化器与调度器创建

通过工厂函数 make_optimizer_and_scheduler 动态创建优化器和学习率调度器,参数来源于配置 cfg 和策略模型 policy。该函数根据 TrainPipelineConfig 中的 optimizer 和 scheduler 配置,或策略预设(use_policy_training_preset=True),生成优化器和调度器。其中优化器默认使用 AdamW(带权重衰减的 Adam),参数包括学习率(cfg.optimizer.lr)、权重衰减系数(cfg.optimizer.weight_decay)等,优化对象为 policy.parameters()(策略模型的可学习参数);而调度器如果配置了 scheduler.type(如 "cosine"),则创建对应学习率调度器(如余弦退火调度器),否则返回 None。

若 cfg.use_policy_training_preset=True(默认),则直接使用策略内置的优化器参数(如 ACT 策略默认 lr=3e-4,weight_decay=1e-4),无需手动配置 optimizer 和 scheduler。

梯度缩放器初始化

GradScaler用于解决低精度(如 float16)训练中的梯度下溢问题。device.type参数模型所在设备类型("cuda"/"mps"/"cpu"),确保缩放器与设备匹配。参数enabled=cfg.policy.use_amp确定是否启用混合精度训练(由策略配置 use_amp 控制)。若为 False,缩放器将禁用(梯度不缩放)。

本质原理是混合精度训练时,前向传播使用低精度(加速计算),但梯度可能因数值过小而下溢(变为 0)。GradScaler 通过梯度缩放(放大损失值 → 梯度按比例放大 → 更新时反缩放)避免下溢,同时保证参数更新精度。

数据加载器配置

# 创建时序感知采样器(针对机器人轨迹数据)
if hasattr(cfg.policy, "drop_n_last_frames"):  # 如ACT策略需丢弃轨迹末尾帧
    sampler = EpisodeAwareSampler(
        dataset.episode_data_index,  # 轨迹索引信息
        drop_n_last_frames=cfg.policy.drop_n_last_frames,  # 丢弃每段轨迹末尾N帧
        shuffle=True,  # 轨迹内随机打乱
    )
else:
    sampler = None  # 普通随机采样

# 构建DataLoader(多线程加载+内存锁定)
dataloader = torch.utils.data.DataLoader(
    dataset,
    num_workers=cfg.num_workers,  # 数据加载线程数(加速IO)
    batch_size=cfg.batch_size,  # 批次大小
    sampler=sampler,
    pin_memory=device.type != "cpu",  # 内存锁定(加速CPU→GPU数据传输)
)
dl_iter = cycle(dataloader)  # 循环迭代器(数据集遍历完后自动重启)
  • EpisodeAwareSampler:确保采样的batch包含完整轨迹片段(避免时序断裂),适配机器人操作等时序依赖任务。
  • cycle(dataloader):将DataLoader转换为无限迭代器,支持训练步数(cfg.steps)远大于数据集长度的场景。

采样器选择

hasattr(cfg.policy, "drop_n_last_frames")用于检查模型中是否支持drop_n_last_frames 属性(如 ACT 策略需丢弃每段轨迹的最后 N 帧,避免无效数据)。如果支持,则启用时序感知采样EpisodeAwareSampler。器策略依赖时序连续的轨迹数据(如机器人操作的连贯动作序列)。核心参数如下:

  • dataset.episode_data_index:数据集轨迹索引(记录每段轨迹的起始/结束位置),确保采样时不跨轨迹断裂时序。
  • drop_n_last_frames=cfg.policy.drop_n_last_frames:丢弃每段轨迹的最后 N 帧(如因传感器延迟导致的无效帧)。
  • shuffle=True:轨迹内部随机打乱(但保持轨迹内时序连续性),平衡随机性与时序完整性。

若模型策略中不支持时序感知采样,那么则采用普通的随机采样,使用默认随机采样(shuffle=True,sampler=None),DataLoader 直接对数据集全局打乱。

数据加载管道

dataloader = torch.utils.data.DataLoader基于采样器配置,创建 PyTorch DataLoader,实现多线程并行数据加载,为训练循环提供高效的批次数据。

  • num_workers=cfg.num_workers:使用配置的线程数并行加载数据(如 8 线程),避免数据加载成为训练瓶颈。
  • pin_memory=True(当使用 GPU 时):将数据加载到固定内存页,加速数据从 CPU 异步传输到 GPU,减少等待时间。
  • drop_last=False:保留最后一个可能不完整的批次(尤其在小数据集场景,避免数据浪费)。
  • batch_size:控制每个批次的样本数量。batch_size=8 时,所有张量第一维度为 8。
  • sampler:控制采样顺序(时序连续/随机)。EpisodeAwareSampler 确保动作序列来自同一段轨迹。
  • num_workers:控制并行加载的子进程数,影响数据加载速度(非数据结构)。num_workers=8 比单线程加载快 5-10 倍(取决于硬件)。注意其并行只会影响数据加载速度,不会影响训练速度,当前数据加载是持续循环进行,而非一次性完成,但是这个相对训练时间。通过控制并行数据加载进程数,减少 GPU 等待时间,提升整体效率。

dataloader 是 torch.utils.data.DataLoader 类的实例,本质是一个可迭代对象(iterable),用于按批次加载数据。其核心作用是将原始数据集(LeRobotDataset)转换为训练可用的批次化数据,支持多线程并行加载、自定义采样顺序等功能。

DataLoader 通过以下步骤将原始数据集(LeRobotDataset)转换为批次化数据

  • 数据集索引采样:原始数据集 dataset(LeRobotDataset 实例),包含所有机器人轨迹数据。通过 sampler 参数(如 EpisodeAwareSampler 或默认随机采样)生成样本索引序列,决定数据加载顺序。
  • 多线程并行加载:num_workers=cfg.num_workers:启动 num_workers 个子进程并行执行 dataset.getitem(index),从磁盘/内存中加载单个样本数据(如读取图像、解析状态)。
  • 样本拼接:默认行为:DataLoader 使用 torch.utils.data.default_collate 函数,将多个单样本字典(来自不同子进程)拼接为批次字典,对每个特征键(如 "observation.images.laptop"),将所有单样本张量(shape=[C, H, W])堆叠为批次张量(shape=[B, C, H, W]);非张量数据(如列表)会被转换为张量或保留为列表(取决于数据类型)。
  • 内存优化:pin_memory=device.type != "cpu":当使用 GPU 时,启用内存锁定(pin memory),将加载的张量数据存入 CPU 的固定内存页,加速后续异步传输到 GPU 的过程(batch[key].to(device, non_blocking=True))

循环迭代器

dl_iter = cycle(dataloader)

将 DataLoader 转换为无限迭代器,支持按“训练步数”(cfg.steps)而非“ epochs” 训练。离线训练通常以固定步数(如 100,000 步)为目标,而非遍历数据集次数(epochs)。当数据集较小时,cycle(dataloader) 可在数据集遍历结束后自动重启,确保训练步数达标。

那dataloader输出的批次数据结构长什么样的?

当通过 next(dl_iter) 获取批次数据时(如代码中 first_batch = next(dl_iter)),返回的是一个字典类型的批次数据,结构如下:

    dl_iter = cycle(dataloader)

    first_batch = next(dl_iter)
    for key, value in first_batch.items():
        if isinstance(value, torch.Tensor):
            print(f"{key}: shape={value.shape}, dtype={value.dtype}")
        else:
            print(f"{key}: type={type(value)}")

打印如下:

observation.images.handeye: shape=torch.Size([8, 3, 480, 640]), dtype=torch.float32
observation.images.fixed: shape=torch.Size([8, 3, 480, 640]), dtype=torch.float32
action: shape=torch.Size([8, 100, 6]), dtype=torch.float32
observation.state: shape=torch.Size([8, 6]), dtype=torch.float32
timestamp: shape=torch.Size([8]), dtype=torch.float32
frame_index: shape=torch.Size([8]), dtype=torch.int64
episode_index: shape=torch.Size([8]), dtype=torch.int64
index: shape=torch.Size([8]), dtype=torch.int64
task_index: shape=torch.Size([8]), dtype=torch.int64
action_is_pad: shape=torch.Size([8, 100]), dtype=torch.bool
task: type=<class 'list'>

可以看到一个batch有8组数据,因为TrainPipelineConfig::batch_size设置的8,控制每个批次的样本数量,batch_size定义了每次参数更新是输入模型的样本数量,如这里的8,就是表示输入8个样本更新一次参数。batch_size的设定要根据模型大小、GPU内存、数据特征综合确定。

如果batch_size过小,批次(如 batch_size=1)的梯度受单个样本噪声影响大,导致参数更新方向不稳定,Loss 曲线剧烈震荡(如下图示例),难以收敛到稳定最小值,如果模型使用了BN层,过小的批次会导致BN统计量(均值/方差)估计不准,影响特征表达。同时GPU利用率低,训练速度会变慢;

如果batch_size过大最直接的影响就是GPU内存会溢出,同时会导致收敛速度变慢或陷入次优解。

一般情况下,若数据集样本量少(如仅 1k 样本),可设 batch_size=32(全量数据集的 3

batch_size 状态 典型问题 解决方案
过大(OOM) GPU 内存溢出,收敛慢 减小批次/降低图像分辨率/梯度累积
过小(<4) Loss 波动大,GPU 利用率低 增大批次至 8-32(需满足内存)
合理(8-32) 梯度稳定,GPU 利用率高(80 维持默认或根据模型/数据微调

batch_size 的核心是平衡内存、速度与收敛性,建议从默认值开始,结合硬件条件和训练监控动态调整。

开始训练

训练模式设置

policy.train()

将策略模型切换为训练模式,确保所有层(如 Dropout、BatchNorm)按训练逻辑运行。PyTorch模型模式有差异,分为训练模式和评估模式。

  • 训练模式:启用 Dropout(随机丢弃神经元防止过拟合)、BatchNorm 更新运行时统计量(均值/方差)。
  • 评估模式(policy.eval()):关闭 Dropout、BatchNorm 使用训练阶段累积的统计量。

在训练循环前显式调用,避免因模型残留评估模式导致训练效果异常(如 Dropout 未激活导致过拟合)。

指标跟踪初始化

train_metrics = {
    "loss": AverageMeter("loss", ":.3f"),          # 训练损失(格式:保留3位小数)
    "grad_norm": AverageMeter("grdn", ":.3f"),      # 梯度范数(格式:缩写"grdn",保留3位小数)
    "lr": AverageMeter("lr", ":0.1e"),              # 学习率(格式:科学计数法,保留1位小数)
    "update_s": AverageMeter("updt_s", ":.3f"),     # 单步更新耗时(格式:缩写"updt_s",保留3位小数)
    "dataloading_s": AverageMeter("data_s", ":.3f"),# 数据加载耗时(格式:缩写"data_s",保留3位小数)
}

通过 AverageMeter 类(来自 lerobot.utils.logging_utils)定义需跟踪的核心训练指标,支持实时平均计算和格式化输出训练信息,这个类实例最终通过参数传递给MetricsTracker。AverageMeter 功能为内部维护 sum(累积和)、count(样本数)、avg(平均值),通过 update 方法更新指标,并按指定格式(如 ":.3f")输出。例:每步训练后调用 train_metrics["loss"].update(loss.item()),自动累积并计算平均 loss。

train_tracker = MetricsTracker(
    cfg.batch_size,                # 批次大小(用于计算每样本指标)
    dataset.num_frames,            # 数据集总帧数(用于进度比例计算)
    dataset.num_episodes,          # 数据集总轨迹数(辅助日志上下文)
    train_metrics,                 # 上述定义的指标跟踪器字典
    initial_step=step              # 初始步数(支持断点续训时从上次步数开始跟踪)
)

创建 MetricsTracker 实例(来自 lerobot.utils.logging_utils),用于聚合、格式化和记录所有训练指标,主要的作用如下:

  • 指标更新:训练循环中通过 train_tracker.loss = loss.item() 便捷更新单个指标。
  • 平均计算:自动对 AverageMeter 指标进行滑动平均(如每 log_freq 步输出平均 loss)。
  • 日志输出:调用 logging.info(train_tracker) 时,按统一格式打印所有指标(如 loss: 0.523 | grdn: 1.234 | lr: 3.0e-4)。
  • 断点续训支持:通过 initial_step=step 确保从断点恢复训练时,指标统计不重复计算。

总结下该代码段是训练前的关键初始化步骤,通过 AverageMeter 和 MetricsTracker 构建训练全流程的指标监控框架,为后续训练循环中的指标更新、日志记录和性能调优提供参考。

启动循环训练

logging.info("Start offline training on a fixed dataset")
# 启动训练循环,从当前 step(初始为 0 或断点续训的步数)迭代至 cfg.steps(配置文件中定义的总训练步数,如 100,000 步)。
for _ in range(step, cfg.steps):
    # 1. 加载数据batch
    batch = next(dl_iter)  # 从循环迭代器获取batch
    batch = {k: v.to(device, non_blocking=True) for k, v in batch.items() if isinstance(v, torch.Tensor)}  # 数据移至设备

    # 2. 单步训练(前向传播→损失计算→反向传播→参数更新)
    train_tracker, output_dict = update_policy(
        train_tracker,
        policy,
        batch,
        optimizer,
        cfg.optimizer.grad_clip_norm,  # 梯度裁剪阈值(默认1.0)
        grad_scaler=grad_scaler,
        use_amp=cfg.policy.use_amp,  # 启用混合精度训练
    )

    # 3. 训练状态更新与记录
    step += 1
    if is_log_step:  # 按log_freq记录训练指标(loss、梯度范数等)
        logging.info(train_tracker)
    if is_saving_step:  # 按save_freq保存模型checkpoint
        save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)
    if is_eval_step:  # 按eval_freq在环境中评估策略性能
        eval_info = eval_policy(eval_env, policy, cfg.eval.n_episodes)  # 执行评估
  • update_policy(...):核心训练函数,实现:
    -- 前向传播:policy.forward(batch) 计算损失(如动作预测MSE损失+VAE KL散度)。
    -- 反向传播:grad_scaler.scale(loss).backward() 缩放损失梯度(混合精度训练)。
    -- 梯度裁剪:torch.nn.utils.clip_grad_norm_ 限制梯度范数(防止梯度爆炸)。
    -- 参数更新:grad_scaler.step(optimizer) 更新模型参数,optimizer.zero_grad() 清空梯度缓存。
  • save_checkpoint(...):保存模型权重、优化器状态、学习率调度器状态和当前步数,支持断点续训。
  • eval_policy(...):在仿真环境中测试策略性能,计算平均奖励、成功率等指标,并保存评估视频。

上面代码是train 函数的核心训练循环,负责执行离线训练的完整流程:从数据加载、模型参数更新,到指标记录、模型保存与策略评估。循环以“训练步数”(step)为驱动,从初始步数(0 或断点续训的步数)运行至目标步数(cfg.steps),确保模型充分训练并实时监控性能。

循环初始化,按步数迭代

for _ in range(step, cfg.steps):

启动训练循环,从当前 step(初始为 0 或断点续训的步数)迭代至 cfg.steps(配置文件中定义的总训练步数,如 100,000 步)。

数据加载与耗时记录

start_time = time.perf_counter()
batch = next(dl_iter)  # 从无限迭代器获取批次数据
train_tracker.dataloading_s = time.perf_counter() - start_time  # 记录数据加载耗时
  • dl_iter = cycle(dataloader):cycle 将 DataLoader 转换为无限迭代器,数据集遍历完毕后自动重启,确保训练步数达标(而非受限于数据集大小)。
  • dataloading_s 指标:通过 train_tracker 记录单批次加载时间,用于监控数据加载是否成为训练瓶颈(若该值接近模型更新时间 update_s,需优化数据加载)。

批次数据设备迁移

for key in batch:
    if isinstance(batch[key], torch.Tensor):
        batch[key] = batch[key].to(device, non_blocking=True)

将批次中的张量数据(如图像、动作)异步传输到目标设备(GPU/CPU)。其中non_blocking=True:启用异步数据传输,允许 CPU 在数据传输至 GPU 的同时执行后续计算(如模型前向传播准备),提升硬件利用率。

模型参数更新

train_tracker, output_dict = update_policy(
    train_tracker,
    policy,
    batch,
    optimizer,
    cfg.optimizer.grad_clip_norm,
    grad_scaler=grad_scaler,
    lr_scheduler=lr_scheduler,
    use_amp=cfg.policy.use_amp,
)

调用 update_policy 函数执行单次参数更新,流程包括:

  • 前向传播:计算模型输出和损失(loss = policy.forward(batch))。
  • 混合精度训练:通过 torch.autocast 启用低精度计算(若 use_amp=True),加速训练并节省显存。
  • 反向传播:梯度缩放(grad_scaler.scale(loss).backward())避免数值下溢,梯度裁剪(clip_grad_norm_)防止梯度爆炸。
  • 参数更新:优化器.step() 更新参数,学习率调度器.step() 动态调整学习率。
  • 指标记录:返回更新后的训练指标(loss、梯度范数、学习率等)。

步数递增与状态跟踪

step += 1  # 步数递增(在更新后,确保日志/评估对应已完成的更新)
train_tracker.step()  # 更新指标跟踪器的当前步数

step用于更新当前系统的训练步数,是整个训练过程的基础计算器。train_tracker.step()通知训练指标跟踪器进入新阶段。

is_log_step = cfg.log_freq > 0 and step 
is_saving_step = step 
is_eval_step = cfg.eval_freq > 0 and step 

通过计算这些标志,用于后续的逻辑控制。

  • is_log_step:日志打印标志位,当配置的日志频率(cfg.log_freq)大于0且当前步数是cfg.log_freq 倍数时激活日志的打印。cfg.log_freq是来自用户的命令行参数的配置。默认是200步打印一次。
  • is_saving_step:保存标志位,当配置的保存频率相等或是其倍数时激活保存。默认是20000步保存一次。
  • is_eval_step:模型评估标志为。默认是20000评估一次。

标志通过模块化设计实现了训练过程的精细化控制,参数均来自TrainConfig配置类。

训练指标日志(按频率触发)

is_log_step = cfg.log_freq > 0 and step 
if is_log_step:
    logging.info(train_tracker)  # 控制台打印平均指标(如 loss: 0.523 | grdn: 1.234)
    if wandb_logger:
        wandb_log_dict = train_tracker.to_dict()
        if output_dict:
            wandb_log_dict.update(output_dict)  # 合并模型输出指标(如策略特定的辅助损失)
        wandb_logger.log_dict(wandb_log_dict, step)  # 上传至 WandB
    train_tracker.reset_averages()  # 重置平均计数器,准备下一轮统计

根据前面计算的is_log_step满足时,默认是200步,则调用logging.info()输出训练的指标,如损失、准确率。如果启动了wandb,则将跟踪器数据转换为字典,合并额外输出(output_dict)后记录到WandB,关联当前步数。最后调用train_tracker.reset_averages() 清除跟踪累计值,为下一周期计数做准备。

模型 checkpoint 保存(按频率触发)

if cfg.save_checkpoint and is_saving_step:
    logging.info(f"Checkpoint policy after step {step}")
    checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
    save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler)  # 保存模型、优化器、调度器状态
    update_last_checkpoint(checkpoint_dir)  # 更新 "last" 软链接指向最新 checkpoint
    if wandb_logger:
        wandb_logger.log_policy(checkpoint_dir)  # 上传 checkpoint 至 WandB  artifacts

默认是启动了save_checkpoint, 每20000步将训练结束的状态进行保存一次,便于支持断点续训和模型版本管理。其中save_checkpoint函数和输出目录结构如下:

    005000/  #  training step at checkpoint
    ├── pretrained_model/
    │   ├── config.json  # 存储模型架构定义,包括网络层数、隐藏维度、激活函数类型等拓扑结构信息
    │   ├── model.safetensors  # 采用SafeTensors格式存储模型权重,包含所有可学习参数的张量数据,具有内存安全和高效加载特性
    │   └── train_config.json  # 序列化的训练配置对象,包含超参数(学习率、批大小)、数据路径、预处理策略等完整训练上下文
    └── training_state/
        ├── optimizer_param_groups.json  #  记录参数分组信息,包括不同层的学习率、权重衰减等差异化配置
        ├── optimizer_state.safetensors  # 保存优化器动态状态,如Adam的一阶矩(momentum)和二阶矩(variance)估计,SGD的动量缓冲区等
        ├── rng_state.safetensors  # 捕获PyTorch全局RNG和CUDA(如使用)的随机数状态,确保恢复训练时数据采样和权重初始化的一致性
        ├── scheduler_state.json  #学习率调度器的内部状态,包括当前调度阶段、预热状态、周期信息等
        └── training_step.json  #当前训练迭代次数,用于精确定位训练数据读取位置

pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)  # 保存模型架构和权重
cfg.save_pretrained(pretrained_dir)      # 保存训练配置
save_training_state(checkpoint_dir, step, optimizer, scheduler)  # 保存训练动态状态

update_last_checkpoint用于维护一个指向最新检查点目录的符号链接(symlink),在训练过程中跟踪和管理最新的模型检查点。

def update_last_checkpoint(checkpoint_dir: Path) -> Path:
    # 1. 构建符号链接路径:在检查点父目录下创建名为 LAST_CHECKPOINT_LINK 的链接
    last_checkpoint_dir = checkpoint_dir.parent / LAST_CHECKPOINT_LINK
    # 2. 如果符号链接已存在,则先删除旧链接
    if last_checkpoint_dir.is_symlink():
        last_checkpoint_dir.unlink()

    # 3. 计算当前检查点目录相对于父目录的相对路径
    relative_target = checkpoint_dir.relative_to(checkpoint_dir.parent)

    # 4. 创建新的符号链接,指向当前检查点目录
    last_checkpoint_dir.symlink_to(relative_target)

策略评估(按频率触发)

is_eval_step = cfg.eval_freq > 0 and step 
if cfg.env and is_eval_step:
    step_id = get_step_identifier(step, cfg.steps)
    logging.info(f"Eval policy at step {step}")
    with torch.no_grad(), torch.autocast(...) if use_amp else nullcontext():
        eval_info = eval_policy(eval_env, policy, cfg.eval.n_episodes, ...)  # 在环境中执行策略评估
    # 记录评估指标(平均奖励、成功率、耗时)并上传至 WandB
    eval_tracker = MetricsTracker(...)
    eval_tracker.avg_sum_reward = eval_info["aggregated"]["avg_sum_reward"]
    ...
    logging.info(eval_tracker)
    if wandb_logger:
        wandb_logger.log_dict(...)
        wandb_logger.log_video(...)  # 上传评估视频

在强化学习和机器人控制领域,策略评估(Policy Evaluation) 是指在特定环境中系统性测试智能体策略(Policy)性能的过程。它通过执行预设数量的评估回合,收集关键指标(如奖励、成功率、执行时间等),客观衡量策略的实际效果。总结一下就是有以下4个核心作用。

  • 性能监控:跟踪训练过程中策略性能的变化趋势,判断模型是否收敛或退化
  • 过拟合检测:通过独立评估集验证策略泛化能力,避免在训练数据上过拟合
  • 决策依据:基于评估指标决定是否保存模型、调整超参数或终止训练
  • 行为分析:通过可视化记录(如代码中的评估视频)观察策略执行细节,发现异常行为

负责在指定训练步骤对策略进行系统性评估,并记录关键指标与可视化结果。代码块实现了周期性策略评估机制,当满足环境配置(cfg.env)和评估步骤标志(is_eval_step)时才会触发。

按 cfg.eval_freq(这里默认是20000步)在环境中评估策略性能,核心功能:

  • 无梯度推理:torch.no_grad() 禁用梯度计算,节省显存并加速评估。
  • 指标计算:通过 eval_policy 获取平均奖励(avg_sum_reward)、成功率(pc_success)等关键指标。
  • 可视化:保存评估视频(如机器人执行任务的轨迹)并上传至 WandB,直观观察策略行为。

训练更新

训练环境准备

    device = get_device_from_parameters(policy)
    policy.train()
    with torch.autocast(device_type=device.type) if use_amp else nullcontext():
        loss, output_dict = policy.forward(batch)

先调用get_device_from_parameters从模型参数自动推断当前计算设备,确保后续张量操作与模型参数在同医社保上,避免跨设备数据传输错误。

接着调用policy.train()将模型切换为训练模式,主要是启动dropout层随机失活功能,激活BatchNormalization层的移动平均统计更新等,与评估模式区别推理时需要调用policy.eval()。

最后使用了with预计用于创建一个上下文管理器,在进入代码块是调用管理器enter()方法,退出时调用exit()方法,这种机制确保资源被正确获取和释放或在特定上下文中执行代码,其代码等价于如下。

if use_amp:
    context_manager = torch.autocast(device_type=device.type)
else:
    context_manager = nullcontext()
with context_manager:
    loss, output_dict = policy.forward(batch)

也就是当条件激活use_amp启用混合精度训练,否则使用空上下文,torch.autocast会动态选择最优精度,对数值稳定性要求高的操作保留FP32。

最后就是调用loss, output_dict = policy.forward(batch)这是模型前向计算的核心函数,返回的是损失值和输出字典。

batch包含的是训练数据(图像、关节动作等),policy.forward()处理输入并生成预测值并计算与真实值的差距loss,然后loss将用于后续梯队的计算。

梯队计算

    grad_scaler.scale(loss).backward()
    grad_scaler.unscale_(optimizer)
    grad_norm = torch.nn.utils.clip_grad_norm_(
        policy.parameters(),
        grad_clip_norm,
        error_if_nonfinite=False,
    )

这段代码是训练处理计算梯队的核心流程,主要就是用于计算梯队。

首先grad_scaler.scale(loss).backward()是将损失放大然后间接放大梯队,grad_scaler是PyTorch的torch.cuda.amp.GradScaler对象,用于解决FP16训练时的梯度下溢问题,scale(loss)将损失值放大scaler倍(通常是2^n),避免梯度在反向传播中因数值过小而下溢为0,最后backward()触发反向传播,计算所有可训练参数的梯度(此时梯度已被放大)。

其次再调用grad_scaler.unscale_(optimizer)将放大的梯度恢复到原始尺寸,由于损失被放大,方向传播的梯队也被同等放大,使用unscale_对所有参数梯度执行反缩放,相当于处于scaler。

最后是调用 torch.nn.utils.clip_grad_norm_对梯度进行裁剪,防止极端梯队导致参数的震荡。

参数优化与更新

    with lock if lock is not None else nullcontext():
        grad_scaler.step(optimizer)
    grad_scaler.update()
    optimizer.zero_grad()
    if lr_scheduler is not None:
        lr_scheduler.step()

先试用with lock条件性启动线程锁,确保参数更新的线程安全。

grad_scaler.step(optimizer)执行参数更新,为了下一次重新计算损失。

接着调用grad_scaler.update()动态调整梯度的缩放因子,主要是自动平衡FP16的数值范围限制,在避免梯度下溢和溢出之间找到最优缩放比例。

最后就是调用optimizer.zero_grad()清除优化器中所有参数梯度缓存,因为Pytorch梯队计算默认都是累积模式(param.grad会累加),需要手动清零与loss.backward()配对,形成成"清零→前向→反向→更新→清零"的完整循环。

lr_scheduler.step()是按照批次更新学习率,可以使用循环学习率、余弦退化调度以及梯队的自适应调度等策略。

对于训练的闭环可以看成:scale(loss)→backward()→unscale_()→clip_grad_norm_()→step()→update()形成完整的混合精度训练流程,解决FP16数值范围限制问题。

最后update_policy返回的是train_metrics和output_dict。前者是传递给日志系统(如TensorBoard/WandB)进行可视化,后者是包含模型前向传播的详细输出(如预测值、中间特征),可用于后续分析

训练状态维护

    if has_method(policy, "update"):
        policy.update()
    train_metrics.loss = loss.item()
    train_metrics.grad_norm = grad_norm.item()
    train_metrics.lr = optimizer.param_groups[0]["lr"]
    return train_metrics, output_dict

如果policy有update的方法,则调用进行更新,这里主要是做兼容性设计。

最后train_metrics主要是记录量化训练过程中的关键特征,为监控、调试和优化提供数据支撑。

训练收尾

# 训练结束后清理
if eval_env:
    eval_env.close()  # 关闭评估环境(释放资源)
logging.info("End of training")

# 推送模型至HuggingFace Hub(若启用)
if cfg.policy.push_to_hub:
    policy.push_model_to_hub(cfg)  # 保存模型配置、权重至Hub,支持后续部署
  • eval_env.close():关闭仿真环境(如Gym/DM Control),释放显存和CPU资源。
  • policy.push_model_to_hub(cfg):将训练好的模型(权重+配置)推送至HuggingFace Hub,支持跨设备共享和部署。