1717
1818@dataclass
1919class ModelArgs (BaseModelArgs ):
20- model_type : str = "mimo_v2"
21- vocab_size : int = 152576
22- hidden_size : int = 4096
23- intermediate_size : int = 16384
24- moe_intermediate_size : int = 2048
25- num_hidden_layers : int = 48
26- num_attention_heads : int = 64
27- num_key_value_heads : int = 4
28- head_dim : int = 192
29- v_head_dim : int = 128
30- rope_theta : float = 5000000.0
31- add_full_attention_sink_bias : bool = False
32- swa_num_attention_heads : int = 64
33- swa_num_key_value_heads : int = 8
34- swa_head_dim : int = 192
35- swa_v_head_dim : int = 128
36- swa_rope_theta : float = 10000.0
37- sliding_window_size : int = 128
38- add_swa_attention_sink_bias : bool = True
39- hybrid_layer_pattern : Optional [List [int ]] = None
40- n_routed_experts : int = 256
41- num_experts_per_tok : int = 8
42- moe_layer_freq : Optional [List [int ]] = None
43- n_group : int = 1
44- topk_group : int = 1
45- norm_topk_prob : bool = True
20+ model_type : str
21+ vocab_size : int
22+ hidden_size : int
23+ intermediate_size : int
24+ moe_intermediate_size : int
25+ num_hidden_layers : int
26+ num_attention_heads : int
27+ num_key_value_heads : int
28+ head_dim : int
29+ v_head_dim : int
30+ rope_theta : float
31+ swa_num_attention_heads : int
32+ swa_num_key_value_heads : int
33+ swa_head_dim : int
34+ swa_v_head_dim : int
35+ swa_rope_theta : float
36+ sliding_window_size : int
37+ add_full_attention_sink_bias : bool
38+ add_swa_attention_sink_bias : bool
39+ hybrid_layer_pattern : List [int ]
40+ moe_layer_freq : List [int ]
41+ n_routed_experts : int
42+ num_experts_per_tok : int
43+ n_group : int
44+ topk_group : int
45+ norm_topk_prob : bool
46+ topk_method : str
47+ partial_rotary_factor : float
48+ attention_bias : bool
49+ layernorm_epsilon : float
50+ max_position_embeddings : int
4651 routed_scaling_factor : Optional [float ] = None
47- topk_method : str = "noaux_tc"
48- partial_rotary_factor : float = 0.334
49- attention_bias : bool = False
5052 attention_value_scale : Optional [float ] = None
51- layernorm_epsilon : float = 1e-5
52- max_position_embeddings : int = 262144
5353 rope_scaling : Optional [Dict [str , Any ]] = None
5454 tie_word_embeddings : bool = False
5555
5656 def __post_init__ (self ):
5757 n = self .num_hidden_layers
58- if self .hybrid_layer_pattern is None :
59- self .hybrid_layer_pattern = [0 ] * n
60- if self .moe_layer_freq is None :
61- self .moe_layer_freq = [0 ] * n
6258 if len (self .hybrid_layer_pattern ) != n :
6359 raise ValueError ("hybrid_layer_pattern length must match num_hidden_layers" )
6460 if len (self .moe_layer_freq ) != n :
@@ -208,7 +204,7 @@ def __init__(self, config: ModelArgs):
208204 self .e_score_correction_bias = mx .zeros ((config .n_routed_experts ,))
209205
210206 def __call__ (self , x ):
211- inds , scores = group_expert_select (
207+ return group_expert_select (
212208 x @ self .weight .T ,
213209 self .e_score_correction_bias ,
214210 self .top_k ,
@@ -217,7 +213,6 @@ def __call__(self, x):
217213 self .routed_scaling_factor ,
218214 self .norm_topk_prob ,
219215 )
220- return inds , scores .astype (x .dtype )
221216
222217
223218class MoE (nn .Module ):
@@ -236,7 +231,7 @@ def __call__(self, x):
236231 x = sum_gradients (self .sharding_group )(x )
237232 inds , scores = self .gate (x )
238233 y = self .switch_mlp (x , inds )
239- y = (y * scores [..., None ]).sum (axis = - 2 )
234+ y = (y * scores [..., None ]).sum (axis = - 2 ). astype ( x . dtype )
240235 if self .sharding_group is not None :
241236 y = mx .distributed .all_sum (y , group = self .sharding_group )
242237 return y
0 commit comments