lerobot策略优化器

torch.optim简介

在学校lerobot的策略优化器前,我们先再复习一下什么是优化器。

什么优化器

优化器官方解释就是在深度学习中让损失函数通过梯度下降思想逐步调整参数以达到最小损失。

简单理解优化器的就是更新计算参数的,根据损失函数的梯度方向调整模型权重和偏置值,公式为:新参数 = 旧参数 - 学习率 × 梯度。通过迭代逐步逼近最优解。在文章https://www.laumy.tech/2050.html我们已经探讨过常用的优化算法。

接下来我们再来从PyTorch使用的角度复习一下。torch.optim 是 PyTorch 官方优化器模块,提供了 SGD、Adam、AdamW 等主流优化算法的实现,所有的优化器都继承自基类 torch.optim.Optimizer。其核心作用是如下:

  • 自动化参数更新:根据反向传播计算的梯度,按特定优化策略(如 Adam 的自适应学习率)更新模型参数。
  • 统一接口抽象:通过 Optimizer 基类封装不同算法,提供一致的使用流程(zero_grad() 清空梯度 → step() 更新参数)。

在Pytorch中优化器统一封装成torch.optim接口调用,可以有以下优势。

  • 避免重复实现复杂算法:无需手动编写 Adam 的动量、二阶矩估计等逻辑,直接调用成熟接口。
  • 灵活支持训练需求:支持单/多参数组优化(如不同模块用不同学习率)、学习率调度、梯度清零等核心训练逻辑。
  • 工程化与可维护性:通过统一接口管理超参数(lr、weight_decay),便于实验对比与代码复用。

torch.optim怎么用

Step 1:定义模型与优化器

import torch
from torch import nn, optim

# 1. 定义模型(示例:简单线性层)
model = nn.Linear(in_features=10, out_features=2)

# 2. 初始化优化器:传入模型参数 + 超参数
optimizer = optim.Adam(
    params=model.parameters(),  # 待优化参数(PyTorch 参数迭代器)
    lr=1e-3,                    # 学习率(核心超参数)
    betas=(0.9, 0.999),         # Adam 动量参数(控制历史梯度影响)
    weight_decay=0.01           # 权重衰减(L2 正则化,可选)
)

定义了一个optimizer使用的是Adam优化器,该优化器学习率设置为1e-3,权重衰减为0.01。

Step 2:训练循环

# 模拟输入数据(batch_size=32,特征维度=10)
inputs = torch.randn(32, 10)
targets = torch.randn(32, 2)  # 模拟目标值

# 训练循环
for epoch in range(10):
    # ① 清空过往梯度(必须!否则梯度累积导致更新异常)
    optimizer.zero_grad()

    # ② 前向传播 + 计算损失
    outputs = model(inputs)
    loss = nn.MSELoss()(outputs, targets)  # 均方误差损失

    # ③ 反向传播计算梯度 + 优化器更新参数
    loss.backward()  # 自动计算所有可训练参数的梯度
    optimizer.step()  # 根据梯度更新参数

    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

loss是损失,通过调用loss.backward()进行反向传播Pytorch就自动把梯度计算好保存了,然后调用关键的一部optimizer.step()就可以执行参数更新了(即新参数=旧参数-学习率*梯度),需要注意的是,在调用loss.backward()进行反向传播计算梯度时,要先调用optimizer.zero_grad()把之前的梯度值情况,因此每计算一次梯度都是被保存,不情况会导致梯度累积。

Step 3:差异化优化参数分组

# 定义参数组(不同模块用不同学习率)
optimizer = optim.Adam([
    {
        "params": model.backbone.parameters(),  # backbone 参数
        "lr": 1e-5  # 小学习率微调
    },
    {
        "params": model.head.parameters(),       # 任务头参数
        "lr": 1e-3  # 大学习率更新
    }
], betas=(0.9, 0.999))  # 公共超参数(所有组共享)

上面是针对同一个模型内不同模块使用不同的超参数lr。

抽象基类OptimizerConfig

OptimizerConfig 是所有优化器配置的抽象基类,通过 draccus.ChoiceRegistry 实现子类注册机制(类似插件系统),为新增优化器类型提供统一接口。

@dataclass
class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):
    lr: float
    weight_decay: float
    grad_clip_norm: float

    @property
    def type(self) -> str:
        return self.get_choice_name(self.__class__)

    @classmethod
    def default_choice_name(cls) -> str | None:
        return "adam"

    @abc.abstractmethod
    def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
        raise NotImplementedError

继承关系

class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC):

OptimizerConfig 继承abc.ABC和draccus.ChoiceRegistry,前者标记为抽象基类,强制子类实现build的抽象方法,确保接口的一致性。后者提供子类注册机制,通过@OptimizerConfig.register_subclass("名称") 将子类与优化器类型绑定(如 "adam" → AdamConfig),支持配置驱动的动态实例化。

