Skip to content

Commit 33dd223

Browse files
committed
Added support for train/val/test split CodonMemmapDataset
1 parent 1db0a0c commit 33dd223

10 files changed

Lines changed: 730 additions & 24 deletions

File tree

bionemo-recipes/recipes/codonfm_native_te/codon_memmap_dataset.py

Lines changed: 513 additions & 0 deletions
Large diffs are not rendered by default.

bionemo-recipes/recipes/codonfm_native_te/dataset.py

Lines changed: 108 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323
import pyarrow.parquet as pq
2424
import torch
25+
from codon_memmap_dataset import CodonMemmapDataset
2526
from distributed_config import DistributedConfig
2627
from tokenizer import CodonTokenizer
2728
from torch.utils.data import DataLoader, Dataset, DistributedSampler
@@ -162,6 +163,11 @@ def __len__(self) -> int: # noqa: D105
162163
def __getitem__(self, idx: int) -> dict[str, str]: # noqa: D105
163164
chunk_id, start, end = self.global_indices[idx]
164165
token_ids = self.sequences_mmaps[chunk_id][start:end]
166+
# Note: decode(skip_special_tokens=True) silently drops <UNK> tokens (ID 2). The codon
167+
# tokenizer's tokenize() is strict 3-char chunking that cannot reparse the "<UNK>"
168+
# literal in a decoded string, so any window containing ambiguous-base codons loses
169+
# those positions when round-tripped. Use CodonMemmapDataset (returns sequence_tokens
170+
# directly) for PTL-parity behavior.
165171
sequence = self.tokenizer.decode(token_ids.tolist(), skip_special_tokens=True)
166172
return {"sequence": sequence}
167173

@@ -203,7 +209,10 @@ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
203209
all_labels = []
204210

205211
for sample in batch:
206-
ids = self.tokenizer.encode(sample["sequence"], add_special_tokens=True)
212+
if "sequence_tokens" in sample:
213+
ids = [self.tokenizer.cls_token_id, *sample["sequence_tokens"].tolist(), self.tokenizer.sep_token_id]
214+
else:
215+
ids = self.tokenizer.encode(sample["sequence"], add_special_tokens=True)
207216
# Truncate to max_seq_length, preserving trailing SEP token
208217
if len(ids) > self.max_seq_length:
209218
ids = [*ids[: self.max_seq_length - 1], self.tokenizer.sep_token_id]
@@ -281,7 +290,10 @@ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
281290
seq_lengths = []
282291

283292
for sample in batch:
284-
ids = self.tokenizer.encode(sample["sequence"], add_special_tokens=True)
293+
if "sequence_tokens" in sample:
294+
ids = [self.tokenizer.cls_token_id, *sample["sequence_tokens"].tolist(), self.tokenizer.sep_token_id]
295+
else:
296+
ids = self.tokenizer.encode(sample["sequence"], add_special_tokens=True)
285297
# Truncate to max_seq_length, preserving trailing SEP token
286298
if len(ids) > self.max_seq_length:
287299
ids = [*ids[: self.max_seq_length - 1], self.tokenizer.sep_token_id]
@@ -344,13 +356,26 @@ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
344356
}
345357

346358

