1- """TensorStore implementation (optional dependency) ."""
1+ """TensorStore implementation -- zarr-python groups + TensorStore array I/O ."""
22
33from __future__ import annotations
44
5- import json
6- from pathlib import Path
75from typing import TYPE_CHECKING , Any
86
97import numpy as np
8+ import zarr
109
1110from iohub .core .config import TensorStoreConfig
1211from iohub .core .protocol import ZarrImplementation
2524 import tensorstore as ts
2625
2726
28- class _TsAttrs (dict ):
29- """Persistent attrs for a _TsGroup.
30-
31- Reads from ``zarr.json`` (v3) or ``.zattrs`` (v2).
32- Writes always go to ``zarr.json``.
33- """
34-
35- def __init__ (self , group : _TsGroup ):
36- self ._group = group
37- super ().__init__ (self ._load ())
38-
39- def _load (self ) -> dict :
40- zarr_json = self ._group .path / "zarr.json"
41- if zarr_json .exists ():
42- return json .loads (zarr_json .read_text ()).get ("attributes" , {})
43- zattrs = self ._group .path / ".zattrs"
44- if zattrs .exists ():
45- return json .loads (zattrs .read_text ())
46- return {}
47-
48- def _save (self ) -> None :
49- zarr_json = self ._group .path / "zarr.json"
50- if zarr_json .exists ():
51- meta = json .loads (zarr_json .read_text ())
52- else :
53- meta = {"zarr_format" : 3 , "node_type" : "group" , "attributes" : {}}
54- meta ["attributes" ] = dict (self )
55- zarr_json .write_text (json .dumps (meta ))
56-
57- def __setitem__ (self , key , value ):
58- super ().__setitem__ (key , value )
59- self ._save ()
60-
61- def update (self , * args , ** kwargs ):
62- super ().update (* args , ** kwargs )
63- self ._save ()
64-
65-
66- def _detect_zarr_driver (path : Path ) -> str :
67- """Detect zarr format for a store root. Called once at open_group time."""
68- if (path / "zarr.json" ).exists ():
69- return "zarr3"
70- if (path / ".zattrs" ).exists () or (path / ".zgroup" ).exists ():
71- return "zarr2"
72- return "zarr3" # new stores will be created as v3 always
73-
74-
75- class _TsGroup :
76- """Lightweight group handle (tensorstore has no native group concept)."""
77-
78- def __init__ (
79- self ,
80- path : Path ,
81- mode : str ,
82- impl : TensorStoreImplementation ,
83- zarr_driver : str = "zarr3" ,
84- root : Path | None = None ,
85- ):
86- if mode == "w-" and path .exists ():
87- raise FileExistsError (f"Store already exists: { path } " )
88- if mode in ("w" , "w-" , "a" ) and not path .exists ():
89- path .mkdir (parents = True , exist_ok = True )
90- (path / "zarr.json" ).write_text ('{"zarr_format": 3, "node_type": "group", "attributes": {}}' )
91- self .path = path
92- self .mode = mode
93- self ._impl = impl
94- self .zarr_driver = zarr_driver
95- self ._root = root if root is not None else path
96-
97- def create_group (self , name : str , overwrite : bool = False ) -> _TsGroup :
98- sub = self .path / name
99- if sub .exists () and not overwrite :
100- return _TsGroup (path = sub , mode = "a" , impl = self ._impl , zarr_driver = self .zarr_driver )
101- sub .mkdir (parents = True , exist_ok = True )
102- zarr_json = sub / "zarr.json"
103- if not zarr_json .exists () or overwrite :
104- zarr_json .write_text (json .dumps ({"zarr_format" : 3 , "node_type" : "group" , "attributes" : {}}))
105- return _TsGroup (path = sub , mode = "a" , impl = self ._impl , zarr_driver = self .zarr_driver )
106-
107- def __contains__ (self , name : str ) -> bool :
108- return self .get (name ) is not None
109-
110- def __delitem__ (self , name : str ) -> None :
111- import shutil
112-
113- sub = self .path / name
114- if not sub .exists ():
115- raise KeyError (name )
116- shutil .rmtree (sub )
117-
118- def __getitem__ (self , name : str ):
119- result = self .get (name )
120- if result is None :
121- raise KeyError (name )
122- return result
123-
124- def get (self , name : str , default = None ):
125- sub = self .path / name
126- if not sub .is_dir ():
127- return default
128- if self .zarr_driver == "zarr3" :
129- try :
130- meta = json .loads ((sub / "zarr.json" ).read_text ())
131- if meta .get ("node_type" ) == "array" :
132- return self ._impl .open_array (self , name )
133- if meta .get ("node_type" ) == "group" :
134- return _TsGroup (path = sub , mode = "a" , impl = self ._impl , zarr_driver = self .zarr_driver , root = self ._root )
135- except (OSError , ValueError ):
136- pass
137- else :
138- if (sub / ".zarray" ).exists ():
139- return self ._impl .open_array (self , name )
140- if (sub / ".zgroup" ).exists ():
141- return _TsGroup (path = sub , mode = "a" , impl = self ._impl , zarr_driver = self .zarr_driver , root = self ._root )
142- return default
143-
144- @property
145- def store (self ) -> _TsGroup :
146- return self
147-
148- @property
149- def root (self ) -> Path :
150- return self .path
151-
152- @property
153- def name (self ) -> str :
154- try :
155- rel = self .path .relative_to (self ._root )
156- return "/" + str (rel ) if str (rel ) != "." else "/"
157- except ValueError :
158- return str (self .path )
159-
160- @property
161- def basename (self ) -> str :
162- return self .path .name
163-
164- @property
165- def attrs (self ) -> _TsAttrs :
166- if not hasattr (self , "_attrs_cache" ) or self ._attrs_cache is None :
167- self ._attrs_cache = _TsAttrs (self )
168- return self ._attrs_cache
169-
170- def tree (self , level : int | None = None ) -> str :
171- lines = [self .basename ]
172- self ._tree_lines (self .path , "" , level , 0 , lines )
173- return "\n " .join (lines )
174-
175- def _tree_lines (self , p : Path , prefix : str , max_level : int | None , depth : int , lines : list ) -> None :
176- if max_level is not None and depth >= max_level :
177- return
178- try :
179- children = sorted (
180- d for d in (entry .name for entry in p .iterdir ()) if (p / d ).is_dir () and not d .startswith ("." )
181- )
182- except OSError :
183- return
184- for i , child in enumerate (children ):
185- connector = "└── " if i == len (children ) - 1 else "├── "
186- lines .append (f"{ prefix } { connector } { child } " )
187- extension = " " if i == len (children ) - 1 else "│ "
188- self ._tree_lines (p / child , prefix + extension , max_level , depth + 1 , lines )
189-
190-
19127def _fill_value_for_spec (data_type : str , fill_value : int | float ) -> object :
19228 """Return a TensorStore-compatible fill value for the given dtype string."""
19329 if data_type == "bool" :
@@ -228,11 +64,28 @@ def _spec_to_ts(spec: ArraySpec, path: str) -> dict:
22864 return {"driver" : "zarr3" , "kvstore" : {"driver" : "file" , "path" : path }, "metadata" : metadata }
22965
23066
231- _TS_IMPL_BASE = ZarrImplementation [_TsGroup , ts .TensorStore ] if _TS_AVAILABLE else object # type: ignore[assignment]
67+ def _resolve_array_path (group : zarr .Group , name : str ) -> str :
68+ """Get filesystem path for an array within a zarr.Group."""
69+ store = group .store
70+ if not hasattr (store , "root" ):
71+ raise TypeError (f"TensorStore requires a LocalStore (filesystem) backend, got { type (store ).__name__ !r} ." )
72+ root = store .root
73+ gpath = group .path
74+ if gpath :
75+ return str (root / gpath / name )
76+ return str (root / name )
77+
78+
79+ _TS_IMPL_BASE = ZarrImplementation [zarr .Group , ts .TensorStore ] if _TS_AVAILABLE else object # type: ignore[assignment]
23280
23381
23482class TensorStoreImplementation (_TS_IMPL_BASE ):
235- """TensorStore-backed I/O implementation."""
83+ """Hybrid implementation: zarr-python groups + TensorStore array I/O.
84+
85+ Group operations (metadata, hierarchy) are delegated to zarr-python.
86+ Array operations (create, read, write, downsample) use TensorStore
87+ for high-performance I/O with configurable concurrency and caching.
88+ """
23689
23790 def __init__ (self , config : TensorStoreConfig | None = None ):
23891 self .config = config or TensorStoreConfig ()
@@ -267,56 +120,34 @@ def _context(self) -> ts.Context:
267120 self ._ctx = ts .Context (ctx_opts )
268121 return self ._ctx
269122
270- # -- Group operations --------------------------------------------------
271-
272- def open_group (self , path : StorePath , mode : str , zarr_format : int | None = None ) -> _TsGroup :
273- p = Path (path )
274- return _TsGroup (path = p , mode = mode , impl = self , zarr_driver = _detect_zarr_driver (p ), root = p )
275-
276- def _iter_children (self , group : _TsGroup , node_type : str ) -> list [str ]:
277- """Return sorted child names matching node_type ('group' or 'array')."""
278- p = Path (group .path )
279- if not p .is_dir ():
280- return []
281- keys : list [str ] = []
282- match group .zarr_driver :
283- case "zarr3" :
284- for entry in p .iterdir ():
285- d = entry .name
286- if not entry .is_dir () or d .startswith ("." ):
287- continue
288- try :
289- meta = json .loads ((entry / "zarr.json" ).read_text ())
290- if meta .get ("node_type" ) == node_type :
291- keys .append (d )
292- except (OSError , ValueError ):
293- pass
294- case "zarr2" :
295- sentinel = {"group" : ".zgroup" , "array" : ".zarray" }[node_type ]
296- keys = [e .name for e in p .iterdir () if (p / e .name / sentinel ).exists ()]
297- return sorted (keys )
298-
299- def group_keys (self , group : _TsGroup ) -> list [str ]:
300- return self ._iter_children (group , "group" )
301-
302- def array_keys (self , group : _TsGroup ) -> list [str ]:
303- return self ._iter_children (group , "array" )
304-
305- def close (self , group : _TsGroup ) -> None :
306- pass # TensorStore handles are not persistent connections
307-
308- def get_zarr_format (self , group : _TsGroup ) -> int :
309- return 3 # TensorStore only supports zarr v3
123+ # -- Group operations (delegated to zarr-python) -----------------------
124+
125+ def open_group (self , path : StorePath , mode : str , zarr_format : int | None = None ) -> zarr .Group :
126+ return zarr .open_group (path , mode = mode , zarr_format = zarr_format )
127+
128+ def group_keys (self , group : zarr .Group ) -> list [str ]:
129+ return sorted (group .group_keys ())
130+
131+ def array_keys (self , group : zarr .Group ) -> list [str ]:
132+ return sorted (group .array_keys ())
133+
134+ def close (self , group : zarr .Group ) -> None :
135+ group .store .close ()
136+
137+ def get_zarr_format (self , group : zarr .Group ) -> int :
138+ return group .metadata .zarr_format
310139
311140 # -- Array lifecycle ---------------------------------------------------
312141
313- def create_array (self , group : _TsGroup , name : str , spec : ArraySpec , * , overwrite : bool = False ) -> ts .TensorStore :
314- ts_spec = _spec_to_ts (spec , str (Path (group .path ) / name ))
142+ def create_array (self , group : zarr .Group , name : str , spec : ArraySpec , * , overwrite : bool = False ) -> ts .TensorStore :
143+ path = _resolve_array_path (group , name )
144+ self ._array_cache .pop (path , None )
145+ ts_spec = _spec_to_ts (spec , path )
315146 return _ts_open (ts_spec , create = True , delete_existing = overwrite , context = self ._context ())
316147
317148 def create_array_v2 (
318149 self ,
319- group : _TsGroup ,
150+ group : zarr . Group ,
320151 name : str ,
321152 * ,
322153 shape : tuple [int , ...],
@@ -325,33 +156,48 @@ def create_array_v2(
325156 fill_value : int = 0 ,
326157 overwrite : bool = False ,
327158 ) -> ts .TensorStore :
159+ shuffle_map = {"noshuffle" : 0 , "shuffle" : 1 , "bitshuffle" : 2 }
160+ comp = self .config .compressor
161+ path = _resolve_array_path (group , name )
162+ self ._array_cache .pop (path , None )
163+ # TensorStore zarr2 driver requires bool fill_value for bool dtype
164+ resolved_dtype = np .dtype (dtype )
165+ if resolved_dtype .kind == "b" :
166+ fill_value = bool (fill_value )
328167 spec = {
329168 "driver" : "zarr2" ,
330- "kvstore" : {"driver" : "file" , "path" : str ( Path ( group . path ) / name ) },
169+ "kvstore" : {"driver" : "file" , "path" : path },
331170 "metadata" : {
332171 "shape" : list (shape ),
333172 "chunks" : list (chunks ),
334- "dtype" : np .dtype (dtype ).str , # zarr2 uses NumPy dtype strings e.g. "<u2"
335- "compressor" : {"id" : "blosc" , "cname" : "lz4" , "clevel" : 5 , "shuffle" : 1 },
173+ "dtype" : resolved_dtype .str ,
174+ "compressor" : {
175+ "id" : "blosc" ,
176+ "cname" : comp .cname ,
177+ "clevel" : comp .clevel ,
178+ "shuffle" : shuffle_map .get (comp .shuffle , 2 ),
179+ },
336180 "fill_value" : fill_value ,
337181 "order" : "C" ,
338182 "filters" : None ,
339183 },
340184 }
341185 return _ts_open (spec , create = True , delete_existing = overwrite , context = self ._context ())
342186
343- def open_array (self , group : _TsGroup , name : str ) -> ts .TensorStore :
344- key = str ( Path ( group . path ) / name )
187+ def open_array (self , group : zarr . Group , name : str ) -> ts .TensorStore :
188+ key = _resolve_array_path ( group , name )
345189 if key not in self ._array_cache :
190+ driver = "zarr3" if group .metadata .zarr_format == 3 else "zarr2"
191+ writable = not getattr (group .store , "read_only" , False )
346192 spec = {
347- "driver" : group . zarr_driver ,
193+ "driver" : driver ,
348194 "kvstore" : {"driver" : "file" , "path" : key },
349195 }
350196 self ._array_cache [key ] = _ts_open (
351197 spec ,
352198 open = True ,
353199 read = True ,
354- write = ( group . mode != "r" ) ,
200+ write = writable ,
355201 context = self ._context (),
356202 )
357203 return self ._array_cache [key ]
0 commit comments