Skip to content

Commit 482886d

Browse files
committed
Replace pickle with JSON for cache serialization
Signed-off-by: Tamaki Nishino <otamachan@gmail.com>
1 parent 35baa7a commit 482886d

2 files changed

Lines changed: 142 additions & 8 deletions

File tree

rosidl_parser/rosidl_parser/cache.py

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,75 @@
1717
import json
1818
import os
1919
import pathlib
20-
import pickle
2120
import shutil
2221
import tempfile
2322
from typing import Any, List, Optional, TypedDict
2423

24+
import rosidl_parser.definition as _def
25+
26+
# Build class registry and precomputed slots for safe deserialization
27+
_CLASSES = {}
28+
_SLOTS = {}
29+
30+
for _name in dir(_def):
31+
_obj = getattr(_def, _name)
32+
if isinstance(_obj, type):
33+
_CLASSES[_name] = _obj
34+
# Precompute all slot names across MRO
35+
_slots = []
36+
for _klass in _obj.__mro__:
37+
_klass_slots = getattr(_klass, '__slots__', ())
38+
if isinstance(_klass_slots, str):
39+
_slots.append(_klass_slots)
40+
else:
41+
_slots.extend(_klass_slots)
42+
_SLOTS[_obj] = tuple(_slots)
43+
44+
45+
def _encode(obj):
46+
if obj is None or isinstance(obj, (bool, int, float, str)):
47+
return obj
48+
if isinstance(obj, pathlib.Path):
49+
return {'__path__': str(obj)}
50+
if isinstance(obj, list):
51+
return [_encode(item) for item in obj]
52+
if isinstance(obj, tuple):
53+
return {'__tuple__': [_encode(item) for item in obj]}
54+
if isinstance(obj, dict):
55+
return {k: _encode(v) for k, v in obj.items()}
56+
slots = _SLOTS.get(type(obj))
57+
if slots is None:
58+
raise TypeError(f'Cannot JSON-encode object of type {type(obj).__name__}')
59+
data = {'__class__': type(obj).__name__}
60+
for slot in slots:
61+
if hasattr(obj, slot):
62+
data[slot] = _encode(getattr(obj, slot))
63+
return data
64+
65+
66+
def _decode(data):
67+
if data is None or isinstance(data, (bool, int, float, str)):
68+
return data
69+
if isinstance(data, list):
70+
return [_decode(item) for item in data]
71+
if isinstance(data, dict):
72+
if '__path__' in data:
73+
return pathlib.Path(data['__path__'])
74+
if '__tuple__' in data:
75+
return tuple(_decode(item) for item in data['__tuple__'])
76+
cls_name = data.get('__class__')
77+
if cls_name:
78+
if cls_name not in _CLASSES:
79+
raise ValueError(f'Unknown class: {cls_name}')
80+
cls = _CLASSES[cls_name]
81+
obj = cls.__new__(cls)
82+
for slot in _SLOTS[cls]:
83+
if slot in data:
84+
object.__setattr__(obj, slot, _decode(data[slot]))
85+
return obj
86+
return {k: _decode(v) for k, v in data.items()}
87+
return data
88+
2589

2690
class CacheConfig(TypedDict):
2791
cache_dir: Optional[str]
@@ -157,7 +221,7 @@ def compute_cache_key(*args) -> Optional[str]:
157221
elif isinstance(arg, dict):
158222
hasher.update(json.dumps(arg, sort_keys=True).encode('utf-8'))
159223
else:
160-
hasher.update(pickle.dumps(arg, protocol=pickle.HIGHEST_PROTOCOL))
224+
hasher.update(json.dumps(arg, sort_keys=True, default=str).encode('utf-8'))
161225

162226
return hasher.hexdigest()
163227

