Skip to content

Commit dfbcfd8

Browse files
authored
feat: add StorageAdapter plugin system for third-party storage protocols
Adds a StorageAdapter ABC and entry-point plugin system (datajoint.storage group) so third-party packages can register new storage protocols without modifying DataJoint internals. Built-in protocols (file, s3, gcs, azure) remain hardcoded; full unification is tracked in #1440 with Phase 0 (per-protocol unit-test scaffolding) as a hard prerequisite, gated on the atomicity contract in safe_write/safe_copy and recursive-op semantics. - StorageAdapter ABC with four extension points: create_filesystem, validate_spec, full_path, get_url - Lazy entry-point discovery via _discover_adapters; adapters auto-load when their protocol is referenced in dj.config.stores - _require_adapter helper provides symmetric missing-adapter errors across _create_filesystem, _full_path, and get_url; _full_path now reaches file-protocol logic only via an explicit elif, not as a catch-all else, so unknown protocols can no longer silently take it - _apply_common_store_defaults keeps built-in and plugin paths in sync on shared defaults; the location default is intentionally not applied to plugins so adapters can declare it in required_keys - 23 unit tests covering the registry, validation defaults, backend delegation, symmetric error handling on unknown protocols, entry-point discovery, and graceful failure of a bad entry point
1 parent a251f44 commit dfbcfd8

4 files changed

Lines changed: 463 additions & 17 deletions

File tree

