PyTorch 报错:ModuleAttributeError: ‘DataParallel‘ object has no attribute ‘ xxx (已解决)

2020/10/29 15:18
阅读数 3.7K

PyTorch 报错:ModuleAttributeError: 'DataParallel' object has no attribute ' xxx (已解决)

 

这个问题中 ,‘XXX’ 一般就是代码里面的需要优化的模型名称,例如,我的模型里定义了 optimizer_G 和 optimizer_D 两个网络(生成器网络和判别器网络)。

问题原因:

在 train.py 中,调用它们时,直觉地写成了 model.optimizer_G 的格式,如下:

model = create_model(opt)
model = model.cuda()
visualizer = Visualizer(opt)
if opt.fp16:    
    model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.optimizer_G, model.optimizer_D], opt_level='O1')             
    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
else:
    optimizer_G, optimizer_Dh = model.optimizer_G, model.optimizer_D

然而,其实这时 model 转换成了 model.module。

 

解决方法:

在 ‘ model. ’ 后面加一个 ‘ module. ’ 。

将 model.optimizer_G 改成 model.module.optimizer_G

将 model.optimizer_D 改成 model.module.optimizer_D

model = create_model(opt)
model = model.cuda()
visualizer = Visualizer(opt)
if opt.fp16:    
    model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.module.optimizer_G, model.module.optimizer_D], opt_level='O1')             
    model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
else:
    optimizer_G, optimizer_Dh = model.module.optimizer_G, model.module.optimizer_D

 

展开阅读全文
打赏
0
0 收藏
分享
加载中
更多评论
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部