Skip to content

Commit 51dba3f

Browse files
committed
[Model]: support mimo model
Signed-off-by: wp-alpha <wangpeng66@xiaomi.com>
1 parent ebb3930 commit 51dba3f

File tree

5 files changed

+497
-4
lines changed

5 files changed

+497
-4
lines changed

vllm/config.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,8 @@ def get_num_attention_heads(self,
11071107
def get_layers_start_end_indices(
11081108
self, parallel_config: "ParallelConfig") -> tuple[int, int]:
11091109
from vllm.distributed.utils import get_pp_indices
1110-
if self.hf_text_config.model_type == "deepseek_mtp":
1110+
if (self.hf_text_config.model_type == "deepseek_mtp"
1111+
or self.hf_config.model_type == "mimo_mtp"):
11111112
total_num_hidden_layers = getattr(self.hf_text_config,
11121113
"num_nextn_predict_layers", 0)
11131114
else:
@@ -2290,6 +2291,17 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
22902291
"n_predict": n_predict,
22912292
"architectures": ["DeepSeekMTPModel"]
22922293
})
2294+
2295+
if hf_config.architectures[0] == "MiMoForCausalLM":
2296+
hf_config.model_type = "mimo_mtp"
2297+
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
2298+
hf_config.update({
2299+
"num_hidden_layers": 0,
2300+
"n_predict": n_predict,
2301+
"architectures": ["MiMoMTPModel"]
2302+
})
2303+
return hf_config
2304+
22932305
return hf_config
22942306

22952307
def __post_init__(self):
@@ -2306,8 +2318,10 @@ def __post_init__(self):
23062318
# TODO(Shangming): Refactor mtp configuration logic when supporting
23072319
# mtp acceleration for more models besides deepseek_v3
23082320
if self.target_model_config and \
2309-
self.target_model_config.hf_text_config.model_type \
2310-
== "deepseek_v3":
2321+
(self.target_model_config.hf_text_config.model_type \
2322+
== "deepseek_v3" or
2323+
self.target_model_config.hf_text_config.architectures[0] \
2324+
== "MiMoForCausalLM"):
23112325
# use the draft model from the same model:
23122326
self.model = self.target_model_config.model
23132327
elif self.method in ("ngram", "[ngram]"):

vllm/model_executor/models/mimo.py

