Skip to content

Commit b666df9

Browse files
committed
feat(cli): add iohub compute-pyramid command
Wraps `Position.compute_pyramid` so users can build multiscale pyramid levels in place from the command line, mirroring the ergonomics of `iohub set-scale` (per-position glob, plate expansion, mode="r+"). iohub compute-pyramid -i input.zarr/*/*/* --levels 4 iohub compute-pyramid -i input.zarr/*/*/* -l 3 -m median --dims y,x Tests cover: - happy path with explicit levels and method - plate-glob expansion (passing the plate root) - partial-axis downsampling via --dims y,x (Z preserved) - --help and invalid --dims rejection
1 parent 1c739a2 commit b666df9

2 files changed

Lines changed: 187 additions & 0 deletions

File tree

src/iohub/cli/cli.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
import pathlib
23

34
import click
@@ -8,6 +9,8 @@
89
from iohub.reader import print_info
910
from iohub.rename_wells import rename_wells
1011

12+
_logger = logging.getLogger(__name__)
13+
1114
VERSION = __version__
1215

1316
_DATASET_PATH = click.Path(exists=True, file_okay=False, resolve_path=True, path_type=pathlib.Path)
@@ -163,6 +166,65 @@ def set_scale(
163166
dataset.set_scale(image, name, value)
164167

165168

169+
_PYRAMID_METHODS = ["mean", "median", "mode", "min", "max", "stride"]
170+
_PYRAMID_DIM_CHOICES = ["t", "z", "y", "x"]
171+
172+
173+
def _parse_dims(ctx, param, value):
174+
if value is None:
175+
return None
176+
tokens = [t.strip().lower() for t in value.split(",") if t.strip()]
177+
if not tokens:
178+
return None
179+
invalid = [t for t in tokens if t not in _PYRAMID_DIM_CHOICES]
180+
if invalid:
181+
raise click.BadParameter(f"Unknown dim(s) {invalid}. Valid choices: {_PYRAMID_DIM_CHOICES}.")
182+
return set(tokens)
183+
184+
185+
@cli.command(name="compute-pyramid")
186+
@click.help_option("-h", "--help")
187+
@input_position_dirpaths()
188+
@click.option(
189+
"--levels",
190+
"-l",
191+
required=True,
192+
type=click.IntRange(min=2),
193+
help="Total number of pyramid levels including level 0 (e.g. 4 = level 0 + 3 extra).",
194+
)
195+
@click.option(
196+
"--method",
197+
"-m",
198+
required=False,
199+
default="mean",
200+
show_default=True,
201+
type=click.Choice(_PYRAMID_METHODS),
202+
help="Downsampling method.",
203+
)
204+
@click.option(
205+
"--dims",
206+
"-d",
207+
required=False,
208+
default=None,
209+
callback=_parse_dims,
210+
help=("Comma-separated axes to downsample (e.g. 'y,x' for YX-only). Defaults to 'z,y,x' inside iohub."),
211+
)
212+
def compute_pyramid(input_position_dirpaths, levels, method, dims):
213+
"""Compute multiscale pyramid levels in place for OME-Zarr positions.
214+
215+
The level-0 array is preserved; new downsampled levels are appended.
216+
217+
```
218+
iohub compute-pyramid -i input.zarr/*/*/* --levels 4
219+
iohub compute-pyramid -i input.zarr/*/*/* -l 3 -m median --dims y,x
220+
```
221+
"""
222+
for input_position_dirpath in input_position_dirpaths:
223+
_logger.info(f"Computing pyramid for {input_position_dirpath}")
224+
with open_ome_zarr(input_position_dirpath, layout="fov", mode="r+") as dataset:
225+
dataset.compute_pyramid(levels=levels, method=method, dims=dims)
226+
227+
166228
@cli.command(name="rename-wells")
167229
@click.help_option("-h", "--help")
168230
@click.option(

tests/cli/test_cli.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44
from unittest.mock import patch
55

6+
import numpy as np
67
import pytest
78
from click.testing import CliRunner
89

@@ -217,3 +218,127 @@ def test_cli_rename_wells(csv_data_file_1):
217218

218219
assert result.exit_code == 0
219220
assert "Renaming" in result.output
221+
222+
223+
def _make_pyramid_test_plate(store_path: Path) -> tuple[tuple[int, ...], list[str]]:
224+
"""Create a tiny plate with a single position and a level-0 array.
225+
226+
Returns the level-0 shape and the position key (row, col, fov).
227+
"""
228+
shape = (1, 2, 8, 64, 64)
229+
rng = np.random.default_rng(0)
230+
data = rng.integers(0, 255, size=shape, dtype=np.uint16)
231+
position_key = ("A", "1", "0")
232+
with open_ome_zarr(
233+
store_path,
234+
layout="hcs",
235+
mode="w",
236+
channel_names=["ch1", "ch2"],
237+
) as plate:
238+
position = plate.create_position(*position_key)
239+
position.create_image("0", data)
240+
return shape, list(position_key)
241+
242+
243+
def test_cli_compute_pyramid(tmp_path, caplog):
244+
store_path = tmp_path / "compute_pyramid.zarr"
245+
shape, position_key = _make_pyramid_test_plate(store_path)
246+
position_path = store_path.joinpath(*position_key)
247+
248+
runner = CliRunner()
249+
result = runner.invoke(
250+
cli,
251+
[
252+
"compute-pyramid",
253+
"-i",
254+
str(position_path),
255+
"--levels",
256+
"3",
257+
"--method",
258+
"mean",
259+
],
260+
)
261+
assert result.exit_code == 0, result.output
262+
assert any("Computing pyramid" in record.message for record in caplog.records)
263+
264+
with open_ome_zarr(position_path, layout="fov", mode="r") as pos:
265+
dataset_paths = pos.metadata.multiscales[0].get_dataset_paths()
266+
assert dataset_paths == ["0", "1", "2"]
267+
assert pos["0"].shape == shape
268+
# Cascade YX/Z downsampling halves spatial axes per level.
269+
assert pos["1"].shape[-2:] == (shape[-2] // 2, shape[-1] // 2)
270+
assert pos["2"].shape[-2:] == (shape[-2] // 4, shape[-1] // 4)
271+
272+
273+
def test_cli_compute_pyramid_plate_glob(tmp_path):
274+
store_path = tmp_path / "compute_pyramid_plate.zarr"
275+
_, position_key = _make_pyramid_test_plate(store_path)
276+
position_path = store_path.joinpath(*position_key)
277+
278+
runner = CliRunner()
279+
# Pass the plate root: `_validate_and_process_paths` expands it into positions.
280+
result = runner.invoke(
281+
cli,
282+
["compute-pyramid", "-i", str(store_path), "-l", "2"],
283+
)
284+
assert result.exit_code == 0, result.output
285+
286+
with open_ome_zarr(position_path, layout="fov", mode="r") as pos:
287+
dataset_paths = pos.metadata.multiscales[0].get_dataset_paths()
288+
assert dataset_paths == ["0", "1"]
289+
290+
291+
def test_cli_compute_pyramid_dims(tmp_path):
292+
store_path = tmp_path / "compute_pyramid_dims.zarr"
293+
shape, position_key = _make_pyramid_test_plate(store_path)
294+
position_path = store_path.joinpath(*position_key)
295+
296+
runner = CliRunner()
297+
result = runner.invoke(
298+
cli,
299+
[
300+
"compute-pyramid",
301+
"-i",
302+
str(position_path),
303+
"-l",
304+
"2",
305+
"--dims",
306+
"y,x",
307+
],
308+
)
309+
assert result.exit_code == 0, result.output
310+
311+
with open_ome_zarr(position_path, layout="fov", mode="r") as pos:
312+
# Z is preserved; YX halved.
313+
assert pos["1"].shape[-3] == shape[-3]
314+
assert pos["1"].shape[-2:] == (shape[-2] // 2, shape[-1] // 2)
315+
316+
317+
def test_cli_compute_pyramid_help():
318+
runner = CliRunner()
319+
for option in ("-h", "--help"):
320+
result = runner.invoke(cli, ["compute-pyramid", option])
321+
assert result.exit_code == 0
322+
assert "compute-pyramid" in result.output or "Compute multiscale" in result.output
323+
324+
325+
def test_cli_compute_pyramid_invalid_dims(tmp_path):
326+
store_path = tmp_path / "compute_pyramid_bad_dims.zarr"
327+
_, position_key = _make_pyramid_test_plate(store_path)
328+
position_path = store_path.joinpath(*position_key)
329+
330+
runner = CliRunner()
331+
result = runner.invoke(
332+
cli,
333+
[
334+
"compute-pyramid",
335+
"-i",
336+
str(position_path),
337+
"-l",
338+
"2",
339+
"--dims",
340+
"y,bogus",
341+
],
342+
)
343+
assert result.exit_code != 0
344+
assert "bogus" in result.output

0 commit comments

Comments
 (0)