from torch.utils.serialization import load_lua doesnt work in current pytorch versions
here is a possible fix with torchfile
class pytorch_lua_wrapper:
def __init__(self, lua_path):
self.lua_model = torchfile.load(lua_path)
def get(self, idx):
return self.lua_model._obj.modules[idx]._obj
Now you can relace this line:
vgg1 = load_lua(args.vgg1)
with
vgg1 = pytorch_lua_wrapper(args.vgg1)
and this line
self.conv1.weight = torch.nn.Parameter(vgg1.get(0).weight.float())
with
self.conv1.weight =torch.nn.Parameter(torch.from_numpy(vgg1.get(0).weight).float())
from torch.utils.serialization import load_lua doesnt work in current pytorch versions
here is a possible fix with torchfile
Now you can relace this line:
vgg1 = load_lua(args.vgg1)with
vgg1 = pytorch_lua_wrapper(args.vgg1)and this line
self.conv1.weight = torch.nn.Parameter(vgg1.get(0).weight.float())with
self.conv1.weight =torch.nn.Parameter(torch.from_numpy(vgg1.get(0).weight).float())