核心属性

lr: float                  # 学习率(核心超参数)
weight_decay: float        # 权重衰减(L2正则化系数)
grad_clip_norm: float      # 梯度裁剪阈值(防止梯度爆炸)

上面3个参数都是优化器的基础配置,避免子类重复定义。

  • lr/weight_decay:直接传递给 torch.optim 优化器(如 Adam 的 lr 参数)
  • grad_clip_norm:不参与优化器创建,而是在训练流程中用于梯度裁剪(如 train.py 中 torch.nn.utils.clip_grad_norm_)

核心方法

(1)type属性用于表示优化器类型

@property
def type(self) -> str:
    return self.get_choice_name(self.__class__)

通过 draccus.ChoiceRegistry 的 get_choice_name 方法,获取子类注册的优化器类型名称(如 AdamConfig 的 type 为 "adam")。

在实际应用中,在配置解析时,通过 type 字段(如 {"type": "adam"})即可匹配到对应子类(AdamConfig),实现“配置→实例”的自动映射。

(2)default_choice_name默认优化器类型

@classmethod
def default_choice_name(cls) -> str | None:
    return "adam"

当配置中为显式指定type时,默认是用adam类型即AdamConfig,旨在简化用户配置,无需手动指定常见优化器类型。

(3)build抽象接口,优化器创建接口

@abc.abstractmethod
def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
    """Build the optimizer. It can be a single optimizer or a dictionary of optimizers."""
    raise NotImplementedError

强制子类实现build的方法。

实例化子类

optimizers.py中一共定义了4种优化器配置子类,adam,adamw,sgd, multi_adam,其中前3个是单参数优化器,最后一个是多参数优化器,最终均通过build方法创建torch.optim实例

单参数优化器

@OptimizerConfig.register_subclass("adam")
@dataclass
class AdamConfig(OptimizerConfig):
    lr: float = 1e-3
    betas: tuple[float, float] = (0.9, 0.999)
    eps: float = 1e-8
    weight_decay: float = 0.0
    grad_clip_norm: float = 10.0

    def build(self, params: dict) -> torch.optim.Optimizer:
        kwargs = asdict(self)  # 将 dataclass 字段转为字典
        kwargs.pop("grad_clip_norm")  # 移除梯度裁剪阈值(非优化器参数)
        return torch.optim.Adam(params, **kwargs)  # 创建 PyTorch Adam 实例

AdamConfig 是 OptimizerConfig 的核心子类,封装了 Adam 优化器的配置与实例化逻辑,通过 draccus 注册机制与工程训练流程深度集成。

@OptimizerConfig.register_subclass("adam")将 AdamConfig 类与字符串 "adam" 绑定,实现 配置驱动的动态实例化,当配置文件中 optimizer.type: "adam" 时,draccus 会自动解析并实例化 AdamConfig。继承自 OptimizerConfig 的 ChoiceRegistry 机制,确保子类可通过 type 字段被唯一标识。

@dataclass自动生成 initrepr 等方法,简化超参数管理,无需手动编写构造函数,直接通过类字段定义超参数(如 lr=1e-3)。

在AdamConfig中默认初始化了一些参数值,其中lr、betas、eps、weight_decay直接对应 torch.optim.Adam 的参数,通过 build 方法传递给 PyTorch 优化器,而grad_clip_norm不参与优化器创建,而是用于训练时的梯度裁剪(如 train.py 中 torch.nn.utils.clip_grad_norm_),实现“优化器参数”与“训练流程参数”的职责分离。

在最后的build方法中,调用torch.optim.Adam(params, **kwargs) 实例化优化器。在此之前,先调用asdict(self)将 AdamConfig 实例的字段(如 lr、betas)转换为字典 {"lr": 1e-3, "betas": (0.9, 0.999), ...},再调用kwargs.pop("grad_clip_norm")剔除 grad_clip_norm(梯度裁剪阈值),因其不属于torch.optim.Adam 的参数(优化器仅负责参数更新,梯度裁剪是训练流程的独立步骤)。

多参数优化器

@OptimizerConfig.register_subclass("multi_adam")
@dataclass
class MultiAdamConfig(OptimizerConfig):
    optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict)  # 组内超参数

    def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
        optimizers = {}
        for name, params in params_dict.items():
            # 合并默认超参数与组内超参数(组内参数优先)
            group_config = self.optimizer_groups.get(name, {})
            optimizer_kwargs = {
                "lr": group_config.get("lr", self.lr),  # 组内 lr 或默认 lr
                "betas": group_config.get("betas", (0.9, 0.999)),
                "weight_decay": group_config.get("weight_decay", self.weight_decay),
            }
            optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)  # 为每组创建独立优化器
        return optimizers  # 返回优化器字典:{"backbone": optimizer1, "head": optimizer2, ...}

