1212from __future__ import annotations
1313
1414import logging
15+ import os
1516
1617import numpy as np
1718import pandas as pd
19+ from iohub .core .config import TensorStoreConfig
1820from lightning .pytorch import LightningDataModule
1921from monai .data .thread_buffer import ThreadDataLoader
2022from monai .transforms import Compose , MapTransform
@@ -238,6 +240,12 @@ def __init__(
238240 # Loss hyperparameters (informational)
239241 # Other
240242 self .cache_pool_bytes = cache_pool_bytes
243+ cpus = os .environ .get ("SLURM_CPUS_PER_TASK" )
244+ cpus = int (cpus ) if cpus is not None else (os .cpu_count () or 4 )
245+ self .tensorstore_config = TensorStoreConfig (
246+ data_copy_concurrency = cpus ,
247+ cache_pool_bytes = cache_pool_bytes or None ,
248+ )
241249 self .seed = seed
242250 self .include_wells = include_wells
243251 self .exclude_fovs = exclude_fovs
@@ -361,13 +369,13 @@ def _setup_experiment_split(self, registry: ExperimentRegistry) -> None:
361369 positive_cell_source = self .positive_cell_source ,
362370 positive_match_columns = self .positive_match_columns ,
363371 max_border_shift = self .max_border_shift ,
372+ tensorstore_config = self .tensorstore_config ,
364373 )
365374 self .train_dataset = MultiExperimentTripletDataset (
366375 index = train_index ,
367376 fit = True ,
368377 tau_range_hours = self .tau_range ,
369378 tau_decay_rate = self .tau_decay_rate ,
370- cache_pool_bytes = self .cache_pool_bytes ,
371379 channels_per_sample = self .channels_per_sample ,
372380 positive_cell_source = self .positive_cell_source ,
373381 positive_match_columns = self .positive_match_columns ,
@@ -388,13 +396,13 @@ def _setup_experiment_split(self, registry: ExperimentRegistry) -> None:
388396 positive_cell_source = self .positive_cell_source ,
389397 positive_match_columns = self .positive_match_columns ,
390398 max_border_shift = self .max_border_shift ,
399+ tensorstore_config = self .tensorstore_config ,
391400 )
392401 self .val_dataset = MultiExperimentTripletDataset (
393402 index = val_index ,
394403 fit = True ,
395404 tau_range_hours = self .tau_range ,
396405 tau_decay_rate = self .tau_decay_rate ,
397- cache_pool_bytes = self .cache_pool_bytes ,
398406 channels_per_sample = self .channels_per_sample ,
399407 positive_cell_source = self .positive_cell_source ,
400408 positive_match_columns = self .positive_match_columns ,
@@ -419,6 +427,7 @@ def _setup_fov_split(self, registry: ExperimentRegistry) -> None:
419427 num_workers = self .num_workers_index ,
420428 positive_cell_source = self .positive_cell_source ,
421429 positive_match_columns = self .positive_match_columns ,
430+ tensorstore_config = self .tensorstore_config ,
422431 )
423432
424433 rng = np .random .default_rng (self .seed )
@@ -458,7 +467,6 @@ def _setup_fov_split(self, registry: ExperimentRegistry) -> None:
458467 fit = True ,
459468 tau_range_hours = self .tau_range ,
460469 tau_decay_rate = self .tau_decay_rate ,
461- cache_pool_bytes = self .cache_pool_bytes ,
462470 channels_per_sample = self .channels_per_sample ,
463471 positive_cell_source = self .positive_cell_source ,
464472 positive_match_columns = self .positive_match_columns ,
@@ -477,7 +485,6 @@ def _setup_fov_split(self, registry: ExperimentRegistry) -> None:
477485 fit = True ,
478486 tau_range_hours = self .tau_range ,
479487 tau_decay_rate = self .tau_decay_rate ,
480- cache_pool_bytes = self .cache_pool_bytes ,
481488 channels_per_sample = self .channels_per_sample ,
482489 positive_cell_source = self .positive_cell_source ,
483490 positive_match_columns = self .positive_match_columns ,
0 commit comments