Skip to content

Commit e398750

Browse files
committed
🧪 TrajectoryData: add pbc to existing tests and increase coverage
Existing tests that call `set_trajectory` with `cells` now pass an explicit `pbc` argument to avoid the deprecation warning introduced in 08c58ff. Also adds tests for `numsteps`/`numsites` on empty trajectories, input validation errors, return types of `get_step_data`, and `plot_positions_XYZ`.
1 parent 2a3875f commit e398750

4 files changed

Lines changed: 105 additions & 9 deletions

File tree

tests/cmdline/commands/test_data.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,10 +514,17 @@ def create_trajectory_data():
514514
[[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [-0.5, -0.5, -0.5]],
515515
]
516516
)
517+
pbc = [True, True, False]
517518

518519
# I set the node
519520
traj.set_trajectory(
520-
stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times, velocities=velocities
521+
stepids=stepids,
522+
cells=cells,
523+
symbols=symbols,
524+
positions=positions,
525+
times=times,
526+
velocities=velocities,
527+
pbc=pbc,
521528
)
522529

523530
traj.store()

tests/orm/nodes/data/test_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ def _generate_class_instance(data_class):
5050
times = stepids * 0.01
5151
cells = numpy.array([[[3.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 3.0]]])
5252
positions = numpy.array([[[0.0, 0.0, 0.0]]])
53-
instance.set_trajectory(stepids=stepids, cells=cells, symbols=['H'], positions=positions, times=times)
53+
instance.set_trajectory(
54+
stepids=stepids, cells=cells, symbols=['H'], positions=positions, times=times, pbc=[True, True, False]
55+
)
5456
return instance
5557

5658
if data_class is orm.UpfData:

tests/orm/nodes/data/test_trajectory.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from aiida.common.warnings import AiidaDeprecationWarning
77
from aiida.orm import StructureData, TrajectoryData, load_node
8+
from aiida.orm.nodes.data.array.trajectory import plot_positions_XYZ
89

910

1011
@pytest.fixture
@@ -77,6 +78,44 @@ def test_get_attribute_tryexcept_default(self):
7778
positions = 'FAILED_tryexc'
7879
assert positions == 1
7980

81+
def test_numsteps_numsites_empty(self):
82+
"""Test that `numsteps` and `numsites` return zero on an empty trajectory."""
83+
trajectory = TrajectoryData()
84+
assert trajectory.numsteps == 0
85+
assert trajectory.numsites == 0
86+
87+
def test_internal_validate_symbols_not_sequence(self):
88+
"""Test that passing a generator (non-Sequence) for symbols raises TypeError."""
89+
trajectory = TrajectoryData()
90+
positions = np.array([[[0.0, 0.0, 0.0]]])
91+
with pytest.raises(TypeError, match='symbols must be of type list'):
92+
trajectory.set_trajectory(symbols='H', positions=positions)
93+
94+
def test_internal_validate_wrong_array_types(self):
95+
"""Test TypeError is raised when arrays have wrong dtype or are not ndarrays."""
96+
trajectory = TrajectoryData()
97+
positions = np.array([[[0.0, 0.0, 0.0]]])
98+
99+
with pytest.raises(TypeError, match='positions must be a numpy array of floats'):
100+
trajectory.set_trajectory(symbols=['H'], positions=positions.astype(int))
101+
102+
with pytest.raises(TypeError, match='stepids must be a numpy array of integers'):
103+
trajectory.set_trajectory(symbols=['H'], positions=positions, stepids=np.array([0.0]))
104+
105+
cells = np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]])
106+
with pytest.raises(TypeError, match='cells must be a numpy array of floats'):
107+
trajectory.set_trajectory(
108+
symbols=['H'], positions=positions, cells=cells.astype(int), pbc=[True, True, True]
109+
)
110+
111+
times = np.array([0.0])
112+
with pytest.raises(TypeError, match='times must be a numpy array of floats'):
113+
trajectory.set_trajectory(symbols=['H'], positions=positions, times=times.astype(int))
114+
115+
velocities = np.array([[[0.0, 0.0, 0.0]]])
116+
with pytest.raises(TypeError, match='velocities must be a numpy array of floats'):
117+
trajectory.set_trajectory(symbols=['H'], positions=positions, velocities=velocities.astype(int))
118+
80119
def test_units(self):
81120
"""Test the setting of units attributes."""
82121
tjd = TrajectoryData()
@@ -110,6 +149,8 @@ def test_trajectory_get_step_data(self, trajectory_data):
110149
stepid, time, cell, symbols, positions, velocities = trajectory.get_step_data(-2)
111150
assert stepid == trajectory_data['stepids'][-2]
112151
assert time == trajectory_data['times'][-2]
152+
assert type(stepid) is int
153+
assert type(time) is float
113154
assert np.array_equal(cell, trajectory_data['cells'][-2, :, :])
114155
assert np.array_equal(symbols, trajectory_data['symbols'])
115156
assert np.array_equal(trajectory.pbc, trajectory_data['pbc'])
@@ -232,6 +273,15 @@ def test_trajectory_pbc_set_trajectory(self):
232273
}
233274
)
234275
trajectory.set_trajectory(**data)
276+
data.update(
277+
{
278+
'cells': np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]]),
279+
'pbc': [True, False],
280+
}
281+
)
282+
with pytest.raises(ValueError, match='`pbc` must be a list/tuple of length three with boolean values.'):
283+
trajectory.set_trajectory(**data)
284+
235285
assert trajectory.get_step_structure(0).pbc == (True, False, False)
236286