MultiAdamConfig 是 OptimizerConfig 的关键子类,专为多参数组优化场景设计,支持为模型不同模块(如 backbone 与 head)创建独立的 Adam 优化器,实现差异化超参数配置。

首先跟前面单参数的属性不同点是多了一个optimizer_groups,这是一个超参数字典,存储多组不同的超参数,示例如下。

optimizer_groups={
    "backbone": {"lr": 1e-5, "weight_decay": 1e-4},  # 低学习率微调 backbone
    "head": {"lr": 1e-3, "betas": (0.95, 0.999)}      # 高学习率更新 head,自定义动量参数
}

build的方法主要逻辑如下:

def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]:
    optimizers = {}
    for name, params in params_dict.items():
        # 1. 获取组内超参数(无则使用默认)
        group_config = self.optimizer_groups.get(name, {})
        # 2. 合并默认与组内超参数
        optimizer_kwargs = {
            "lr": group_config.get("lr", self.lr),
            "betas": group_config.get("betas", (0.9, 0.999)),
            "eps": group_config.get("eps", 1e-5),
            "weight_decay": group_config.get("weight_decay", self.weight_decay),
        }
        # 3. 为该组创建 Adam 优化器
        optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs)
    return optimizers  # 返回:{组名: 优化器实例}

其中params_dict是超参数组的来源,是字典类型,键为参数组名称(需与 optimizer_groups 键匹配),值为该组参数列表(如模型某模块的 parameters())。通常是策略类的get_optim_params方法提供,如下:

# 策略类中拆分参数组(示例逻辑)
def get_optim_params(self):
    return {
        "backbone": self.backbone.parameters(),
        "head": self.head.parameters()
    }

主要的核心逻辑是对于每个参数组,优先使用 optimizer_groups 中的超参数(如 group_config.get("lr")),无则回退到默认值(如 self.lr),然后为每个参数组创建独立的 torch.optim.Adam 实例,确保参数更新互不干扰。

优化器状态管理

状态保存

将优化器的某一个时刻参数进行存储,方便过程查看以及重新加载模型训练等等。

def save_optimizer_state(
    optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer],  # 优化器实例或字典
    save_dir: Path  # 根保存目录
) -> None:
    if isinstance(optimizer, dict):
        # 1. 处理多参数优化器字典(如 MultiAdamConfig 创建的优化器)
        for name, opt in optimizer.items():  # 遍历优化器名称与实例(如 "backbone": opt1)
            optimizer_dir = save_dir / name  # 创建子目录:根目录/优化器名称(如 save_dir/backbone)
            optimizer_dir.mkdir(exist_ok=True, parents=True)  # 确保目录存在(含父目录创建)
            _save_single_optimizer_state(opt, optimizer_dir)  # 委托单优化器保存逻辑
    else:
        # 2. 处理单参数优化器(如 AdamConfig 创建的优化器)
        _save_single_optimizer_state(optimizer, save_dir)  # 直接使用根目录保存

区分单参数和多参数优化器,对应多参数优化器为每个优化器创建独立子目录(如 save_dir/backbone、save_dir/head),避免不同优化器的状态文件冲突。如果是单优化器,则直接调用 _save_single_optimizer_state,状态文件保存于 save_dir 根目录,结构简洁。

def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
    """Save a single optimizer's state to disk."""
    # 1. 获取优化器完整状态字典(含参数组和内部状态)
    state = optimizer.state_dict()

    # 2. 分离参数组(超参数配置)与剩余状态(张量数据)
    param_groups = state.pop("param_groups")  # 参数组:学习率、权重衰减等超参数(非张量)
    flat_state = flatten_dict(state)  # 剩余状态:动量、二阶矩等张量(展平嵌套字典,便于序列化)

    # 3. 保存张量状态(safetensors)与参数组(JSON)
    save_file(flat_state, save_dir / OPTIMIZER_STATE)  # 张量数据:高效二进制存储(如 "optimizer_state.safetensors")
    write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS)  # 参数组:JSON 格式(如 "optimizer_param_groups.json"方便可视化查看)

存储张量和非张量的格式

  • param_groups:包含优化器的超参数配置(如 lr、weight_decay、betas),是列表嵌套字典的结构,存储为JSON格式,JSON 序列化后可直接查看超参数,便于训练过程追溯。其文件名为optimizer_param_groups.json。
  • state:包含优化器的内部状态张量(如 Adam 的 exp_avg、exp_avg_sq 动量缓冲区),是嵌套字典结构,通过 flatten_dict 展平后用 safetensors 保存,safetensors专为张量设计的存储格式,支持高效读写、内存映射,避免 PyTorch torch.save 的 pickle 兼容性问题。其文件名为optimizer_state.safetensors。

