教程 4: 自定义模型
2024-07-11 19:02:46
假设您想增加一个新的叫 的优化器,它的参数分别为 , , 和 。 您首先需要在一个文件里实现这个新的优化器,例如在 里面:
from mmcv.runner import OPTIMIZERS
from torch.optim import Optimizer
@OPTIMIZERS.register_module
class MyOptimizer(Optimizer):
def __init__(self, a, b, c)
然后增加这个模块到 里面,这样注册器 (registry) 将会发现这个新的模块并添加它:
from .my_optimizer import MyOptimizer
之后您可以在配置文件的 域里使用 , 如下所示,在配置文件里,优化器被 域所定义:
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
为了使用您自己的优化器,域可以被修改为:
optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)
我们已经支持了 PyTorch 自带的全部优化器,唯一修改的地方是在配置文件里的 域。例如,如果您想使用 ,尽管数值表现会掉点,还是可以如下修改:
optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)
使用者可以直接按照 PyTorch 文档教程 去设置参数。