@@ -171,9 +235,9 @@ def save_object_to_cache(cache_key: str, cache_subdir: str, obj: Any) -> None:
171235
tmp_dir = None
172236
try:
173237
tmp_dir = pathlib.Path(tempfile.mkdtemp(dir=cache_dir))
174-
cache_file = tmp_dir / 'object.pkl'
175-
with open(cache_file, 'wb') as f:
176-
pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
238+
cache_file = tmp_dir / 'object.json'
239+
with open(cache_file, 'w', encoding='utf-8') as f:
240+
json.dump(_encode(obj), f)
177241
if cache_entry_dir.exists():
178242
shutil.rmtree(cache_entry_dir)
179243
tmp_dir.rename(cache_entry_dir)
@@ -190,12 +254,12 @@ def restore_object_from_cache(cache_key: str, cache_subdir: str) -> Optional[Any
190254
if not cache_dir:
191255
return None
192256
cache_entry_dir = cache_dir / cache_key
193-
cache_file = cache_entry_dir / 'object.pkl'
257+
cache_file = cache_entry_dir / 'object.json'
194258
if not cache_file.exists():
195259
return None
196260
try:
197-
with open(cache_file, 'rb') as f:
198-
return pickle.load(f)
261+
with open(cache_file, 'r', encoding='utf-8') as f:
262+
return _decode(json.load(f))
199263
except Exception as e:
200264
debug_print(f'[rosidl cache] Failed to load from cache: {e}')
201265
return None

rosidl_parser/test/test_cache.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import os
16+
import pathlib
1617
import time
1718

1819
import pytest
@@ -136,3 +137,72 @@ def test_cleanup_removes_oldest(cache_dir):
136137

137138
assert not old_entry.exists()
138139
assert new_entry.exists()
140+
141+
142+
def test_idl_file_cache_round_trip(cache_dir):
143+
from rosidl_parser.definition import (
144+
Array,
145+
BasicType,
146+
BoundedSequence,
147+
BoundedString,
148+
Constant,
149+
IdlContent,
150+
IdlFile,
151+
IdlLocator,
152+
Include,
153+
Member,
154+
Message,
155+
NamespacedType,
156+
Structure,
157+
UnboundedSequence,
158+
UnboundedString,
159+
)
160+
161+
locator = IdlLocator(pathlib.Path('/base'), pathlib.Path('msg/Test.idl'))
162+
content = IdlContent()
163+
164+
ns_type = NamespacedType(['test_pkg', 'msg'], 'TestMessage')
165+
structure = Structure(ns_type, members=[
166+
Member(BasicType('int32'), 'field_a'),
167+
Member(BasicType('boolean'), 'field_b'),
168+
Member(UnboundedString(), 'field_c'),
169+
Member(BoundedString(100), 'field_d'),
170+
Member(Array(BasicType('double'), 5), 'field_e'),
171+
Member(BoundedSequence(BasicType('uint8'), 10), 'field_f'),
172+
Member(UnboundedSequence(BasicType('int64')), 'field_g'),
173+
Member(NamespacedType(['other_pkg', 'msg'], 'OtherType'), 'field_h'),
174+
])
175+
msg = Message(structure)
176+
msg.constants = [
177+
Constant('MY_CONST', BasicType('int32'), 42),
178+
]
179+
180+
content.elements.append(Include('other_pkg/msg/OtherType.idl'))
181+
content.elements.append(msg)
182+
183+
idl_file = IdlFile(locator, content)
184+
185+
save_object_to_cache('idl_key', 'sub', idl_file)
186+
restored = restore_object_from_cache('idl_key', 'sub')
187+
188+
assert restored is not None
189+
assert isinstance(restored, IdlFile)
190+
assert str(restored.locator.basepath) == '/base'
191+
assert str(restored.locator.relative_path) == 'msg/Test.idl'
192+
assert len(restored.content.elements) == 2
193+
assert isinstance(restored.content.elements[0], Include)
194+
assert restored.content.elements[0].locator == 'other_pkg/msg/OtherType.idl'
195+
msg_r = restored.content.elements[1]
196+
assert isinstance(msg_r, Message)
197+
assert msg_r.structure.namespaced_type.name == 'TestMessage'
198+
assert len(msg_r.structure.members) == 8
199+
assert isinstance(msg_r.structure.members[0].type, BasicType)
200+
assert msg_r.structure.members[0].type.typename == 'int32'
201+
assert isinstance(msg_r.structure.members[4].type, Array)
202+
assert msg_r.structure.members[4].type.size == 5
203+
assert isinstance(msg_r.structure.members[5].type, BoundedSequence)
204+
assert msg_r.structure.members[5].type.maximum_size == 10
205+
assert isinstance(msg_r.structure.members[6].type, UnboundedSequence)
206+
assert len(msg_r.constants) == 1
207+
assert msg_r.constants[0].name == 'MY_CONST'
208+
assert msg_r.constants[0].value == 42

0 commit comments

Comments
 (0)