237287
def test_trajectory_without_pbc(self, trajectory_data):
@@ -243,3 +293,19 @@ def test_trajectory_without_pbc(self, trajectory_data):
243293
assert trajectory.pbc is None
244294
structure = trajectory.get_step_structure(0)
245295
assert structure.pbc == (True, True, True)
296+
297+
298+
def test_plot_positions_xyz(monkeypatch):
299+
"""Test that `plot_positions_XYZ` runs."""
300+
import matplotlib
301+
302+
matplotlib.use('Agg')
303+
monkeypatch.setattr('matplotlib.pyplot.show', lambda **kwargs: None)
304+
305+
n_steps, n_atoms = 20, 3
306+
times = np.linspace(0, 1, n_steps)
307+
positions = np.random.rand(n_steps, n_atoms, 3)
308+
colors = [(1.0, 0.0, 0.0), (0.0, 1.0, 0.0), (0.0, 0.0, 1.0)]
309+
310+
plot_positions_XYZ(times, positions, indices_to_show=[0, 1, 2], color_list=colors, label='test')
311+
plot_positions_XYZ(times, positions, indices_to_show=[0], color_list=colors, label='test', mintime=0.2, maxtime=0.8)

tests/test_dataclasses.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2739,10 +2739,17 @@ def test_creation(self):
27392739
[[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [-0.5, -0.5, -0.5]],
27402740
]
27412741
)
2742+
pbc = [True, True, False]
27422743

27432744
# I set the node
27442745
n.set_trajectory(
2745-
stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times, velocities=velocities
2746+
stepids=stepids,
2747+
cells=cells,
2748+
symbols=symbols,
2749+
positions=positions,
2750+
times=times,
2751+
velocities=velocities,
2752+
pbc=pbc,
27462753
)
27472754

27482755
# Generic checks
@@ -2772,7 +2779,7 @@ def test_creation(self):
27722779

27732780
########################################################
27742781
# I set the node, this time without times or velocities (the same node)
2775-
n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions)
2782+
n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, pbc=pbc)
27762783
# Generic checks
27772784
assert n.numsites == 3
27782785
assert n.numsteps == 2
@@ -2785,7 +2792,7 @@ def test_creation(self):
27852792

27862793
# Same thing, but for a new node
27872794
n = TrajectoryData()
2788-
n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions)
2795+
n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, pbc=pbc)
27892796
# Generic checks
27902797
assert n.numsites == 3
27912798
assert n.numsteps == 2
@@ -2798,7 +2805,7 @@ def test_creation(self):
27982805

27992806
########################################################
28002807
# I set the node, this time without velocities (the same node)
2801-
n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times)
2808+
n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times, pbc=pbc)
28022809
# Generic checks
28032810
assert n.numsites == 3
28042811
assert n.numsteps == 2
@@ -2811,7 +2818,7 @@ def test_creation(self):
28112818

28122819
# Same thing, but for a new node
28132820
n = TrajectoryData()
2814-
n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times)
2821+
n.set_trajectory(stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times, pbc=pbc)
28152822
# Generic checks
28162823
assert n.numsites == 3
28172824
assert n.numsteps == 2
@@ -2935,10 +2942,17 @@ def test_conversion_to_structure(self):
29352942
[[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [-0.5, -0.5, -0.5]],
29362943
]
29372944
)
2945+
pbc = [True, True, False]
29382946

29392947
# I set the node
29402948
n.set_trajectory(
2941-
stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times, velocities=velocities
2949+
stepids=stepids,
2950+
cells=cells,
2951+
symbols=symbols,
2952+
positions=positions,
2953+
times=times,
2954+
velocities=velocities,
2955+
pbc=pbc,
29422956
)
29432957

29442958
from_step = n.get_step_structure(1)
@@ -3116,10 +3130,17 @@ def test_export_to_file():
31163130
[[0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [-0.5, -0.5, -0.5]],
31173131
]
31183132
)
3133+
pbc = [True, True, False]
31193134

31203135
# I set the node
31213136
n.set_trajectory(
3122-
stepids=stepids, cells=cells, symbols=symbols, positions=positions, times=times, velocities=velocities
3137+
stepids=stepids,
3138+
cells=cells,
3139+
symbols=symbols,
3140+
positions=positions,
3141+
times=times,
3142+
velocities=velocities,
3143+
pbc=pbc,
31233144
)
31243145

31253146
# It is not obvious how to check that the bands are correct.

0 commit comments

Comments
 (0)