Skip to content

Commit 545d24e

Browse files
authored
Merge pull request #30 from haeussma/from-enzymeml-inhomogeneous-data-shape
Fix `ValidationError` for inhomogeneous EnzymeML measurement data
2 parents 2927846 + 22364ed commit 545d24e

6 files changed

Lines changed: 290 additions & 38 deletions

File tree

catalax/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"from_enzymeml",
2323
]
2424

25-
__version__ = "0.5.4"
25+
__version__ = "0.5.5"
2626

2727
PARAMETERS = InAxes.PARAMETERS
2828
TIME = InAxes.TIME

catalax/dataset/dataset.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -494,19 +494,36 @@ def from_enzymeml(
494494
) -> Dataset:
495495
"""Create a dataset from an EnzymeML document.
496496
497+
First scans all measurements to determine the global maximum data-array
498+
length (``global_max_len``), then constructs each ``Measurement`` with
499+
arrays padded to that common length. Within each measurement, species
500+
with shorter time arrays are position-aligned onto the canonical (longest)
501+
time axis with ``NaN`` at unsampled positions. Across measurements,
502+
shorter time arrays are extended with monotonic continuation and data
503+
padded with ``NaN``.
504+
497505
Args:
498-
enzmldoc: EnzymeML document containing experimental data
506+
enzmldoc: EnzymeML document containing experimental data.
499507
500508
Returns:
501-
A new Dataset object with measurements extracted from the EnzymeML document
509+
Dataset with uniformly-shaped measurements ready for JAX operations.
502510
"""
511+
global_max_len = max(
512+
(
513+
len(sp.data)
514+
for meas in enzmldoc.measurements
515+
for sp in meas.species_data
516+
if sp.data is not None
517+
),
518+
default=0,
519+
)
503520

504-
missing_initial_conditions = []
505-
measurements = []
521+
missing_initial_conditions: list[str] = []
522+
measurements: list[Measurement] = []
506523

507524
for meas in enzmldoc.measurements:
508525
if any(sp.initial is not None for sp in meas.species_data):
509-
measurements.append(Measurement.from_enzymeml(meas))
526+
measurements.append(Measurement.from_enzymeml(meas, global_max_len))
510527
else:
511528
missing_initial_conditions.append(meas.id)
512529

