4848
4949from bionemo .evo2 .models .mamba import MAMBA_MODEL_OPTIONS , MambaModel , mamba_no_weight_decay_cond_with_embeddings
5050from bionemo .evo2 .run .peft import Evo2LoRA
51+ from bionemo .evo2 .utils .config import hyena_no_weight_decay_cond_with_embeddings
5152from bionemo .evo2 .utils .logging .callbacks import TEVCallback
5253from bionemo .llm .utils .datamodule_utils import infer_global_batch_size
5354from bionemo .llm .utils .logger_utils import WandbConfig , setup_nemo_lightning_logger
@@ -176,7 +177,6 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
176177 help = "TP communication backend to use. Defaults to 'nccl'." ,
177178 )
178179 parser .add_argument ("--align-param-gather" , action = "store_true" , default = False )
179- # parser.add_argument("--straggler-detection", action="store_true", default=False)
180180 parser .add_argument (
181181 "--model-size" ,
182182 type = str ,
@@ -356,7 +356,8 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
356356 help = "If set, the embeddings are initialized with a Normal(0, 1.0) distribution rather "
357357 "than the default Normal(0, 0.02). This may help avoid loss spiking during training. Consider using this with "
358358 "--no-weight-decay-embeddings to avoid shrinking the embeddings to 0 by skipping weight decay on these layers, "
359- "or with --use-targeted-variance-loss to maintain a 1.0 variance during training even with weight decay." ,
359+ "or with --use-targeted-variance-loss to maintain a 1.0 variance during training even with weight decay. This "
360+ "also turns off shared weights between embeddings and outputs." ,
360361 )
361362 parser .add_argument (
362363 "--no-weight-decay-embeddings" ,
@@ -442,6 +443,12 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
442443 default = 0.0 ,
443444 help = "Dropout probability for the hyena layers" ,
444445 )
446+ parser .add_argument (
447+ "--ffn-hidden-size" ,
448+ type = int ,
449+ default = None ,
450+ help = "FFN hidden size for the hyena layers" ,
451+ )
445452 parser .add_argument (
446453 "--log-num-zeros-in-grad" ,
447454 action = "store_true" ,
@@ -549,7 +556,6 @@ def train(args: argparse.Namespace) -> nl.Trainer:
549556 tokenizer = tokenizer ,
550557 eod_mask_loss = args .eod_pad_in_loss_mask ,
551558 )
552-
553559 if args .no_activation_checkpointing :
554560 activation_checkpointing_args = {
555561 "recompute_granularity" : None ,
@@ -583,6 +589,12 @@ def train(args: argparse.Namespace) -> nl.Trainer:
583589 "add_bias_output" : args .add_bias_output ,
584590 ** activation_checkpointing_args ,
585591 }
592+ if args .spike_no_more_embedding_init :
593+ config_modifiers_init ["embedding_init_method_std" ] = 1.0
594+ # When using spike_no_more_embedding_init, we don't want to share embeddings and outputs.
595+ config_modifiers_init ["share_embeddings_and_output_weights" ] = False
596+ if args .ffn_hidden_size :
597+ config_modifiers_init ["ffn_hidden_size" ] = args .ffn_hidden_size
586598 if args .use_targeted_variance_loss :
587599 config_modifiers_init ["use_targeted_variance_loss" ] = True
588600 if args .use_b2b_causal_conv1d :
@@ -603,6 +615,10 @@ def train(args: argparse.Namespace) -> nl.Trainer:
603615 if args .model_size not in HYENA_MODEL_OPTIONS :
604616 raise ValueError (f"Invalid model size for Hyena: { args .model_size } " )
605617 model_config = HYENA_MODEL_OPTIONS [args .model_size ](** config_modifiers_init )
618+ if args .no_weight_decay_embeddings :
619+ # Override the default weight decay condition for Hyena with our bionemo version that also excludes
620+ # embeddings
621+ model_config .hyena_no_weight_decay_cond_fn = hyena_no_weight_decay_cond_with_embeddings
606622 # Lora adaptors configuration
607623 lora_transform = None
608624 if args .lora_finetune :
@@ -612,8 +628,6 @@ def train(args: argparse.Namespace) -> nl.Trainer:
612628 else : # mamba
613629 if args .no_weight_decay_embeddings :
614630 config_modifiers_init ["hyena_no_weight_decay_cond_fn" ] = mamba_no_weight_decay_cond_with_embeddings
615- if args .spike_no_more_embedding_init : # --spike-no-more-embedding-init
616- config_modifiers_init ["spike_no_more_embedding_init" ] = True
617631 config_modifiers_init ["lowercase_loss_reweighting" ] = args .mamba_lowercase_loss_weight
618632 if args .model_size not in MAMBA_MODEL_OPTIONS :
619633 raise ValueError (f"Invalid model size for Mamba: { args .model_size } " )
0 commit comments