src/datajoint/settings.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -420,24 +420,29 @@ def get_store_spec(self, store: str | None = None, *, use_filepath_default: bool
420420

421421
spec = dict(self.stores[store])
422422

423-
# Set defaults for optional fields (common to all protocols)
424-
spec.setdefault("subfolding", None) # No subfolding by default
425-
spec.setdefault("partition_pattern", None) # No partitioning by default
426-
spec.setdefault("token_length", 8) # Default token length
427-
428-
# Set defaults for storage section prefixes
429-
spec.setdefault("hash_prefix", "_hash") # Hash-addressed storage section
430-
spec.setdefault("schema_prefix", "_schema") # Schema-addressed storage section
431-
spec.setdefault("filepath_prefix", None) # Filepath storage (unrestricted by default)
423+
self._apply_common_store_defaults(spec)
432424

433425
# Validate protocol
434426
protocol = spec.get("protocol", "").lower()
435427
supported_protocols = ("file", "s3", "gcs", "azure")
436428
if protocol not in supported_protocols:
437-
raise DataJointError(
438-
f'Missing or invalid protocol in config.stores["{store}"]. '
439-
f"Supported protocols: {', '.join(supported_protocols)}"
429+
from .storage_adapter import get_storage_adapter
430+
431+
adapter = get_storage_adapter(protocol)
432+
if adapter is None:
433+
raise DataJointError(
434+
f'Unknown protocol "{protocol}" in config.stores["{store}"]. '
435+
f"Built-in: {', '.join(supported_protocols)}. "
436+
f"Install a plugin package for additional protocols."
437+
)
438+
adapter.validate_spec(spec)
439+
self._validate_prefix_separation(
440+
store_name=store,
441+
hash_prefix=spec.get("hash_prefix"),
442+
schema_prefix=spec.get("schema_prefix"),
443+
filepath_prefix=spec.get("filepath_prefix"),
440444
)
445+
return spec
441446

442447
# Set protocol-specific defaults
443448
if protocol == "s3":
@@ -582,6 +587,16 @@ def normalize(p: str) -> str:
582587
f"Storage section prefixes must be mutually exclusive."
583588
)
584589

590+
@staticmethod
591+
def _apply_common_store_defaults(spec: dict[str, Any]) -> None:
592+
"""Apply defaults shared by every store protocol (built-in and plugin)."""
593+
spec.setdefault("subfolding", None)
594+
spec.setdefault("partition_pattern", None)
595+
spec.setdefault("token_length", 8)
596+
spec.setdefault("hash_prefix", "_hash")
597+
spec.setdefault("schema_prefix", "_schema")
598+
spec.setdefault("filepath_prefix", None)
599+
585600
def load(self, filename: str | Path) -> None:
586601
"""
587602
Load settings from a JSON file.

src/datajoint/storage.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,15 @@ def fs(self) -> fsspec.AbstractFileSystem:
330330
self._fs = self._create_filesystem()
331331
return self._fs
332332

333+
def _require_adapter(self):
334+
"""Look up a registered storage adapter, raising if none is registered."""
335+
from .storage_adapter import get_storage_adapter
336+
337+
adapter = get_storage_adapter(self.protocol)
338+
if adapter is None:
339+
raise errors.DataJointError(f"Unsupported storage protocol: {self.protocol}")
340+
return adapter
341+
333342
def _create_filesystem(self) -> fsspec.AbstractFileSystem:
334343
"""Create fsspec filesystem based on protocol."""
335344
if self.protocol == "file":
@@ -368,7 +377,7 @@ def _create_filesystem(self) -> fsspec.AbstractFileSystem:
368377
)
369378

370379
else:
371-
raise errors.DataJointError(f"Unsupported storage protocol: {self.protocol}")
380+
return self._require_adapter().create_filesystem(self.spec)
372381

373382
def _full_path(self, path: str | PurePosixPath) -> str:
374383
"""
@@ -397,12 +406,13 @@ def _full_path(self, path: str | PurePosixPath) -> str:
397406
if location:
398407
return f"{bucket}/{location}/{path}"
399408
return f"{bucket}/{path}"
400-
else:
401-
# Local filesystem - prepend location if specified
409+
elif self.protocol == "file":
402410
location = self.spec.get("location", "")
403411
if location:
404412
return str(Path(location) / path)
405413
return path
414+
else:
415+
return self._require_adapter().full_path(self.spec, path)
406416

407417
def get_url(self, path: str | PurePosixPath) -> str:
408418
"""
@@ -448,8 +458,7 @@ def get_url(self, path: str | PurePosixPath) -> str:
448458
elif self.protocol == "azure":
449459
return f"az://{full_path}"
450460
else:
451-
# Fallback: use protocol prefix
452-
return f"{self.protocol}://{full_path}"
461+
return self._require_adapter().get_url(self.spec, full_path)
453462

454463
def put_file(self, local_path: str | Path, remote_path: str | PurePosixPath, metadata: dict | None = None) -> None:
455464
"""

src/datajoint/storage_adapter.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""Plugin system for third-party storage protocols.
2+
3+
Third-party packages register adapters via entry points::
4+
5+
[project.entry-points."datajoint.storage"]
6+
myprotocol = "my_package:MyStorageAdapter"
7+
8+
The adapter is auto-discovered when DataJoint encounters the protocol name
9+
in a store configuration. No explicit import is needed.
10+
"""
11+
12+
from abc import ABC, abstractmethod
13+
from typing import Any
14+
import logging
15+
16+
import fsspec
17+
18+
from . import errors
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class StorageAdapter(ABC):
24+
"""Base class for storage protocol adapters.
25+
26+
Subclass this and declare an entry point to add a new storage protocol
27+
to DataJoint. At minimum, implement ``create_filesystem`` and set
28+
``protocol``, ``required_keys``, and ``allowed_keys``.
29+
"""
30+
31+
protocol: str
32+
required_keys: tuple[str, ...] = ()
33+
allowed_keys: tuple[str, ...] = ()
34+
35+
@abstractmethod
36+
def create_filesystem(self, spec: dict[str, Any]) -> fsspec.AbstractFileSystem:
37+
"""Return an fsspec filesystem instance for this protocol."""
38+
...
39+
40+
def validate_spec(self, spec: dict[str, Any]) -> None:
41+
"""Validate protocol-specific config fields."""
42+
missing = [k for k in self.required_keys if k not in spec]
43+
if missing:
44+
raise errors.DataJointError(f'{self.protocol} store is missing: {", ".join(missing)}')
45+
all_allowed = set(self.allowed_keys) | _COMMON_STORE_KEYS
46+
invalid = [k for k in spec if k not in all_allowed]
47+
if invalid:
48+
raise errors.DataJointError(f'Invalid key(s) for {self.protocol}: {", ".join(invalid)}')
49+
50+
def full_path(self, spec: dict[str, Any], relpath: str) -> str:
51+
"""Construct storage path from a relative path."""
52+
location = spec.get("location", "")
53+
return f"{location}/{relpath}" if location else relpath
54+
55+
def get_url(self, spec: dict[str, Any], path: str) -> str:
56+
"""Return a display URL for the stored object."""
57+
return f"{self.protocol}://{path}"
58+
59+
60+
_COMMON_STORE_KEYS = frozenset(
61+
{
62+
"protocol",
63+
"location",
64+
"subfolding",
65+
"partition_pattern",
66+
"token_length",
67+
"hash_prefix",
68+
"schema_prefix",
69+
"filepath_prefix",
70+
"stage",
71+
}
72+
)
73+
74+
_adapter_registry: dict[str, StorageAdapter] = {}
75+
_adapters_loaded: bool = False
76+
77+
78+
def get_storage_adapter(protocol: str) -> StorageAdapter | None:
79+
"""Look up a registered storage adapter by protocol name."""
80+
global _adapters_loaded
81+
if not _adapters_loaded:
82+
_discover_adapters()
83+
_adapters_loaded = True
84+
return _adapter_registry.get(protocol)
85+
86+
87+
def _discover_adapters() -> None:
88+
"""Load storage adapters from datajoint.storage entry points."""
89+
try:
90+
from importlib.metadata import entry_points
91+
except ImportError:
92+
logger.debug("importlib.metadata not available, skipping adapter discovery")
93+
return
94+
95+
try:
96+
eps = entry_points(group="datajoint.storage")
97+
except TypeError:
98+
eps = entry_points().get("datajoint.storage", [])
99+
100+
for ep in eps:
101+
if ep.name in _adapter_registry:
102+
continue
103+
try:
104+
adapter_cls = ep.load()
105+
adapter = adapter_cls()
106+
_adapter_registry[adapter.protocol] = adapter
107+
logger.debug(f"Loaded storage adapter: {adapter.protocol}")
108+
except Exception as e:
109+
logger.warning(f"Failed to load storage adapter '{ep.name}': {e}")

0 commit comments

Comments
 (0)