@@ -518,17 +535,14 @@ def from_enzymeml(
518535
small_molecules = [sp.id for sp in enzmldoc.small_molecules]
519536
proteins = [sp.id for sp in enzmldoc.proteins]
520537
complexes = [sp.id for sp in enzmldoc.complexes]
521-
all_states = small_molecules + proteins + complexes
522538

523-
dataset = cls(
539+
return cls(
524540
id=enzmldoc.name,
525541
name=enzmldoc.name,
526-
states=all_states,
542+
states=small_molecules + proteins + complexes,
527543
measurements=measurements,
528544
)
529545

530-
return dataset
531-
532546
@classmethod
533547
def from_dataframe(
534548
cls,

catalax/dataset/measurement.py

Lines changed: 120 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -390,41 +390,135 @@ def from_dataframe(
390390
**kwargs,
391391
)
392392

393+
@staticmethod
394+
def _pad_species_arrays(
395+
measurement: pe.Measurement,
396+
global_max_len: int,
397+
) -> tuple[jax.Array, dict[str, jax.Array]]:
398+
"""Normalize species data onto a single canonical time axis.
399+
400+
EnzymeML species within the same measurement may be sampled at different
401+
time points (different lengths and/or different values). This method:
402+
403+
1. Selects the **longest** species time array as the canonical (unified)
404+
time axis for this measurement.
405+
2. **Validates** that every shorter species time array is a subset of the
406+
canonical axis (all their time points appear in the canonical array,
407+
within floating-point tolerance ``atol=1e-10``).
408+
3. **Position-aligns** each species' data onto the canonical axis, placing
409+
``NaN`` at canonical time positions where that species was not sampled.
410+
4. **Extends** the canonical time axis to ``global_max_len`` using
411+
monotonic continuation (``+1.0`` per step) and pads all data arrays
412+
with ``NaN``. This cross-measurement normalisation ensures all
413+
``ctx.Measurement`` objects in a ``Dataset`` share the same array
414+
length.
415+
416+
Args:
417+
measurement: A pyenzyme ``Measurement`` whose species may have
418+
heterogeneous time/data array lengths.
419+
global_max_len: Target length for all output arrays. Determined by
420+
``Dataset.from_enzymeml()`` across all measurements in the
421+
document.
422+
423+
Returns:
424+
``(time, data_dict)`` where ``time`` is a 1-D JAX array of length
425+
``global_max_len`` and every value in ``data_dict`` is a 1-D JAX
426+
array of the same length.
427+
428+
Raises:
429+
ValueError: If any species' time array contains values not present
430+
in the canonical time array (i.e., it is not a subset).
431+
"""
432+
non_empty = [
433+
sp for sp in measurement.species_data
434+
if sp.data is not None and len(sp.data) > 0
435+
]
436+
437+
# 1. Select canonical time array (longest).
438+
time_candidates = [
439+
sp.time for sp in non_empty if sp.time is not None and len(sp.time) > 0
440+
]
441+
if time_candidates:
442+
canonical_list: list[float] = max(time_candidates, key=len)
443+
else:
444+
# No species has time data; synthesise a zero-based integer axis.
445+
max_data_len = max((len(sp.data) for sp in non_empty), default=0)
446+
canonical_list = list(range(max_data_len))
447+
448+
canonical_np = np.array(canonical_list, dtype=float)
449+
450+
# 2 & 3. Validate subset + position-align each species.
451+
data_dict: dict[str, jax.Array] = {}
452+
for sp in non_empty:
453+
sp_data_np = np.array(sp.data, dtype=float)
454+
455+
if sp.time is None or len(sp.time) == 0:
456+
# No time info — align to the start of the canonical axis.
457+
aligned = np.full(len(canonical_np), np.nan)
458+
aligned[: len(sp_data_np)] = sp_data_np
459+
elif len(sp.time) == len(canonical_np) and np.allclose(
460+
sp.time, canonical_np, atol=1e-10
461+
):
462+
# Same axis as canonical — no alignment needed.
463+
aligned = sp_data_np
464+
else:
465+
sp_time_np = np.array(sp.time, dtype=float)
466+
aligned = np.full(len(canonical_np), np.nan)
467+
for j, (t, v) in enumerate(zip(sp_time_np, sp_data_np)):
468+
idx = np.where(np.isclose(canonical_np, t, atol=1e-10))[0]
469+
if idx.size == 0:
470+
raise ValueError(
471+
f"Time point {t!r} of species '{sp.species_id}' "
472+
f"(measurement '{measurement.id}') not found in the "
473+
f"canonical time array {canonical_list!r}. All species "
474+
"time arrays must be subsets of the longest time array "
475+
"within the same measurement."
476+
)
477+
aligned[idx[0]] = v
478+
479+
data_dict[sp.species_id] = jnp.array(aligned)
480+
481+
# 4. Extend canonical time to global_max_len.
482+
pad_len = global_max_len - len(canonical_np)
483+
if pad_len > 0:
484+
start = canonical_np[-1] + 1.0 if canonical_np.size > 0 else 0.0
485+
extension = np.arange(start, start + pad_len, 1.0)
486+
canonical_np = np.concatenate([canonical_np, extension])
487+
data_dict = {
488+
sid: jnp.concatenate([arr, jnp.full(pad_len, jnp.nan)])
489+
for sid, arr in data_dict.items()
490+
}
491+
492+
return jnp.array(canonical_np), data_dict
493+
393494
@classmethod
394-
def from_enzymeml(cls, measurement: pe.Measurement) -> "Measurement":
395-
"""Create a Measurement object from a pyenzyme Measurement object.
495+
def from_enzymeml(
496+
cls,
497+
measurement: pe.Measurement,
498+
global_max_len: int,
499+
) -> "Measurement":
500+
"""Create a Measurement from a pyenzyme Measurement object.
501+
502+
Delegates array normalisation to ``_pad_species_arrays``, which selects
503+
the longest species time array as the unified time axis, validates that
504+
shorter species time arrays are subsets of it, position-aligns their data
505+
(``NaN`` at unsampled positions), and pads all arrays to ``global_max_len``.
396506
397507
Args:
398508
measurement (pe.Measurement): PyEnzyme measurement object.
509+
global_max_len (int): Common length for all output arrays; determined
510+
by ``Dataset.from_enzymeml()`` across the entire document.
399511
400512
Returns:
401-
Measurement: New Measurement object with data from the PyEnzyme measurement.
513+
Measurement: New Measurement with a single unified time axis and
514+
NaN-padded data arrays.
402515
"""
403516
initials = {
404-
species.species_id: species.initial
405-
for species in measurement.species_data
406-
if species.initial is not None
517+
sp.species_id: sp.initial
518+
for sp in measurement.species_data
519+
if sp.initial is not None
407520
}
408-
409-
data = {
410-
species.species_id: jnp.array(species.data)
411-
for species in measurement.species_data
412-
if species.data is not None and len(species.data) > 0
413-
}
414-
415-
time = next(
416-
iter(
417-
[
418-
jnp.array(data.time)
419-
for data in measurement.species_data
420-
if data.time is not None and len(data.time) > 0
421-
]
422-
),
423-
None,
424-
)
425-
426-
if measurement.id is None:
427-
measurement.id = str(uuid4())
521+
time, data = cls._pad_species_arrays(measurement, global_max_len)
428522

429523
return cls(
430524
initial_conditions=initials,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "catalax"
3-
version = "0.5.4"
3+
version = "0.5.5"
44
description = "A JAX-based framework for (neural) ODE modelling in biocatalysis."
55
authors = [{ email = "jan.range@simtech.uni-stuttgart.de", name = "Jan Range" }]
66
license = "MIT"

tests/unit/dataset/test_dataset_enzymeml.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,44 @@ def test_dataset_from_enzymeml(self):
8989
assert len(ds.measurements) == len(doc.measurements)
9090
assert len(ds.states) == len(doc.small_molecules) + len(doc.proteins)
9191
assert len(ds.measurements) == len(doc.measurements)
92+
93+
def test_inhomogeneous_species_no_error(self):
94+
"""Different-length species arrays within one measurement must not raise ValidationError."""
95+
doc = pe.EnzymeMLDocument(name="test")
96+
s1 = doc.add_to_small_molecules(id="s1", name="s1")
97+
s2 = doc.add_to_small_molecules(id="s2", name="s2")
98+
99+
meas = doc.add_to_measurements(id="m1", name="m1")
100+
meas.add_to_species_data(
101+
species_id=s1.id, name="s1", initial=10.0,
102+
data=[10, 8, 6, 4, 2], time=[0, 1, 2, 3, 4],
103+
)
104+
meas.add_to_species_data(
105+
species_id=s2.id, name="s2", initial=0.0,
106+
data=[0, 4, 8], time=[0, 2, 4],
107+
)
108+
109+
ds = ctx.Dataset.from_enzymeml(doc) # must not raise
110+
m = ds.measurements[0]
111+
assert len(m.time) == 5
112+
assert len(m.data["s1"]) == len(m.data["s2"]) == 5
113+
114+
def test_multiple_measurements_uniform_length(self):
115+
"""Measurements of different lengths must be padded to global_max_len."""
116+
doc = pe.EnzymeMLDocument(name="test")
117+
s1 = doc.add_to_small_molecules(id="s1", name="s1")
118+
119+
m_short = doc.add_to_measurements(id="m_short", name="m_short")
120+
m_short.add_to_species_data(
121+
species_id=s1.id, name="s1", initial=10.0,
122+
data=[10, 8, 6], time=[0, 1, 2],
123+
)
124+
m_long = doc.add_to_measurements(id="m_long", name="m_long")
125+
m_long.add_to_species_data(
126+
species_id=s1.id, name="s1", initial=5.0,
127+
data=[5, 4, 3, 2, 1], time=[0, 1, 2, 3, 4],
128+
)
129+
130+
ds = ctx.Dataset.from_enzymeml(doc)
131+
lengths = {len(m.time) for m in ds.measurements}
132+
assert lengths == {5}, f"Expected uniform length 5, got {lengths}"
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""Tests for Measurement._pad_species_arrays() and from_enzymeml()."""
2+
import jax.numpy as jnp
3+
import pyenzyme as pe
4+
import pytest
5+
6+
from catalax.dataset.measurement import Measurement
7+
8+
9+
def _make_pe_measurement(species: list[dict]) -> pe.Measurement:
10+
"""Build a pe.Measurement with multiple species_data entries.
11+
12+
Each entry in ``species`` is a dict with keys:
13+
species_id, initial, data, time (all required).
14+
"""
15+
doc = pe.EnzymeMLDocument(name="test")
16+
meas = doc.add_to_measurements(id="m1", name="m1")
17+
for sp in species:
18+
sm = doc.add_to_small_molecules(id=sp["species_id"], name=sp["species_id"])
19+
meas.add_to_species_data(
20+
species_id=sm.id,
21+
name=sp["species_id"],
22+
initial=sp["initial"],
23+
data=sp["data"],
24+
time=sp["time"],
25+
)
26+
return meas
27+
28+
29+
class TestPadSpeciesArrays:
30+
def test_homogeneous_no_padding_needed(self):
31+
"""When all species already have the same length, arrays are unchanged."""
32+
meas = _make_pe_measurement([
33+
{"species_id": "s1", "initial": 10.0, "data": [10, 8, 6], "time": [0, 1, 2]},
34+
{"species_id": "s2", "initial": 0.0, "data": [0, 2, 4], "time": [0, 1, 2]},
35+
])
36+
time, data = Measurement._pad_species_arrays(meas, global_max_len=3)
37+
assert len(time) == 3
38+
assert len(data["s1"]) == len(data["s2"]) == 3
39+
assert not jnp.any(jnp.isnan(data["s1"]))
40+
assert not jnp.any(jnp.isnan(data["s2"]))
41+
42+
def test_subset_species_aligned_with_nan(self):
43+
"""Shorter species (subset time) get NaN at positions they were not measured."""
44+
# s1: [0,1,2,3,4] — canonical (longest)
45+
# s2: [0,2,4] — subset, measured at every other point
46+
meas = _make_pe_measurement([
47+
{"species_id": "s1", "initial": 10.0, "data": [10, 8, 6, 4, 2], "time": [0, 1, 2, 3, 4]},
48+
{"species_id": "s2", "initial": 0.0, "data": [0, 4, 8], "time": [0, 2, 4]},
49+
])
50+
time, data = Measurement._pad_species_arrays(meas, global_max_len=5)
51+
assert list(time) == [0, 1, 2, 3, 4]
52+
assert float(data["s2"][0]) == 0.0
53+
assert jnp.isnan(data["s2"][1]) # t=1 not measured
54+
assert float(data["s2"][2]) == 4.0
55+
assert jnp.isnan(data["s2"][3]) # t=3 not measured
56+
assert float(data["s2"][4]) == 8.0
57+
58+
def test_cross_measurement_padding_to_global_max(self):
59+
"""When global_max_len > local canonical length, time and data are extended."""
60+
meas = _make_pe_measurement([
61+
{"species_id": "s1", "initial": 10.0, "data": [10, 8, 6], "time": [0, 1, 2]},
62+
])
63+
time, data = Measurement._pad_species_arrays(meas, global_max_len=5)
64+
assert len(time) == 5
65+
assert float(time[3]) == 3.0 # monotonic continuation
66+
assert float(time[4]) == 4.0
67+
assert jnp.isnan(data["s1"][3]) # data padded with NaN
68+
assert jnp.isnan(data["s1"][4])
69+
70+
def test_raises_when_species_time_not_subset(self):
71+
"""Raises ValueError when a species has time points outside the canonical axis."""
72+
meas = _make_pe_measurement([
73+
{"species_id": "s1", "initial": 10.0, "data": [10, 8, 6, 4], "time": [0, 1, 2, 3]},
74+
{"species_id": "s2", "initial": 0.0, "data": [0, 2, 4], "time": [0, 1.5, 3]},
75+
# t=1.5 is NOT in s1's time array -> should raise
76+
])
77+
with pytest.raises(ValueError, match="not found in the canonical time"):
78+
Measurement._pad_species_arrays(meas, global_max_len=4)
79+
80+
81+
class TestMeasurementFromEnzymeML:
82+
def test_homogeneous_roundtrip(self):
83+
"""Homogeneous species (same time arrays) round-trip without NaN."""
84+
meas = _make_pe_measurement([
85+
{"species_id": "s1", "initial": 10.0, "data": [10, 8, 6], "time": [0, 1, 2]},
86+
{"species_id": "s2", "initial": 0.0, "data": [0, 2, 4], "time": [0, 1, 2]},
87+
])
88+
m = Measurement.from_enzymeml(meas, global_max_len=3)
89+
assert list(m.time) == [0, 1, 2]
90+
assert not jnp.any(jnp.isnan(m.data["s1"]))
91+
assert not jnp.any(jnp.isnan(m.data["s2"]))
92+
93+
def test_inhomogeneous_no_validation_error(self):
94+
"""The original bug: inhomogeneous lengths must not raise ValidationError."""
95+
meas = _make_pe_measurement([
96+
{"species_id": "s1", "initial": 10.0, "data": [10, 8, 6, 4, 2], "time": [0, 1, 2, 3, 4]},
97+
{"species_id": "s2", "initial": 0.0, "data": [0, 4, 8], "time": [0, 2, 4]},
98+
])
99+
m = Measurement.from_enzymeml(meas, global_max_len=5) # must not raise
100+
assert len(m.time) == 5
101+
assert len(m.data["s1"]) == len(m.data["s2"]) == 5
102+
assert jnp.isnan(m.data["s2"][1])
103+
assert jnp.isnan(m.data["s2"][3])

0 commit comments

Comments
 (0)