状态加载

def load_optimizer_state(
    optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer],  # 待恢复的优化器(单实例或字典)
    save_dir: Path  # 状态文件根目录
) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]:
    if isinstance(optimizer, dict):
        # 1. 处理多优化器字典(如 MultiAdamConfig 创建的优化器)
        loaded_optimizers = {}
        for name, opt in optimizer.items():  # 遍历优化器名称与实例(如 "backbone": opt1)
            optimizer_dir = save_dir / name  # 子目录路径:根目录/优化器名称(如 save_dir/backbone)
            if optimizer_dir.exists():  # 仅当目录存在时加载(避免新增优化器时出错)
                loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir)
            else:
                loaded_optimizers[name] = opt  # 目录不存在时返回原优化器
        return loaded_optimizers
    else:
        # 2. 处理单优化器(如 AdamConfig 创建的优化器)
        return _load_single_optimizer_state(optimizer, save_dir)  # 直接从根目录加载

同样是区分单参数和多参数,对于多参数组根据save_dir / name 定位每个优化器的独立子目录(与 save_optimizer_state 的保存结构对应),如果是单参数优化器直接调用 _load_single_optimizer_state,从 save_dir 根目录加载状态文件。

def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer:
    """Load a single optimizer's state from disk."""
    # 1. 获取当前优化器的状态字典结构(用于校验与适配)
    current_state_dict = optimizer.state_dict()

    # 2. 加载并恢复张量状态(safetensors → 嵌套字典)
    flat_state = load_file(save_dir / OPTIMIZER_STATE)  # 加载展平的张量状态(如 "optimizer_state.safetensors")
    state = unflatten_dict(flat_state)  # 恢复为嵌套字典(与保存时的 flatten_dict 对应)

    # 3. 处理优化器内部状态(如动量缓冲区)
    if "state" in state:
        # 将字符串键转为整数(safetensors 保存时键为字符串,PyTorch 期望参数索引为整数)
        loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}}
    else:
        loaded_state_dict = {"state": {}}  # 新创建的优化器可能无状态,初始化为空

    # 4. 处理参数组(超参数配置,如学习率、权重衰减)
    if "param_groups" in current_state_dict:
        # 从 JSON 反序列化参数组,并确保结构与当前优化器匹配
        param_groups = deserialize_json_into_object(
            save_dir / OPTIMIZER_PARAM_GROUPS,  # 加载参数组 JSON 文件(如 "optimizer_param_groups.json")
            current_state_dict["param_groups"]  # 以当前参数组结构为模板,确保兼容性
        )
        loaded_state_dict["param_groups"] = param_groups

    # 5. 将恢复的状态字典加载到优化器
    optimizer.load_state_dict(loaded_state_dict)
    return optimizer

张量的状态恢复部分,通过 unflatten_dict 将保存时展平的状态(flatten_dict)恢复为嵌套字典,匹配 PyTorch 优化器状态的原始结构。接着通过state["state"] 的键在保存时被序列化为字符串(如 "0"),加载时需转回整数(如 0),以匹配 PyTorch 参数索引的整数类型。

对于参数组恢复先通过JSON 反序列化,deserialize_json_into_object 将 JSON 文件中的参数组配置(如 [{"lr": 1e-3, ...}, ...])反序列化为 Python 对象。再以当前优化器的 current_state_dict["param_groups"] 为模板,确保加载的参数组与当前优化器的参数结构兼容(如参数组数量、超参数字段匹配),避免因配置变更导致的加载失败。

最后合并 state(张量数据)和 param_groups(超参数配置)为完整状态字典,通过 optimizer.load_state_dict 完成优化器状态恢复。

工程调用

创建流程

# 1. 策略提供参数(如多参数组)
params = policy.get_optim_params()  # 例如:{"backbone": [params1...], "head": [params2...]}

# 2. 配置解析:根据 config.optimizer.type 实例化对应子类(如 MultiAdamConfig)
cfg.optimizer = MultiAdamConfig(
    lr=1e-3,
    optimizer_groups={"backbone": {"lr": 1e-5}, "head": {"lr": 1e-3}}
)

# 3. 创建优化器实例
optimizer = cfg.optimizer.build(params)  # 返回:{"backbone": Adam, "head": Adam}

训练流程

def update_policy(...):
    # 前向传播计算损失
    loss, output_dict = policy.forward(batch)
    # 反向传播与梯度裁剪
    grad_scaler.scale(loss).backward()
    grad_scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(policy.parameters(), grad_clip_norm)
    # 参数更新
    grad_scaler.step(optimizer)
    optimizer.zero_grad()  # 清空梯度