Skip to content

Commit 89338ec

Browse files
committed
Fix gate cast
1 parent a386aa1 commit 89338ec

1 file changed

Lines changed: 33 additions & 38 deletions

File tree

mlx_lm/models/mimo_v2.py

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,48 +17,44 @@
1717

1818
@dataclass
1919
class ModelArgs(BaseModelArgs):
20-
model_type: str = "mimo_v2"
21-
vocab_size: int = 152576
22-
hidden_size: int = 4096
23-
intermediate_size: int = 16384
24-
moe_intermediate_size: int = 2048
25-
num_hidden_layers: int = 48
26-
num_attention_heads: int = 64
27-
num_key_value_heads: int = 4
28-
head_dim: int = 192
29-
v_head_dim: int = 128
30-
rope_theta: float = 5000000.0
31-
add_full_attention_sink_bias: bool = False
32-
swa_num_attention_heads: int = 64
33-
swa_num_key_value_heads: int = 8
34-
swa_head_dim: int = 192
35-
swa_v_head_dim: int = 128
36-
swa_rope_theta: float = 10000.0
37-
sliding_window_size: int = 128
38-
add_swa_attention_sink_bias: bool = True
39-
hybrid_layer_pattern: Optional[List[int]] = None
40-
n_routed_experts: int = 256
41-
num_experts_per_tok: int = 8
42-
moe_layer_freq: Optional[List[int]] = None
43-
n_group: int = 1
44-
topk_group: int = 1
45-
norm_topk_prob: bool = True
20+
model_type: str
21+
vocab_size: int
22+
hidden_size: int
23+
intermediate_size: int
24+
moe_intermediate_size: int
25+
num_hidden_layers: int
26+
num_attention_heads: int
27+
num_key_value_heads: int
28+
head_dim: int
29+
v_head_dim: int
30+
rope_theta: float
31+
swa_num_attention_heads: int
32+
swa_num_key_value_heads: int
33+
swa_head_dim: int
34+
swa_v_head_dim: int
35+
swa_rope_theta: float
36+
sliding_window_size: int
37+
add_full_attention_sink_bias: bool
38+
add_swa_attention_sink_bias: bool
39+
hybrid_layer_pattern: List[int]
40+
moe_layer_freq: List[int]
41+
n_routed_experts: int
42+
num_experts_per_tok: int
43+
n_group: int
44+
topk_group: int
45+
norm_topk_prob: bool
46+
topk_method: str
47+
partial_rotary_factor: float
48+
attention_bias: bool
49+
layernorm_epsilon: float
50+
max_position_embeddings: int
4651
routed_scaling_factor: Optional[float] = None
47-
topk_method: str = "noaux_tc"
48-
partial_rotary_factor: float = 0.334
49-
attention_bias: bool = False
5052
attention_value_scale: Optional[float] = None
51-
layernorm_epsilon: float = 1e-5
52-
max_position_embeddings: int = 262144
5353
rope_scaling: Optional[Dict[str, Any]] = None
5454
tie_word_embeddings: bool = False
5555

5656
def __post_init__(self):
5757
n = self.num_hidden_layers
58-
if self.hybrid_layer_pattern is None:
59-
self.hybrid_layer_pattern = [0] * n
60-
if self.moe_layer_freq is None:
61-
self.moe_layer_freq = [0] * n
6258
if len(self.hybrid_layer_pattern) != n:
6359
raise ValueError("hybrid_layer_pattern length must match num_hidden_layers")
6460
if len(self.moe_layer_freq) != n:
@@ -208,7 +204,7 @@ def __init__(self, config: ModelArgs):
208204
self.e_score_correction_bias = mx.zeros((config.n_routed_experts,))
209205

210206
def __call__(self, x):
211-
inds, scores = group_expert_select(
207+
return group_expert_select(
212208
x @ self.weight.T,
213209
self.e_score_correction_bias,
214210
self.top_k,
@@ -217,7 +213,6 @@ def __call__(self, x):
217213
self.routed_scaling_factor,
218214
self.norm_topk_prob,
219215
)
220-
return inds, scores.astype(x.dtype)
221216

222217

223218
class MoE(nn.Module):
@@ -236,7 +231,7 @@ def __call__(self, x):
236231
x = sum_gradients(self.sharding_group)(x)
237232
inds, scores = self.gate(x)
238233
y = self.switch_mlp(x, inds)
239-
y = (y * scores[..., None]).sum(axis=-2)
234+
y = (y * scores[..., None]).sum(axis=-2).astype(x.dtype)
240235
if self.sharding_group is not None:
241236
y = mx.distributed.all_sum(y, group=self.sharding_group)
242237
return y

0 commit comments

Comments
 (0)