torch.nn.ParameterDict 是 PyTorch 中的一个容器类,用于存储和管理一组 torch.nn.Parameter 对象。它类似于 Python 的字典(dict),但专门用于存储 Parameter 对象,并且可以与 PyTorch 的神经网络模块(torch.nn.Module)无缝集成。

主要特点

  1. 键值对存储ParameterDict 以键值对的形式存储 Parameter 对象,键是字符串,值是 Parameter 对象。
  2. 动态管理:可以动态地添加、删除和更新 Parameter 对象。
  3. Module 集成ParameterDict 可以与 torch.nn.Module 一起使用,方便地管理模型中的参数。
  4. 序列化支持ParameterDict 支持序列化和反序列化,可以方便地保存和加载模型参数。

常用方法

  • add_module(name, parameter):向 ParameterDict 中添加一个 Parameter 对象,并指定一个名称。
  • update(parameters):将另一个 ParameterDict 或字典中的参数更新到当前 ParameterDict 中。
  • keys():返回 ParameterDict 中所有键的列表。
  • values():返回 ParameterDict 中所有 Parameter 对象的列表。
  • items():返回 ParameterDict 中所有键值对的列表。
  • pop(key):移除并返回指定键对应的 Parameter 对象。
  • clear():清空 ParameterDict 中的所有参数。

示例代码

import torch
import torch.nn as nn

# 创建一个 ParameterDict
param_dict = nn.ParameterDict()

# 添加参数
param_dict['weight'] = nn.Parameter(torch.randn(2, 2))
param_dict['bias'] = nn.Parameter(torch.zeros(2))

# 打印 ParameterDict
print(param_dict)

# 访问参数
print(param_dict['weight'])

# 更新参数
new_params = {'weight': nn.Parameter(torch.ones(2, 2))}
param_dict.update(new_params)

# 打印更新后的参数
print(param_dict['weight'])

# 移除参数
param_dict.pop('bias')

# 打印剩余的键
print(list(param_dict.keys()))

Module 集成

ParameterDict 可以与 torch.nn.Module 一起使用,方便地管理模型中的参数。例如:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.params = nn.ParameterDict({
            'weight': nn.Parameter(torch.randn(2, 2)),
            'bias': nn.Parameter(torch.zeros(2))
        })

    def forward(self, x):
        return x @ self.params['weight'] + self.params['bias']

model = MyModel()
print(model.params)

总结

torch.nn.ParameterDict 是一个非常有用的工具,特别适合在需要动态管理大量参数的场景中使用。它与 torch.nn.Module 的集成使得模型参数的存储和管理变得更加方便和灵活。