+190
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# Adapted from
4+
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
5+
# Copyright 2025 Xiaomi Corporation.
6+
# Copyright 2024 The Qwen team.
7+
# Copyright 2023 The vLLM team.
8+
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
9+
#
10+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
11+
# and OPT implementations in this library. It has been modified from its
12+
# original forms to accommodate minor architectural differences compared
13+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
14+
#
15+
# Licensed under the Apache License, Version 2.0 (the "License");
16+
# you may not use this file except in compliance with the License.
17+
# You may obtain a copy of the License at
18+
#
19+
# http://www.apache.org/licenses/LICENSE-2.0
20+
#
21+
# Unless required by applicable law or agreed to in writing, software
22+
# distributed under the License is distributed on an "AS IS" BASIS,
23+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
24+
# See the License for the specific language governing permissions and
25+
# limitations under the License.
26+
"""Inference-only MiMo model compatible with HuggingFace weights."""
27+
from typing import Iterable, Optional, Set, Tuple, Union
28+
29+
import torch
30+
import torch.nn as nn
31+
32+
from vllm.compilation.decorators import support_torch_compile
33+
from vllm.config import VllmConfig
34+
from vllm.distributed import get_pp_group
35+
from vllm.logger import init_logger
36+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
37+
from vllm.model_executor.layers.sampler import get_sampler
38+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
39+
from vllm.model_executor.model_loader.weight_utils import (
40+
default_weight_loader, maybe_remap_kv_scale_name)
41+
from vllm.model_executor.models.qwen2 import Qwen2ForCausalLM, Qwen2Model
42+
from vllm.model_executor.sampling_metadata import SamplingMetadata
43+
from vllm.sequence import IntermediateTensors
44+
45+
from .utils import PPMissingLayer, is_pp_missing_parameter, maybe_prefix
46+
47+
logger = init_logger(__name__)
48+
49+
50+
@support_torch_compile(
51+
dynamic_arg_dims={
52+
"input_ids": 0,
53+
"positions": -1,
54+
"intermediate_tensors": 0,
55+
"inputs_embeds": 0,
56+
})
57+
class MiMoModel(Qwen2Model):
58+
59+
def forward(
60+
self,
61+
input_ids: torch.Tensor,
62+
positions: torch.Tensor,
63+
intermediate_tensors: Optional[IntermediateTensors] = None,
64+
inputs_embeds: Optional[torch.Tensor] = None,
65+
) -> Union[torch.Tensor, IntermediateTensors]:
66+
if get_pp_group().is_first_rank:
67+
if inputs_embeds is not None:
68+
hidden_states = inputs_embeds
69+
else:
70+
hidden_states = self.get_input_embeddings(input_ids)
71+
residual = None
72+
else:
73+
assert intermediate_tensors is not None
74+
hidden_states = intermediate_tensors["hidden_states"]
75+
residual = intermediate_tensors["residual"]
76+
for layer in self.layers[self.start_layer:self.end_layer]:
77+
hidden_states, residual = layer(
78+
positions,
79+
hidden_states,
80+
residual,
81+
)
82+
if not get_pp_group().is_last_rank:
83+
return IntermediateTensors({
84+
"hidden_states": hidden_states,
85+
"residual": residual
86+
})
87+
hidden_states = hidden_states + residual
88+
return hidden_states
89+
90+
def load_weights(self, weights: Iterable[Tuple[str,
91+
torch.Tensor]]) -> Set[str]:
92+
stacked_params_mapping = [
93+
("qkv_proj", "q_proj", "q"),
94+
("qkv_proj", "k_proj", "k"),
95+
("qkv_proj", "v_proj", "v"),
96+
("gate_up_proj", "gate_proj", 0),
97+
("gate_up_proj", "up_proj", 1),
98+
]
99+
params_dict = dict(self.named_parameters(remove_duplicate=False))
100+
loaded_params: Set[str] = set()
101+
for name, loaded_weight in weights:
102+
if "mtp_layers" in name:
103+
continue
104+
if "rotary_emb.inv_freq" in name:
105+
continue
106+
if (self.quant_config is not None and
107+
(scale_name := self.quant_config.get_cache_scale(name))):
108+
# Loading kv cache quantization scales
109+
param = params_dict[scale_name]
110+
weight_loader = getattr(param, "weight_loader",
111+
default_weight_loader)
112+
loaded_weight = (loaded_weight if loaded_weight.dim() == 0 else
113+
loaded_weight[0])
114+
weight_loader(param, loaded_weight)
115+
loaded_params.add(scale_name)
116+
continue
117+
for (param_name, weight_name, shard_id) in stacked_params_mapping:
118+
if weight_name not in name:
119+
continue
120+
name = name.replace(weight_name, param_name)
121+
# Skip loading extra bias for GPTQ models.
122+
if name.endswith(".bias") and name not in params_dict:
123+
continue
124+
if is_pp_missing_parameter(name, self):
125+
continue
126+
param = params_dict[name]
127+
weight_loader = param.weight_loader
128+
weight_loader(param, loaded_weight, shard_id)
129+
break
130+
else:
131+
# Skip loading extra bias for GPTQ models.
132+
if name.endswith(".bias") and name not in params_dict:
133+
continue
134+
# Remapping the name of FP8 kv-scale.
135+
name = maybe_remap_kv_scale_name(name, params_dict)
136+
if name is None:
137+
continue
138+
if is_pp_missing_parameter(name, self):
139+
continue
140+
param = params_dict[name]
141+
weight_loader = getattr(param, "weight_loader",
142+
default_weight_loader)
143+
weight_loader(param, loaded_weight)
144+
loaded_params.add(name)
145+
return loaded_params
146+
147+
148+
class MiMoForCausalLM(Qwen2ForCausalLM, nn.Module):
149+
150+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
151+
nn.Module.__init__(self)
152+
config = vllm_config.model_config.hf_config
153+
quant_config = vllm_config.quant_config
154+
lora_config = vllm_config.lora_config
155+
156+
self.config = config
157+
self.lora_config = lora_config
158+
159+
self.quant_config = quant_config
160+
161+
self.model = MiMoModel(vllm_config=vllm_config,
162+
prefix=maybe_prefix(prefix, "model"))
163+
164+
if get_pp_group().is_last_rank:
165+
if config.tie_word_embeddings:
166+
self.lm_head = self.model.embed_tokens
167+
else:
168+
self.lm_head = ParallelLMHead(config.vocab_size,
169+
config.hidden_size,
170+
quant_config=quant_config,
171+
prefix=maybe_prefix(
172+
prefix, "lm_head"))
173+
else:
174+
self.lm_head = PPMissingLayer()
175+
176+
self.logits_processor = LogitsProcessor(config.vocab_size)
177+
self.sampler = get_sampler()
178+
179+
self.make_empty_intermediate_tensors = (
180+
self.model.make_empty_intermediate_tensors)
181+
182+
def compute_logits(
183+
self,
184+
hidden_states: torch.Tensor,
185+
sampling_metadata: SamplingMetadata,
186+
) -> Optional[torch.Tensor]:
187+
hidden_states = self.model.norm(hidden_states)
188+
logits = self.logits_processor(self.lm_head, hidden_states,
189+
sampling_metadata)
190+
return logits

0 commit comments

Comments
 (0)