347-
def _create_dataset(data_path: str, max_seq_length: int, seed: int) -> Dataset:
359+
def _create_dataset(
360+
data_path: str,
361+
max_seq_length: int,
362+
seed: int,
363+
split: str | None = None,
364+
split_kwargs: dict | None = None,
365+
) -> Dataset:
348366
"""Create the appropriate dataset based on data_path format.
349367
350368
Args:
351369
data_path: 'synthetic', path to a parquet file, or path to a memmap directory.
352370
max_seq_length: Maximum sequence length (used for memmap sliding windows).
353371
seed: Random seed.
372+
split: If set ("train" / "validation" / "test"), construct the split-aware
373+
CodonMemmapDataset (port of the PTL dataset) instead of MemmapCodonDataset.
374+
Only meaningful when data_path is a memmap directory; ignored otherwise.
375+
split_kwargs: Extra keyword arguments forwarded to CodonMemmapDataset
376+
(train_val_test_ratio, context_overlap, pretraining_task, min_seq_length,
377+
max_filter_seq_length, groups_to_use, taxid_exclusion_file, split_name_prefix,
378+
force_recompute). Only used when split is set.
354379
355380
Returns:
356381
A Dataset instance.
@@ -359,6 +384,14 @@ def _create_dataset(data_path: str, max_seq_length: int, seed: int) -> Dataset:
359384
return SyntheticCodonDataset(num_samples=500, seed=seed)
360385
data_dir = Path(data_path)
361386
if data_dir.is_dir() and (data_dir / "metadata.json").exists():
387+
if split is not None:
388+
return CodonMemmapDataset(
389+
data_path,
390+
split=split,
391+
max_seq_length=max_seq_length,
392+
seed=seed,
393+
**(split_kwargs or {}),
394+
)
362395
return MemmapCodonDataset(data_path, max_seq_length=max_seq_length)
363396
return ParquetCodonDataset(data_path)
364397

@@ -372,6 +405,8 @@ def create_bshd_dataloader(
372405
num_workers: int = 1,
373406
seed: int = 42,
374407
pad_to_multiple_of: int | None = None,
408+
split: str | None = None,
409+
split_kwargs: dict | None = None,
375410
) -> tuple[DataLoader, DistributedSampler]:
376411
"""Create a BSHD-format dataloader.
377412
@@ -384,25 +419,30 @@ def create_bshd_dataloader(
384419
num_workers: Number of dataloader workers.
385420
seed: Random seed.
386421
pad_to_multiple_of: Unused in BSHD mode (only applies to THD).
422+
split: If set, use the split-aware CodonMemmapDataset for memmap dirs.
423+
split_kwargs: Extra arguments forwarded to CodonMemmapDataset when split is set.
387424
388425
Returns:
389426
Tuple of (DataLoader, DistributedSampler).
390427
"""
391428
tokenizer = CodonTokenizer()
392429

393-
dataset = _create_dataset(data_path, max_seq_length, seed)
430+
dataset = _create_dataset(data_path, max_seq_length, seed, split=split, split_kwargs=split_kwargs)
394431

432+
sampler_kwargs = {"shuffle": False} if split == "validation" else {}
395433
sampler = DistributedSampler(
396434
dataset,
397435
rank=dist_config.rank,
398436
num_replicas=dist_config.world_size,
399437
seed=seed,
438+
**sampler_kwargs,
400439
)
401440

402441
collator = CodonMLMCollator(
403442
tokenizer=tokenizer,
404443
max_seq_length=max_seq_length,
405444
mlm_probability=mlm_probability,
445+
seed=seed,
406446
)
407447

408448
dataloader = DataLoader(
@@ -426,6 +466,8 @@ def create_thd_dataloader(
426466
num_workers: int = 1,
427467
seed: int = 42,
428468
pad_to_multiple_of: int | None = None,
469+
split: str | None = None,
470+
split_kwargs: dict | None = None,
429471
) -> tuple[DataLoader, DistributedSampler]:
430472
"""Create a THD-format (packed sequence) dataloader.
431473
@@ -440,6 +482,8 @@ def create_thd_dataloader(
440482
pad_to_multiple_of: If set, pad total tokens to a multiple of this value. If None,
441483
defaults to micro_batch_size * max_seq_length for consistent tensor shapes
442484
(matching ESM2's approach). Set to 0 to disable padding.
485+
split: If set, use the split-aware CodonMemmapDataset for memmap dirs.
486+
split_kwargs: Extra arguments forwarded to CodonMemmapDataset when split is set.
443487
444488
Returns:
445489
Tuple of (DataLoader, DistributedSampler).
@@ -454,20 +498,23 @@ def create_thd_dataloader(
454498
elif pad_to_multiple_of == 0:
455499
pad_to_multiple_of = None
456500

457-
dataset = _create_dataset(data_path, max_seq_length, seed)
501+
dataset = _create_dataset(data_path, max_seq_length, seed, split=split, split_kwargs=split_kwargs)
458502

503+
sampler_kwargs = {"shuffle": False} if split == "validation" else {}
459504
sampler = DistributedSampler(
460505
dataset,
461506
rank=dist_config.rank,
462507
num_replicas=dist_config.world_size,
463508
seed=seed,
509+
**sampler_kwargs,
464510
)
465511

466512
collator = CodonTHDCollator(
467513
tokenizer=tokenizer,
468514
max_seq_length=max_seq_length,
469515
mlm_probability=mlm_probability,
470516
pad_to_multiple_of=pad_to_multiple_of,
517+
seed=seed,
471518
)
472519

473520
dataloader = DataLoader(
@@ -480,3 +527,59 @@ def create_thd_dataloader(
480527
)
481528

482529
return dataloader, sampler
530+
531+
532+
def create_dataloaders(
533+
dist_config: DistributedConfig,
534+
*,
535+
use_sequence_packing: bool,
536+
build_validation: bool,
537+
use_split_dataset: bool = True,
538+
split_kwargs: dict | None = None,
539+
**factory_kwargs,
540+
) -> tuple[DataLoader, DataLoader | None, DistributedSampler]:
541+
"""Build train (and optionally validation) dataloaders from a single configuration.
542+
543+
Wrapper modeled on esm2_peft_te.create_dataloader: one factory call produces both loaders, so
544+
train and val datasets share the on-disk caches via mmap and the kernel page cache. When
545+
use_split_dataset is True, the new CodonMemmapDataset is constructed for each split (train/val
546+
samples are disjoint by the PTL proportional cluster split); when False, the legacy path is
547+
used and the val loader simply re-reads the train data (placeholder behavior).
548+
549+
If split_kwargs requests force_recompute, the flag is honored only by the train call; the val
550+
call is invoked with force_recompute=False so the cache written by train is reused instead of
551+
rebuilt a second time in the same process.
552+
553+
Args:
554+
dist_config: Distributed configuration.
555+
use_sequence_packing: Pick THD factory if True, BSHD factory if False.
556+
build_validation: If False, skip val-loader construction entirely (returns None).
557+
use_split_dataset: When True (default), construct the split-aware CodonMemmapDataset
558+
for memmap directories. Set to False to fall back to the legacy MemmapCodonDataset,
559+
in which case the val loader (if requested) re-reads the train data as a
560+
placeholder. Has no effect for synthetic/parquet data paths.
561+
split_kwargs: Extra arguments forwarded to CodonMemmapDataset (only used when
562+
use_split_dataset=True). See codon_memmap_dataset.CodonMemmapDataset for the full list.
563+
**factory_kwargs: Remaining keyword arguments passed to the low-level factory
564+
(data_path, micro_batch_size, max_seq_length, mlm_probability, num_workers, seed,
565+
pad_to_multiple_of).
566+
567+
Returns:
568+
Tuple of (train_dataloader, val_dataloader or None, train DistributedSampler).
569+
"""
570+
factory = create_thd_dataloader if use_sequence_packing else create_bshd_dataloader
571+
572+
train_split = "train" if use_split_dataset else None
573+
val_split = "validation" if use_split_dataset else None
574+
575+
train_dataloader, sampler = factory(dist_config, split=train_split, split_kwargs=split_kwargs, **factory_kwargs)
576+
577+
val_dataloader = None
578+
if build_validation:
579+
# The train call above has already regenerated the cache if force_recompute was set, so
580+
# the val call must use the warmed cache rather than redo the work. Copy split_kwargs to
581+
# avoid mutating the caller's dict.
582+
val_split_kwargs = {**split_kwargs, "force_recompute": False} if split_kwargs is not None else None
583+
val_dataloader, _ = factory(dist_config, split=val_split, split_kwargs=val_split_kwargs, **factory_kwargs)
584+
585+
return train_dataloader, val_dataloader, sampler

bionemo-recipes/recipes/codonfm_native_te/hydra_config/L0_sanity.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,8 @@ checkpoint:
3030

3131
logger:
3232
frequency: 1
33+
34+
validation:
35+
enabled: true
36+
eval_interval: 50
37+
num_batches: 4

bionemo-recipes/recipes/codonfm_native_te/hydra_config/defaults.yaml

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,20 @@ dataset:
1111
data_path: ???
1212
micro_batch_size: ???
1313
num_workers: 1
14-
max_seq_length: 512
14+
max_seq_length: 2048
1515
mlm_probability: 0.15
16+
seed: 123 # Used for DistributedSampler, MLM masking RNG, and (when split mode is on) the cluster split.
17+
use_split_dataset: true
18+
split_kwargs:
19+
train_val_test_ratio: [0.9998, 0.0002, 0.0]
20+
context_overlap: 0
21+
pretraining_task: mlm
22+
min_seq_length: 100
23+
max_filter_seq_length: 150_000
24+
groups_to_use: null
25+
taxid_exclusion_file: null
26+
split_name_prefix: ""
27+
force_recompute: false
1628

1729
# WandB config
1830
wandb_init_args:
@@ -35,7 +47,7 @@ fp4_config:
3547
adamw_kwargs:
3648
lr: 4e-4
3749
fused: true
38-
betas: [0.9, 0.98]
50+
betas: [0.9, 0.999]
3951
eps: 1e-8
4052
weight_decay: 0.01
4153

@@ -55,6 +67,11 @@ checkpoint:
5567
logger:
5668
frequency: 100
5769

70+
validation:
71+
enabled: false
72+
eval_interval: 500
73+
num_batches: 10
74+
5875
quant_stats_config:
5976
enabled: false
6077
quant_stats_file: ./fp8_debugging_stats.yaml

bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_1b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dataset:
1111
data_path: ???
1212
micro_batch_size: 4
1313
num_workers: 1
14-
max_seq_length: 512
14+
max_seq_length: 2048
1515

1616
# WandB config
1717
wandb_init_args:

bionemo-recipes/recipes/codonfm_native_te/hydra_config/encodon_5b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dataset:
1111
data_path: ???
1212
micro_batch_size: 4
1313
num_workers: 1
14-
max_seq_length: 512
14+
max_seq_length: 2048
1515

1616
# WandB config
1717
wandb_init_args:

bionemo-recipes/recipes/codonfm_native_te/perf_logger.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,18 @@ def log_step(
181181
self.num_unpadded_tokens.zero_()
182182
self.grad_acc_step_count = 0
183183

184+
def log_validation(self, step: int, val_metrics: dict):
185+
"""Log validation metrics to wandb on the main process.
186+
187+
Args:
188+
step: The current optimizer step.
189+
val_metrics: Dict of metric name -> scalar value (already reduced across ranks).
190+
"""
191+
if not self._dist_config.is_main_process():
192+
return
193+
wandb.log({f"val/{k}": v for k, v in val_metrics.items()}, step=step)
194+
logger.info("[VAL step=%d] %s", step, ", ".join(f"{k}: {v:.4g}" for k, v in val_metrics.items()))
195+
184196
def finish(self):
185197
"""Finish the logger."""
186198
if self.quant_stats_config:

bionemo-recipes/recipes/codonfm_native_te/run_1b.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,13 @@ torchrun --nproc_per_node=${NPROC_PER_NODE} ${TRAIN_SCRIPT} \
8383
use_fp32_master_weights=${USE_FP32_MASTER_WEIGHTS} \
8484
lr_scheduler_kwargs.num_warmup_steps=${NUM_WARMUP_STEPS} \
8585
wandb_init_args.name=${WANDB_RUN_NAME} \
86-
+wandb_init_args.project=${WANDB_PROJECT} \
86+
wandb_init_args.project=${WANDB_PROJECT} \
8787
checkpoint.save_final_model=${SAVE_FINAL_MODEL} \
8888
checkpoint.save_every_n_steps=${SAVE_EVERY_N_STEPS} \
8989
checkpoint.ckpt_dir=${CKPT_DIR} \
9090
checkpoint.resume_from_checkpoint=${RESUME_FROM_CHECKPOINT} \
9191
hydra.run.dir=${HYDRA_RUN_DIR} \
9292
fp8_config.enabled=${FP8_ENABLED} \
9393
fp8_config.fp8_recipe=${FP8_RECIPE} \
94-
fp8_config.fp8_format=${FP8_FORMAT}
94+
fp8_config.fp8_format=${FP8_FORMAT} \
95+
dataset.pad_to_multiple_of=32

0 commit comments

Comments
 (0)