2222import numpy as np
2323import pyarrow .parquet as pq
2424import torch
25+ from codon_memmap_dataset import CodonMemmapDataset
2526from distributed_config import DistributedConfig
2627from tokenizer import CodonTokenizer
2728from torch .utils .data import DataLoader , Dataset , DistributedSampler
@@ -162,6 +163,11 @@ def __len__(self) -> int: # noqa: D105
162163 def __getitem__ (self , idx : int ) -> dict [str , str ]: # noqa: D105
163164 chunk_id , start , end = self .global_indices [idx ]
164165 token_ids = self .sequences_mmaps [chunk_id ][start :end ]
166+ # Note: decode(skip_special_tokens=True) silently drops <UNK> tokens (ID 2). The codon
167+ # tokenizer's tokenize() is strict 3-char chunking that cannot reparse the "<UNK>"
168+ # literal in a decoded string, so any window containing ambiguous-base codons loses
169+ # those positions when round-tripped. Use CodonMemmapDataset (returns sequence_tokens
170+ # directly) for PTL-parity behavior.
165171 sequence = self .tokenizer .decode (token_ids .tolist (), skip_special_tokens = True )
166172 return {"sequence" : sequence }
167173
@@ -203,7 +209,10 @@ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
203209 all_labels = []
204210
205211 for sample in batch :
206- ids = self .tokenizer .encode (sample ["sequence" ], add_special_tokens = True )
212+ if "sequence_tokens" in sample :
213+ ids = [self .tokenizer .cls_token_id , * sample ["sequence_tokens" ].tolist (), self .tokenizer .sep_token_id ]
214+ else :
215+ ids = self .tokenizer .encode (sample ["sequence" ], add_special_tokens = True )
207216 # Truncate to max_seq_length, preserving trailing SEP token
208217 if len (ids ) > self .max_seq_length :
209218 ids = [* ids [: self .max_seq_length - 1 ], self .tokenizer .sep_token_id ]
@@ -281,7 +290,10 @@ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
281290 seq_lengths = []
282291
283292 for sample in batch :
284- ids = self .tokenizer .encode (sample ["sequence" ], add_special_tokens = True )
293+ if "sequence_tokens" in sample :
294+ ids = [self .tokenizer .cls_token_id , * sample ["sequence_tokens" ].tolist (), self .tokenizer .sep_token_id ]
295+ else :
296+ ids = self .tokenizer .encode (sample ["sequence" ], add_special_tokens = True )
285297 # Truncate to max_seq_length, preserving trailing SEP token
286298 if len (ids ) > self .max_seq_length :
287299 ids = [* ids [: self .max_seq_length - 1 ], self .tokenizer .sep_token_id ]
@@ -344,13 +356,26 @@ def __call__(self, batch: list[dict[str, str]]) -> dict[str, torch.Tensor]:
344356 }
345357
346358
347- def _create_dataset (data_path : str , max_seq_length : int , seed : int ) -> Dataset :
359+ def _create_dataset (
360+ data_path : str ,
361+ max_seq_length : int ,
362+ seed : int ,
363+ split : str | None = None ,
364+ split_kwargs : dict | None = None ,
365+ ) -> Dataset :
348366 """Create the appropriate dataset based on data_path format.
349367
350368 Args:
351369 data_path: 'synthetic', path to a parquet file, or path to a memmap directory.
352370 max_seq_length: Maximum sequence length (used for memmap sliding windows).
353371 seed: Random seed.
372+ split: If set ("train" / "validation" / "test"), construct the split-aware
373+ CodonMemmapDataset (port of the PTL dataset) instead of MemmapCodonDataset.
374+ Only meaningful when data_path is a memmap directory; ignored otherwise.
375+ split_kwargs: Extra keyword arguments forwarded to CodonMemmapDataset
376+ (train_val_test_ratio, context_overlap, pretraining_task, min_seq_length,
377+ max_filter_seq_length, groups_to_use, taxid_exclusion_file, split_name_prefix,
378+ force_recompute). Only used when split is set.
354379
355380 Returns:
356381 A Dataset instance.
@@ -359,6 +384,14 @@ def _create_dataset(data_path: str, max_seq_length: int, seed: int) -> Dataset:
359384 return SyntheticCodonDataset (num_samples = 500 , seed = seed )
360385 data_dir = Path (data_path )
361386 if data_dir .is_dir () and (data_dir / "metadata.json" ).exists ():
387+ if split is not None :
388+ return CodonMemmapDataset (
389+ data_path ,
390+ split = split ,
391+ max_seq_length = max_seq_length ,
392+ seed = seed ,
393+ ** (split_kwargs or {}),
394+ )
362395 return MemmapCodonDataset (data_path , max_seq_length = max_seq_length )
363396 return ParquetCodonDataset (data_path )
364397
@@ -372,6 +405,8 @@ def create_bshd_dataloader(
372405 num_workers : int = 1 ,
373406 seed : int = 42 ,
374407 pad_to_multiple_of : int | None = None ,
408+ split : str | None = None ,
409+ split_kwargs : dict | None = None ,
375410) -> tuple [DataLoader , DistributedSampler ]:
376411 """Create a BSHD-format dataloader.
377412
@@ -384,25 +419,30 @@ def create_bshd_dataloader(
384419 num_workers: Number of dataloader workers.
385420 seed: Random seed.
386421 pad_to_multiple_of: Unused in BSHD mode (only applies to THD).
422+ split: If set, use the split-aware CodonMemmapDataset for memmap dirs.
423+ split_kwargs: Extra arguments forwarded to CodonMemmapDataset when split is set.
387424
388425 Returns:
389426 Tuple of (DataLoader, DistributedSampler).
390427 """
391428 tokenizer = CodonTokenizer ()
392429
393- dataset = _create_dataset (data_path , max_seq_length , seed )
430+ dataset = _create_dataset (data_path , max_seq_length , seed , split = split , split_kwargs = split_kwargs )
394431
432+ sampler_kwargs = {"shuffle" : False } if split == "validation" else {}
395433 sampler = DistributedSampler (
396434 dataset ,
397435 rank = dist_config .rank ,
398436 num_replicas = dist_config .world_size ,
399437 seed = seed ,
438+ ** sampler_kwargs ,
400439 )
401440
402441 collator = CodonMLMCollator (
403442 tokenizer = tokenizer ,
404443 max_seq_length = max_seq_length ,
405444 mlm_probability = mlm_probability ,
445+ seed = seed ,
406446 )
407447
408448 dataloader = DataLoader (
@@ -426,6 +466,8 @@ def create_thd_dataloader(
426466 num_workers : int = 1 ,
427467 seed : int = 42 ,
428468 pad_to_multiple_of : int | None = None ,
469+ split : str | None = None ,
470+ split_kwargs : dict | None = None ,
429471) -> tuple [DataLoader , DistributedSampler ]:
430472 """Create a THD-format (packed sequence) dataloader.
431473
@@ -440,6 +482,8 @@ def create_thd_dataloader(
440482 pad_to_multiple_of: If set, pad total tokens to a multiple of this value. If None,
441483 defaults to micro_batch_size * max_seq_length for consistent tensor shapes
442484 (matching ESM2's approach). Set to 0 to disable padding.
485+ split: If set, use the split-aware CodonMemmapDataset for memmap dirs.
486+ split_kwargs: Extra arguments forwarded to CodonMemmapDataset when split is set.
443487
444488 Returns:
445489 Tuple of (DataLoader, DistributedSampler).
@@ -454,20 +498,23 @@ def create_thd_dataloader(
454498 elif pad_to_multiple_of == 0 :
455499 pad_to_multiple_of = None
456500
457- dataset = _create_dataset (data_path , max_seq_length , seed )
501+ dataset = _create_dataset (data_path , max_seq_length , seed , split = split , split_kwargs = split_kwargs )
458502
503+ sampler_kwargs = {"shuffle" : False } if split == "validation" else {}
459504 sampler = DistributedSampler (
460505 dataset ,
461506 rank = dist_config .rank ,
462507 num_replicas = dist_config .world_size ,
463508 seed = seed ,
509+ ** sampler_kwargs ,
464510 )
465511
466512 collator = CodonTHDCollator (
467513 tokenizer = tokenizer ,
468514 max_seq_length = max_seq_length ,
469515 mlm_probability = mlm_probability ,
470516 pad_to_multiple_of = pad_to_multiple_of ,
517+ seed = seed ,
471518 )
472519
473520 dataloader = DataLoader (
@@ -480,3 +527,59 @@ def create_thd_dataloader(
480527 )
481528
482529 return dataloader , sampler
530+
531+
532+ def create_dataloaders (
533+ dist_config : DistributedConfig ,
534+ * ,
535+ use_sequence_packing : bool ,
536+ build_validation : bool ,
537+ use_split_dataset : bool = True ,
538+ split_kwargs : dict | None = None ,
539+ ** factory_kwargs ,
540+ ) -> tuple [DataLoader , DataLoader | None , DistributedSampler ]:
541+ """Build train (and optionally validation) dataloaders from a single configuration.
542+
543+ Wrapper modeled on esm2_peft_te.create_dataloader: one factory call produces both loaders, so
544+ train and val datasets share the on-disk caches via mmap and the kernel page cache. When
545+ use_split_dataset is True, the new CodonMemmapDataset is constructed for each split (train/val
546+ samples are disjoint by the PTL proportional cluster split); when False, the legacy path is
547+ used and the val loader simply re-reads the train data (placeholder behavior).
548+
549+ If split_kwargs requests force_recompute, the flag is honored only by the train call; the val
550+ call is invoked with force_recompute=False so the cache written by train is reused instead of
551+ rebuilt a second time in the same process.
552+
553+ Args:
554+ dist_config: Distributed configuration.
555+ use_sequence_packing: Pick THD factory if True, BSHD factory if False.
556+ build_validation: If False, skip val-loader construction entirely (returns None).
557+ use_split_dataset: When True (default), construct the split-aware CodonMemmapDataset
558+ for memmap directories. Set to False to fall back to the legacy MemmapCodonDataset,
559+ in which case the val loader (if requested) re-reads the train data as a
560+ placeholder. Has no effect for synthetic/parquet data paths.
561+ split_kwargs: Extra arguments forwarded to CodonMemmapDataset (only used when
562+ use_split_dataset=True). See codon_memmap_dataset.CodonMemmapDataset for the full list.
563+ **factory_kwargs: Remaining keyword arguments passed to the low-level factory
564+ (data_path, micro_batch_size, max_seq_length, mlm_probability, num_workers, seed,
565+ pad_to_multiple_of).
566+
567+ Returns:
568+ Tuple of (train_dataloader, val_dataloader or None, train DistributedSampler).
569+ """
570+ factory = create_thd_dataloader if use_sequence_packing else create_bshd_dataloader
571+
572+ train_split = "train" if use_split_dataset else None
573+ val_split = "validation" if use_split_dataset else None
574+
575+ train_dataloader , sampler = factory (dist_config , split = train_split , split_kwargs = split_kwargs , ** factory_kwargs )
576+
577+ val_dataloader = None
578+ if build_validation :
579+ # The train call above has already regenerated the cache if force_recompute was set, so
580+ # the val call must use the warmed cache rather than redo the work. Copy split_kwargs to
581+ # avoid mutating the caller's dict.
582+ val_split_kwargs = {** split_kwargs , "force_recompute" : False } if split_kwargs is not None else None
583+ val_dataloader , _ = factory (dist_config , split = val_split , split_kwargs = val_split_kwargs , ** factory_kwargs )
584+
585+ return train_dataloader , val_dataloader , sampler
0 commit comments