Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions torchfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,15 @@
LuaFunction = namedtuple('LuaFunction',
['size', 'dumped', 'upvalues'])

class mycontainer():
def __init__(self, val):
self.val = val
def __hash__(self):
return id(self.val)
def __eq__(self, other):
return id(self.val) == id(other.val)
def __ne__(self, other):
return id(self.val) != id(other.val)

class hashable_uniq_dict(dict):
"""
Expand All @@ -60,27 +69,37 @@ class hashable_uniq_dict(dict):
This way, dicts can be keys of other dicts.
"""

def __iter__(self):
return iter(self.keys())

def __getitem__(self, k):
for _k,v in self.items():
if str(_k) == str(k):
return v

def __setitem__(self, k, v):
dict.__setitem__(self, mycontainer(k), v)

def items(self):
return [(k.val, v) for k,v in dict.items(self)]

def keys(self):
return [k.val for k in dict.keys(self)]

def values(self):
return [v for v in dict.values(self)]

def __hash__(self):
return id(self)

def __getattr__(self, key):
if key in self:
return self[key]
if isinstance(key, (str, bytes)):
return self.get(key.encode('utf8'))

def __eq__(self, other):
return id(self) == id(other)

def __ne__(self, other):
return id(self) != id(other)

def _disabled_binop(self, other):
raise TypeError(
'hashable_uniq_dict does not support these comparisons')
__cmp__ = __ne__ = __le__ = __gt__ = __lt__ = _disabled_binop


class TorchObject(object):
"""
Simple torch object, used by `add_trivial_class_reader`.
Expand All @@ -97,16 +116,16 @@ def __init__(self, typename, obj=None, version_number=0):
self._version_number = version_number

def __getattr__(self, k):
if k in self._obj:
if k in self._obj.keys():
return self._obj[k]
if isinstance(k, (str, bytes)):
return self._obj.get(k.encode('utf8'))

return self._obj[k.encode('utf8')]
def __getitem__(self, k):
if k in self._obj:
if k in self._obj.keys():
return self._obj[k]
if isinstance(k, (str, bytes)):
return self._obj.get(k.encode('utf8'))
return self._obj[k.encode('utf8')]

def torch_typename(self):
return self._typename
Expand All @@ -118,7 +137,7 @@ def __str__(self):
return repr(self)

def __dir__(self):
keys = list(self._obj.keys())
keys = self._obj.keys()
keys.append('torch_typename')
return keys

Expand Down