Skip to content

Commit a5a2af9

Browse files
authored
Evo2 spike-no-more support (#1011)
### Description Spike-no-more support training for evo2 models. ### Type of changes - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Refactor - [ ] Documentation update - [ ] Other (please describe): --------- Signed-off-by: John St John <jstjohn@nvidia.com>
1 parent 267eb51 commit a5a2af9

3 files changed

Lines changed: 29 additions & 5 deletions

File tree

sub-packages/bionemo-evo2/src/bionemo/evo2/models/mamba.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ class HybridMambaConfig8BEvo2Loss(NemotronHConfigBase):
335335
# to be close to a target value (1.0).
336336
use_targeted_variance_loss: bool = False
337337
targeted_variance_loss_loss_coeff: float = 0.1
338+
share_embeddings_and_output_weights: bool = False
338339

339340
def __post_init__(self):
340341
"""Post-init logic for Evo2 to enable backwards compatibility with old configs."""
@@ -378,6 +379,7 @@ def configure_model(
378379
seq_len_interpolation_factor=self.seq_len_interpolation_factor,
379380
pre_process=pre_process or parallel_state.is_pipeline_first_stage(),
380381
post_process=post_process or parallel_state.is_pipeline_last_stage(),
382+
share_embeddings_and_output_weights=self.share_embeddings_and_output_weights,
381383
)
382384

383385

sub-packages/bionemo-evo2/src/bionemo/evo2/run/train.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
from bionemo.evo2.models.mamba import MAMBA_MODEL_OPTIONS, MambaModel, mamba_no_weight_decay_cond_with_embeddings
5050
from bionemo.evo2.run.peft import Evo2LoRA
51+
from bionemo.evo2.utils.config import hyena_no_weight_decay_cond_with_embeddings
5152
from bionemo.evo2.utils.logging.callbacks import TEVCallback
5253
from bionemo.llm.utils.datamodule_utils import infer_global_batch_size
5354
from bionemo.llm.utils.logger_utils import WandbConfig, setup_nemo_lightning_logger
@@ -176,7 +177,6 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
176177
help="TP communication backend to use. Defaults to 'nccl'.",
177178
)
178179
parser.add_argument("--align-param-gather", action="store_true", default=False)
179-
# parser.add_argument("--straggler-detection", action="store_true", default=False)
180180
parser.add_argument(
181181
"--model-size",
182182
type=str,
@@ -356,7 +356,8 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
356356
help="If set, the embeddings are initialized with a Normal(0, 1.0) distribution rather "
357357
"than the default Normal(0, 0.02). This may help avoid loss spiking during training. Consider using this with "
358358
"--no-weight-decay-embeddings to avoid shrinking the embeddings to 0 by skipping weight decay on these layers, "
359-
"or with --use-targeted-variance-loss to maintain a 1.0 variance during training even with weight decay.",
359+
"or with --use-targeted-variance-loss to maintain a 1.0 variance during training even with weight decay. This "
360+
"also turns off shared weights between embeddings and outputs.",
360361
)
361362
parser.add_argument(
362363
"--no-weight-decay-embeddings",
@@ -442,6 +443,12 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
442443
default=0.0,
443444
help="Dropout probability for the hyena layers",
444445
)
446+
parser.add_argument(
447+
"--ffn-hidden-size",
448+
type=int,
449+
default=None,
450+
help="FFN hidden size for the hyena layers",
451+
)
445452
parser.add_argument(
446453
"--log-num-zeros-in-grad",
447454
action="store_true",
@@ -549,7 +556,6 @@ def train(args: argparse.Namespace) -> nl.Trainer:
549556
tokenizer=tokenizer,
550557
eod_mask_loss=args.eod_pad_in_loss_mask,
551558
)
552-
553559
if args.no_activation_checkpointing:
554560
activation_checkpointing_args = {
555561
"recompute_granularity": None,
@@ -583,6 +589,12 @@ def train(args: argparse.Namespace) -> nl.Trainer:
583589
"add_bias_output": args.add_bias_output,
584590
**activation_checkpointing_args,
585591
}
592+
if args.spike_no_more_embedding_init:
593+
config_modifiers_init["embedding_init_method_std"] = 1.0
594+
# When using spike_no_more_embedding_init, we don't want to share embeddings and outputs.
595+
config_modifiers_init["share_embeddings_and_output_weights"] = False
596+
if args.ffn_hidden_size:
597+
config_modifiers_init["ffn_hidden_size"] = args.ffn_hidden_size
586598
if args.use_targeted_variance_loss:
587599
config_modifiers_init["use_targeted_variance_loss"] = True
588600
if args.use_b2b_causal_conv1d:
@@ -603,6 +615,10 @@ def train(args: argparse.Namespace) -> nl.Trainer:
603615
if args.model_size not in HYENA_MODEL_OPTIONS:
604616
raise ValueError(f"Invalid model size for Hyena: {args.model_size}")
605617
model_config = HYENA_MODEL_OPTIONS[args.model_size](**config_modifiers_init)
618+
if args.no_weight_decay_embeddings:
619+
# Override the default weight decay condition for Hyena with our bionemo version that also excludes
620+
# embeddings
621+
model_config.hyena_no_weight_decay_cond_fn = hyena_no_weight_decay_cond_with_embeddings
606622
# Lora adaptors configuration
607623
lora_transform = None
608624
if args.lora_finetune:
@@ -612,8 +628,6 @@ def train(args: argparse.Namespace) -> nl.Trainer:
612628
else: # mamba
613629
if args.no_weight_decay_embeddings:
614630
config_modifiers_init["hyena_no_weight_decay_cond_fn"] = mamba_no_weight_decay_cond_with_embeddings
615-
if args.spike_no_more_embedding_init: # --spike-no-more-embedding-init
616-
config_modifiers_init["spike_no_more_embedding_init"] = True
617631
config_modifiers_init["lowercase_loss_reweighting"] = args.mamba_lowercase_loss_weight
618632
if args.model_size not in MAMBA_MODEL_OPTIONS:
619633
raise ValueError(f"Invalid model size for Mamba: {args.model_size}")

sub-packages/bionemo-evo2/src/bionemo/evo2/utils/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,17 @@
2020
from pathlib import Path
2121
from typing import Literal
2222

23+
from nemo.collections.llm.gpt.model.megatron.hyena.hyena_utils import hyena_no_weight_decay_cond
2324
from pydantic import BaseModel
2425

2526

27+
def hyena_no_weight_decay_cond_with_embeddings(name, param):
28+
"""Condition for no weight decay for Hyena parameters with embeddings."""
29+
if "embedding" in name:
30+
return True
31+
return hyena_no_weight_decay_cond(name, param)
32+
33+
2634
class Evo2TaxonomyLineage(BaseModel):
2735
"""Pydantic model class that defines the source lineage of a DNA sequence."""
2836

0 commit comments

Comments
 (0)