diff --git a/docs/source/openvino/models.mdx b/docs/source/openvino/models.mdx index c11505fa4d..a91526db13 100644 --- a/docs/source/openvino/models.mdx +++ b/docs/source/openvino/models.mdx @@ -131,6 +131,7 @@ Here is the list of the supported architectures : - Qwen2VL - Qwen2.5VL - Qwen3VL +- Qwen3.5 - Qwen3-Next - ResNet - Roberta diff --git a/optimum/exporters/openvino/convert.py b/optimum/exporters/openvino/convert.py index d0efa2259f..23e309efdb 100644 --- a/optimum/exporters/openvino/convert.py +++ b/optimum/exporters/openvino/convert.py @@ -447,6 +447,16 @@ def ts_patched_forward(*args, **kwargs): extension=conversion_extensions, ) + if patch_16bit_model: + # Undo __make_16bit_traceable patching on sub-modules to avoid corrupting + # forward methods of modules shared across export behaviors (e.g. pos_embed + # Embedding in VLMs that is also exported separately as vision_embeddings_pos). + _orig_forward_attr = "_openvino_module_extension_patch_orig_forward" + for module in model.modules(): + if hasattr(module, _orig_forward_attr): + module.forward = getattr(module, _orig_forward_attr) + delattr(module, _orig_forward_attr) + ov_model.validate_nodes_and_infer_types() # TODO: remove as unnecessary validation? output_names = list(config.outputs.keys()) diff --git a/optimum/exporters/openvino/model_configs.py b/optimum/exporters/openvino/model_configs.py index cc1cac2714..5467e641c3 100644 --- a/optimum/exporters/openvino/model_configs.py +++ b/optimum/exporters/openvino/model_configs.py @@ -194,6 +194,9 @@ Qwen2MoEPatcher, Qwen2VLLanguageModelPatcher, Qwen2VLVisionEmbMergerPatcher, + Qwen3_5ModelPatcher, + Qwen3_5MoeModelPatcher, + Qwen3_5VisionEmbMergerPatcher, Qwen3MoeModelPatcher, Qwen3NextModelPatcher, Qwen3VLLanguageModelPatcher, @@ -253,6 +256,14 @@ def init_model_configs(): "transformers", "AutoModelForCausalLM", ) + TasksManager._CUSTOM_CLASSES[("pt", "qwen3_5", "image-text-to-text")] = ( + "transformers", + "AutoModelForImageTextToText", + ) + TasksManager._CUSTOM_CLASSES[("pt", "qwen3_5_moe", "image-text-to-text")] = ( + "transformers", + "AutoModelForImageTextToText", + ) # since transformers v4.46, model can be loaded using default AutoModelForImageTextToText # https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/models/auto/modeling_auto.py#L776 @@ -3551,6 +3562,14 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int return generated_input +class DummyQwen3_5LMInputGenerator(DummyTextInputGenerator): + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + generated_input = super().generate(input_name, framework, int_dtype, float_dtype) + if input_name == "position_ids": + return generated_input.unsqueeze(0).expand(4, -1, -1) + return generated_input + + class DummyQwen2VLVisionEmbedInputGenerator(DummyVisionInputGenerator): SUPPORTED_INPUT_NAMES = ( "hidden_states", @@ -5497,3 +5516,309 @@ def generate_dummy_inputs(self, framework: str = "pt", **kwargs): ) return dummy_inputs + + +class Qwen3_5DummyPastKeyValuesGenerator(DummyPastKeyValuesGenerator): + """ + Generates dummy cache_params inputs for Qwen3.5 architectures. + """ + + SUPPORTED_INPUT_NAMES = ("cache_params",) + + def __init__( + self, + task: str, + normalized_config, + batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"], + sequence_length: int = DEFAULT_DUMMY_SHAPES["sequence_length"], + **kwargs, + ): + super().__init__( + task=task, + normalized_config=normalized_config, + batch_size=batch_size, + sequence_length=sequence_length, + **kwargs, + ) + + config = normalized_config.config + self.num_full_attn_layers = config.layer_types.count("full_attention") + self.num_linear_attn_layers = config.layer_types.count("linear_attention") + self.conv_kernel_size = config.linear_conv_kernel_dim + self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads) + self.head_k_dim = config.linear_key_head_dim + self.head_v_dim = config.linear_value_head_dim + self.num_v_heads = config.linear_num_value_heads + self.num_k_heads = config.linear_num_key_heads + self.num_key_value_heads = config.num_key_value_heads + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): + cache_params = [] + + for idx in range(self.num_linear_attn_layers): + d_inner = self.num_k_heads * (2 * self.head_k_dim + self.head_v_dim * self.num_v_heads // self.num_k_heads) + conv_state_shape = ( + self.batch_size, + d_inner, + self.conv_kernel_size, + ) + conv_state = self.random_float_tensor(conv_state_shape, framework=framework, dtype=float_dtype) + cache_params.append(conv_state) + num_heads = self.num_v_heads + recurrent_state_shape = (self.batch_size, num_heads, self.head_k_dim, self.head_v_dim) + recurrent_state = self.random_float_tensor(recurrent_state_shape, framework=framework, dtype=float_dtype) + cache_params.append(recurrent_state) + + for idx in range(self.num_full_attn_layers): + kv_shape = (self.batch_size, self.num_key_value_heads, self.sequence_length, self.head_dim) + k = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype) + v = self.random_float_tensor(kv_shape, framework=framework, dtype=float_dtype) + cache_params.append(k) + cache_params.append(v) + + return cache_params + + +@register_in_tasks_manager( + "qwen3_5_text", + *["text-generation", "text-generation-with-past"], + library_name="transformers", +) +class Qwen3_5TextOpenVINOConfig(Qwen3VLTextOpenVINOConfig): + DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator, Qwen3_5DummyPastKeyValuesGenerator) + DUMMY_PKV_GENERATOR_CLASS = Qwen3_5DummyPastKeyValuesGenerator + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + MIN_TRANSFORMERS_VERSION = "4.57.0" + _MODEL_PATCHER = Qwen3_5ModelPatcher + + def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], direction: str): + if direction not in ["inputs", "outputs"]: + raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given') + + if direction == "inputs": + decoder_sequence_name = "past_sequence_length" + cache_name_prefix = "cache_params.past" + else: + decoder_sequence_name = "past_sequence_length + sequence_length" + cache_name_prefix = "cache_params.present" + + self.num_full_attn_layers = self._normalized_config.layer_types.count("full_attention") + self.num_linear_attn_layers = self._normalized_config.layer_types.count("linear_attention") + + for i in range(self.num_linear_attn_layers): + inputs_or_outputs[f"{cache_name_prefix}.conv.{i}"] = {0: "batch_size"} + inputs_or_outputs[f"{cache_name_prefix}.ssm.{i}"] = {0: "batch_size"} + + for i in range(self.num_full_attn_layers): + inputs_or_outputs[f"{cache_name_prefix}.key.{i}"] = {0: "batch_size", 2: decoder_sequence_name} + inputs_or_outputs[f"{cache_name_prefix}.value.{i}"] = {0: "batch_size", 2: decoder_sequence_name} + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + common_inputs = { + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "position_ids": {0: "batch_size", 1: "sequence_length"}, + } + if self.use_past_in_inputs: + self.add_past_key_values(common_inputs, direction="inputs") + return common_inputs + + def generate_dummy_inputs(self, framework: str = "pt", **kwargs): + dummy_inputs_generators = self._create_dummy_input_generator_classes(**kwargs) + + dummy_inputs = {} + input_names = [key for key in self.inputs.keys() if not key.startswith("cache_params")] + if self.use_past_in_inputs: + input_names.extend(["cache_params"]) + + for input_name in input_names: + input_was_inserted = False + for dummy_input_gen in dummy_inputs_generators: + if dummy_input_gen.supports_input(input_name): + dummy_inputs[input_name] = self.overwrite_shape_and_generate_input( + dummy_input_gen, + input_name, + framework, + input_shapes=kwargs, + ) + input_was_inserted = True + break + if not input_was_inserted: + raise RuntimeError( + f'Could not generate dummy input for "{input_name}". Try adding a proper dummy input generator to the model ONNX config.' + ) + + return dummy_inputs + + +@register_in_tasks_manager( + "qwen3_5", + *["image-text-to-text"], + library_name="transformers", +) +class Qwen3_5OpenVINOConfig(Qwen3VLOpenVINOConfig): + SUPPORTED_BEHAVIORS = [model_type.value for model_type in QwenVLConfigBehavior] + DUMMY_INPUT_GENERATOR_CLASSES = (DummyQwen3VLVisionEmbedInputGenerator,) + MIN_TRANSFORMERS_VERSION = "4.57.0" + + def __init__( + self, + config: "PretrainedConfig", + task: str = "feature-extraction", + int_dtype: str = "int64", + float_dtype: str = "fp32", + behavior: QwenVLConfigBehavior = QwenVLConfigBehavior.VISION_EMBEDDINGS, + preprocessors: Optional[List[Any]] = None, + ): + super().__init__( + config=config, + task=task, + int_dtype=int_dtype, + float_dtype=float_dtype, + preprocessors=preprocessors, + behavior=behavior, + ) + if self._behavior == QwenVLConfigBehavior.VISION_EMBEDDINGS_POS and hasattr(config, "vision_config"): + self._config = config.vision_config + self._normalized_config = self.NORMALIZED_CONFIG_CLASS(self._config) + self._normalized_config.use_embed_dim = True + + def with_behavior( + self, + behavior: Union[str, QwenVLConfigBehavior], + ): + """ + Creates a config for different behaviour. + Args: + behavior ([`ConfigBehavior`]): + The behavior to use for the new instance. + """ + if isinstance(behavior, str) and not isinstance(behavior, QwenVLConfigBehavior): + behavior = QwenVLConfigBehavior(behavior) + + if behavior == QwenVLConfigBehavior.TEXT_EMBEDDINGS: + return get_vlm_text_embeddings_config( + "qwen3_5_text", self._orig_config.text_config, self.int_dtype, self.float_dtype + ) + + if behavior == QwenVLConfigBehavior.LANGUAGE: + return get_vlm_text_generation_config( + "qwen3_5_text", + self._orig_config.text_config, + self.int_dtype, + self.float_dtype, + model_patcher=Qwen3_5ModelPatcher, + dummy_input_generator=DummyQwen3_5LMInputGenerator, + inputs_update={"position_ids": {1: "batch_size", 2: "sequence_length"}}, + ) + + if behavior in ( + QwenVLConfigBehavior.VISION_EMBEDDINGS, + QwenVLConfigBehavior.VISION_EMBEDDINGS_MERGER, + QwenVLConfigBehavior.VISION_EMBEDDINGS_POS, + ): + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + + def patch_model_for_export(self, model: Union["PreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None): + model_kwargs = model_kwargs or {} + if self._behavior == QwenVLConfigBehavior.VISION_EMBEDDINGS_MERGER: + return Qwen3_5VisionEmbMergerPatcher(self, model, model_kwargs) + if ( + self._behavior == QwenVLConfigBehavior.VISION_EMBEDDINGS + or self._behavior == QwenVLConfigBehavior.VISION_EMBEDDINGS_POS + ): + return ModelPatcher(self, model, model_kwargs=model_kwargs) + return super().patch_model_for_export(model, model_kwargs) + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior == QwenVLConfigBehavior.VISION_EMBEDDINGS_POS: + return { + "input": {1: "sequence_length"}, + } + return super().inputs + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior == QwenVLConfigBehavior.VISION_EMBEDDINGS: + return super().outputs + if self._behavior == QwenVLConfigBehavior.VISION_EMBEDDINGS_MERGER: + return {"last_hidden_state": {0: "seq_len"}} + if self._behavior == QwenVLConfigBehavior.VISION_EMBEDDINGS_POS: + return {"last_hidden_state": {0: "seq_len", 1: "seq_len"}} + if self._behavior == QwenVLConfigBehavior.TEXT_EMBEDDINGS: + return {"inputs_embeds": {0: "batch_size", 1: "sequence_length"}} + if self._behavior == QwenVLConfigBehavior.LANGUAGE: + return get_vlm_internal_text_generation_config( + "qwen3_5_text", self._orig_config.text_config, self.int_dtype, self.float_dtype + ).outputs + raise Exception("Unknown Qwen3.5 behavior type.") + + +@register_in_tasks_manager( + "qwen3_5_moe_text", + *["text-generation", "text-generation-with-past"], + library_name="transformers", +) +class Qwen3_5MoeTextOpenVINOConfig(Qwen3_5TextOpenVINOConfig): + _MODEL_PATCHER = Qwen3_5MoeModelPatcher + + +@register_in_tasks_manager( + "qwen3_5_moe", + *["image-text-to-text"], + library_name="transformers", +) +class Qwen3_5MoeOpenVINOConfig(Qwen3_5OpenVINOConfig): + def with_behavior( + self, + behavior: Union[str, QwenVLConfigBehavior], + ): + if isinstance(behavior, str) and not isinstance(behavior, QwenVLConfigBehavior): + behavior = QwenVLConfigBehavior(behavior) + + if behavior == QwenVLConfigBehavior.TEXT_EMBEDDINGS: + return get_vlm_text_embeddings_config( + "qwen3_5_moe_text", self._orig_config.text_config, self.int_dtype, self.float_dtype + ) + + if behavior == QwenVLConfigBehavior.LANGUAGE: + return get_vlm_text_generation_config( + "qwen3_5_moe_text", + self._orig_config.text_config, + self.int_dtype, + self.float_dtype, + model_patcher=Qwen3_5MoeModelPatcher, + dummy_input_generator=DummyQwen3_5LMInputGenerator, + inputs_update={"position_ids": {1: "batch_size", 2: "sequence_length"}}, + ) + + if behavior in ( + QwenVLConfigBehavior.VISION_EMBEDDINGS, + QwenVLConfigBehavior.VISION_EMBEDDINGS_MERGER, + QwenVLConfigBehavior.VISION_EMBEDDINGS_POS, + ): + return self.__class__( + self._orig_config, + task=self.task, + int_dtype=self.int_dtype, + float_dtype=self.float_dtype, + behavior=behavior, + preprocessors=self._preprocessors, + ) + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + if self._behavior == QwenVLConfigBehavior.LANGUAGE: + return get_vlm_internal_text_generation_config( + "qwen3_5_moe_text", self._orig_config.text_config, self.int_dtype, self.float_dtype + ).outputs + return super().outputs diff --git a/optimum/exporters/openvino/model_patcher.py b/optimum/exporters/openvino/model_patcher.py index a0c8b4601a..f906fa52d0 100644 --- a/optimum/exporters/openvino/model_patcher.py +++ b/optimum/exporters/openvino/model_patcher.py @@ -8417,3 +8417,395 @@ def __exit__(self, exc_type, exc_value, traceback): sparse_moe_block = decoder_layer.mlp decoder_layer.mlp.forward = decoder_layer.mlp._orig_forward del sparse_moe_block.down_projs, sparse_moe_block.gate_projs, sparse_moe_block.up_projs + + +# The CausalConv1D block is overridden with a generic patch provided by `ov_causal_conv1d()`. +# The GatedDeltaNet block is overridden with a recurrent version of its implementation. +# +# To replace GatedDeltaNet with its recurrent form, patching uses the ModuleExtension +# approach, which replaces the GatedDeltaNet block with a single operation, +# `GatedDeltaNetOp`. OpenVINO then applies the `convert_recurrent_attention_cell()` +# conversion rule to this operation. +def qwen3_5_gated_delta_net_forward( + self, + hidden_states: torch.Tensor, + cache_params=None, + cache_position: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, +): + def apply_mask_to_padding_states(hidden_states, attention_mask): + """ + Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 + """ + # NOTE: attention mask is a 2D boolean tensor + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + dtype = hidden_states.dtype + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + + return hidden_states + + hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) + + # Set up dimensions for reshapes later + batch_size, seq_len, _ = hidden_states.shape + + # getting projected states from cache if it exists + layer_idx = None + recurrent_state = None + if cache_params is not None: + layer_idx = cache_params.linear_attn_mapping[self.layer_idx] + conv_state = cache_params.conv_states[layer_idx] + recurrent_state = cache_params.recurrent_states[layer_idx] + + mixed_qkv = self.in_proj_qkv(hidden_states) + mixed_qkv = mixed_qkv.transpose(1, 2) + + z = self.in_proj_z(hidden_states) + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) + + b = self.in_proj_b(hidden_states) + a = self.in_proj_a(hidden_states) + + if cache_params is not None: + new_mixed_qkv, new_conv_state = ov_causal_conv1d(conv_state, mixed_qkv, self.conv1d.weight, self.conv1d.bias) + mixed_qkv = F.silu(new_mixed_qkv) + cache_params.conv_states[layer_idx] = new_conv_state + else: + mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len]) + + mixed_qkv = mixed_qkv.transpose(1, 2) + query, key, value = torch.split( + mixed_qkv, + [ + self.key_dim, + self.key_dim, + self.value_dim, + ], + dim=-1, + ) + query = query.reshape(query.shape[0], query.shape[1], -1, self.head_k_dim) + key = key.reshape(key.shape[0], key.shape[1], -1, self.head_k_dim) + value = value.reshape(value.shape[0], value.shape[1], -1, self.head_v_dim) + + beta = b.sigmoid() + # If the model is loaded in fp16, without the .float() here, A might be -inf + g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) + if self.num_v_heads // self.num_k_heads > 1: + query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) + + core_attn_out, last_recurrent_state = self.recurrent_gated_delta_rule( + self, + query, + key, + value, + g=g, + beta=beta, + initial_state=recurrent_state, + output_final_state=cache_params is not None, + use_qk_l2norm_in_kernel=True, + ) + + # Update cache + if cache_params is not None: + cache_params.recurrent_states[layer_idx] = last_recurrent_state + + # reshape input data into 2D tensor + core_attn_out = core_attn_out.reshape(-1, self.head_v_dim) + z = z.reshape(-1, self.head_v_dim) + core_attn_out = self.norm(core_attn_out, z) + core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) + + output = self.out_proj(core_attn_out) + return output + + +class Qwen3_5ModelPatcher(OVDecoderModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: "PreTrainedModel", + model_kwargs: Optional[Dict[str, Any]] = None, + ): + from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DynamicCache + + from openvino.frontend.pytorch import ConversionExtension, ModuleExtension + + from ._ov_ops import convert_recurrent_attention_cell + + super().__init__(config, model, model_kwargs) + + # Detect VLM vs text-only model + self._is_vlm = hasattr(self._model.model, "language_model") + if self._is_vlm: + self._text_model = self._model.model.language_model + self._text_config = self._model.config.text_config + else: + self._text_model = self._model.model + self._text_config = self._model.model.config + + class Qwen3_5DynamicCacheWrap(Qwen3_5DynamicCache): + def __init__(self, config, conv_states, recurrent_states, key_cache, value_cache): + # Call parent constructor with all required arguments + super().__init__(config=config) + + self.conv_states = conv_states + self.recurrent_states = recurrent_states + self.key_cache = key_cache + self.value_cache = value_cache + self.full_attn_mapping = {} + self.linear_attn_mapping = {} + full_attn_layer_idx = 0 + linear_attn_layer_idx = 0 + for i in range(len(config.layer_types)): + if self.layer_types[i] == "full_attention": + self.full_attn_mapping[i] = full_attn_layer_idx + full_attn_layer_idx += 1 + elif self.layer_types[i] == "linear_attention": + self.linear_attn_mapping[i] = linear_attn_layer_idx + linear_attn_layer_idx += 1 + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + # map layer_idx to key_cache (value_cache) idx + layer_idx = self.full_attn_mapping[layer_idx] + if self.key_cache[layer_idx] is None: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # take any layer that contains cache and not empty tensor + layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx + layer_idx = self.full_attn_mapping[layer_idx] + if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx] is None: + return 0 + return self.key_cache[layer_idx].shape[-2] + + @property + def has_previous_state(self): + """We have a previous state if the last linear (conv) layer was already updated.""" + layer_idx = self.linear_attn_mapping[self.last_linear_layer] + return self.conv_states[layer_idx] is not None + + # the patch is needed to include KV-cache, Conv, and SSM states in the inputs and outputs. + def patched_forward( + input_ids=None, + attention_mask=None, + cache_params=None, + inputs_embeds=None, + position_ids=None, + ): + text_config = self._text_config + num_full_attn_layers = text_config.layer_types.count("full_attention") + num_linear_attn_layers = text_config.layer_types.count("linear_attention") + + use_cache = False + wrapped_cache_params = None + if cache_params is not None: + use_cache = True + conv_states = [] + recurrent_states = [] + key_cache = [] + value_cache = [] + + # decouple ssm_states, conv_states, keys and values from cache_params + for idx in range(num_linear_attn_layers): + conv_states.append(cache_params[2 * idx]) + recurrent_states.append(cache_params[2 * idx + 1]) + + for idx in range(num_full_attn_layers): + key_cache.append(cache_params[2 * num_linear_attn_layers + 2 * idx]) + value_cache.append(cache_params[2 * num_linear_attn_layers + 2 * idx + 1]) + + wrapped_cache_params = Qwen3_5DynamicCacheWrap( + text_config, conv_states, recurrent_states, key_cache, value_cache + ) + + if self._is_vlm: + # VLM case: call language model through the composite model + outputs_lm = self._text_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=wrapped_cache_params, + use_cache=use_cache, + ) + hidden_states = outputs_lm[0] + logits = self._model.lm_head(hidden_states) + past_kv = outputs_lm.past_key_values + else: + causal_lm_output = self.model_orig_forward( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=wrapped_cache_params, + use_cache=use_cache, + ) + logits = causal_lm_output.logits + past_kv = causal_lm_output.past_key_values + outputs = { + "logits": logits, + } + + if use_cache: + present_key_values = [] + for idx in range(num_linear_attn_layers): + present_key_values.append(past_kv.conv_states[idx]) + present_key_values.append(past_kv.recurrent_states[idx]) + + for idx in range(num_full_attn_layers): + present_key_values.append(past_kv.key_cache[idx]) + present_key_values.append(past_kv.value_cache[idx]) + + outputs["present_key_values"] = present_key_values + + return outputs + + self.patched_forward = patched_forward + self.model_orig_forward = self.orig_forward + self.orig_forward = patched_forward + + self.module_extensions = { + RecurrentAttentionCell: ModuleExtension(RecurrentAttentionCell, "RecurrentAttentionCellOp"), + } + self.conversion_extensions = [ + ConversionExtension("RecurrentAttentionCellOp", convert_recurrent_attention_cell), + ] + + def __enter__(self): + super().__enter__() + setattr(self._model, self.orig_forward_name, self.patched_forward) + + for idx, decoder_layer in enumerate(self._text_model.layers): + layer_type = self._text_config.layer_types[idx] + if layer_type == "linear_attention": + linear_attn_layer = decoder_layer.linear_attn + linear_attn_layer._orig_forward = linear_attn_layer.forward + linear_attn_layer.forward = types.MethodType(qwen3_5_gated_delta_net_forward, linear_attn_layer) + linear_attn_layer.recurrent_gated_delta_rule = patched_recurrent_gated_delta_rule + linear_attn_layer.recurrent_attention_cell = RecurrentAttentionCell() + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + setattr(self._model, self.orig_forward_name, self.model_orig_forward) + for idx, decoder_layer in enumerate(self._text_model.layers): + layer_type = self._text_config.layer_types[idx] + if layer_type == "linear_attention": + linear_attn_layer = decoder_layer.linear_attn + linear_attn_layer.forward = linear_attn_layer._orig_forward + + +class Qwen3_5VisionEmbMergerPatcher(ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: "PreTrainedModel", + model_kwargs: Dict[str, Any] = None, + ): + model.__orig_forward = model.forward + + # Adapted from Qwen3.5 VisionModel forward + # added attention_mask input instead of cu_seqlens for its internal calculation + # separated patch_embed and rot_pos_emb calls for performing as part of another model + def image_embed_forward( + self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, rotary_pos_emb: torch.Tensor + ) -> torch.Tensor: + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + position_embeddings = (emb.cos(), emb.sin()) + for blk in self.blocks: + hidden_states = blk( + hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings + ) + return self.merger(hidden_states) + + model.forward = types.MethodType(image_embed_forward, model) + super().__init__(config, model, model_kwargs) + + def __enter__(self): + patch_qwen2vl_vision_blocks(self._model) + super().__enter__() + + def __exit__(self, exc_type, exc_value, traceback): + super().__exit__(exc_type, exc_value, traceback) + self._model.forward = self._model.__orig_forward + for block in self._model.blocks: + block.forward = block._orig_forward + block.attn.forward = block.attn._orig_forward + + +def patched_qwen3_5_moe_sparse_moe_block(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_experts = self.experts.num_experts + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # router returns (logits, scores, indices) + _, routing_weights, selected_experts = self.gate(hidden_states) + + new_routing_weights = torch.zeros(batch_size * sequence_length, num_experts, dtype=routing_weights.dtype) + new_routing_weights.scatter_(dim=1, index=selected_experts, src=routing_weights) + + shared_expert_output = self.shared_expert(hidden_states) + shared_expert_output = torch.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + + hidden_states = hidden_states.repeat(num_experts, 1) + hidden_states = hidden_states.view(num_experts, -1, hidden_dim) + act_fn = self.experts.act_fn + + # compute experts outputs in a vectorized form using torch.bmm + gate = torch.bmm(hidden_states, self.gate_projs.transpose(1, 2)) + up = torch.bmm(hidden_states, self.up_projs.transpose(1, 2)) + gate_up = act_fn(gate) * up + next_states = torch.bmm(gate_up, self.down_projs.transpose(1, 2)) + next_states = next_states.view(num_experts, batch_size, -1, hidden_dim) + next_states = next_states * new_routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.sum(dim=0) + + shared_expert_output = shared_expert_output.view(batch_size, -1, hidden_dim) + output = shared_expert_output + next_states + return output.view(batch_size, sequence_length, hidden_dim) + + +class Qwen3_5MoeModelPatcher(Qwen3_5ModelPatcher): + def __init__( + self, + config: "OnnxConfig", + model: "PreTrainedModel", + model_kwargs: Optional[Dict[str, Any]] = None, + ): + super().__init__(config, model, model_kwargs) + + def __enter__(self): + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock + + super().__enter__() + for decoder_layer in self._text_model.layers: + if isinstance(decoder_layer.mlp, Qwen3_5MoeSparseMoeBlock): + sparse_moe_block = decoder_layer.mlp + intermediate_dim = sparse_moe_block.experts.intermediate_dim + sparse_moe_block._orig_forward = sparse_moe_block.forward + sparse_moe_block.forward = types.MethodType(patched_qwen3_5_moe_sparse_moe_block, sparse_moe_block) + # TODO: remove `float()` casting when CVS-181449 is fixed + # now it is needed to have MoE optimizations to be applied + sparse_moe_block.gate_projs = sparse_moe_block.experts.gate_up_proj[:, :intermediate_dim, :].float() + sparse_moe_block.up_projs = sparse_moe_block.experts.gate_up_proj[:, intermediate_dim:, :].float() + sparse_moe_block.down_projs = sparse_moe_block.experts.down_proj.data.float() + + def __exit__(self, exc_type, exc_value, traceback): + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import Qwen3_5MoeSparseMoeBlock + + super().__exit__(exc_type, exc_value, traceback) + for decoder_layer in self._text_model.layers: + if isinstance(decoder_layer.mlp, Qwen3_5MoeSparseMoeBlock): + sparse_moe_block = decoder_layer.mlp + sparse_moe_block.forward = sparse_moe_block._orig_forward + del sparse_moe_block.gate_projs, sparse_moe_block.up_projs, sparse_moe_block.down_projs diff --git a/optimum/exporters/openvino/stateful.py b/optimum/exporters/openvino/stateful.py index 3b8642d65a..38ffef5d05 100644 --- a/optimum/exporters/openvino/stateful.py +++ b/optimum/exporters/openvino/stateful.py @@ -310,6 +310,10 @@ def patch_stateful(config: PretrainedConfig, ov_model: ov.Model): return patch_stateful_encoder_decoder(config, ov_model) if config.model_type in SSM_MODELS: return patch_stateful_hybrid_ssm(ov_model) + # For VLM models, the text sub-model may be SSM-based (e.g. qwen3_5 VLM with qwen3_5_text language model) + text_config = getattr(config, "text_config", None) + if text_config is not None and getattr(text_config, "model_type", None) in SSM_MODELS: + return patch_stateful_hybrid_ssm(ov_model) return patch_stateful_decoder(config, ov_model) diff --git a/optimum/exporters/openvino/utils.py b/optimum/exporters/openvino/utils.py index af2f1edaba..6314803bbc 100644 --- a/optimum/exporters/openvino/utils.py +++ b/optimum/exporters/openvino/utils.py @@ -295,6 +295,8 @@ def get_submodels(model): "qwen2_vl", "qwen2_5_vl", "qwen3_vl", + "qwen3_5", + "qwen3_5_moe", "got_ocr2", "gemma3", "idefics3", @@ -305,7 +307,16 @@ def get_submodels(model): "minicpmo", ] -SSM_MODELS = ["mamba", "falcon_mamba", "zamba2", "lfm2", "granitemoehybrid", "qwen3_next"] +SSM_MODELS = [ + "mamba", + "falcon_mamba", + "zamba2", + "lfm2", + "granitemoehybrid", + "qwen3_next", + "qwen3_5_text", + "qwen3_5_moe_text", +] # All transformers, diffusers, timm and sentence transformers models that are supported via optimum-onnx OnnxConfigs but that have currently no test # TODO: add tests for all models that are compatible and remove support for all others diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index ccf177df9d..7044953664 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -1449,8 +1449,14 @@ def prepare_inputs_for_generation( # decoding stage so it takes the last token input_ids = input_ids[:, -1].unsqueeze(-1) - if self.config.model_type not in ["lfm2", "granitemoehybrid", "qwen3_next"]: - # LFM2, GraniteMoeHybrid (Granite-4.0), and Qwen3-Next require the attention mask + if self.config.model_type not in [ + "lfm2", + "granitemoehybrid", + "qwen3_next", + "qwen3_5_text", + "qwen3_5_moe_text", + ]: + # LFM2, GraniteMoeHybrid (Granite-4.0), Qwen3-Next, and Qwen3.5 require the attention mask # to be the length of the full context, so default mask from OVModelForCausalLM needs to be used. # Other models like Mamba typically do not require an attention_mask # for the decoding step after the first token so use attention mask of ones. diff --git a/optimum/intel/openvino/modeling_visual_language.py b/optimum/intel/openvino/modeling_visual_language.py index beb7b974eb..427860775e 100644 --- a/optimum/intel/openvino/modeling_visual_language.py +++ b/optimum/intel/openvino/modeling_visual_language.py @@ -188,9 +188,11 @@ def prepare_inputs( position_ids = np.cumsum(attention_mask, axis=1) - 1 position_ids[attention_mask == 0] = 1 if past_len: - position_ids = position_ids[:, -inputs_embeds.shape[1] :] + position_ids = position_ids[..., -inputs_embeds.shape[1] :] - if (self.config.model_type in ["qwen2_vl", "qwen3_vl"]) and position_ids.ndim != 3: + if self.config.model_type in ["qwen3_5", "qwen3_5_moe"] and position_ids.ndim != 3: + position_ids = np.repeat(np.expand_dims(position_ids, 0), 4, axis=0) + elif self.config.model_type in ["qwen2_vl", "qwen3_vl"] and position_ids.ndim != 3: position_ids = np.repeat(np.expand_dims(position_ids, 0), 3, axis=0) inputs["position_ids"] = position_ids @@ -3439,6 +3441,44 @@ class Qwen3VLModel: class Qwen3VLVisionModel: pass + class Qwen3_5Model: + pass + + class Qwen3_5VisionModel: + pass + + +if is_transformers_version(">=", "5.2.0"): + from transformers.models.qwen3_5.modeling_qwen3_5 import ( + Qwen3_5Model, + Qwen3_5VisionModel, + Qwen3_5VisionRotaryEmbedding, + ) +else: + + class Qwen3_5Model: + pass + + class Qwen3_5VisionModel: + pass + + class Qwen3_5VisionRotaryEmbedding: + pass + + +if is_transformers_version(">=", "5.2.0"): + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeModel, + Qwen3_5MoeVisionModel, + ) +else: + + class Qwen3_5MoeModel: + pass + + class Qwen3_5MoeVisionModel: + pass + # The inheritance from Qwen3VLModel is needed to get access to methods: # get_placeholder_mask(): https://github.com/huggingface/transformers/blob/v4.57.6/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py#L1066 @@ -4802,6 +4842,396 @@ def preprocess_inputs( return inputs +# The inheritance from Qwen3_5Model is needed to get access to methods: +# get_placeholder_mask(), get_rope_index(), get_image_features(), get_video_features(), compute_3d_position_ids() +# +# and inheritance from Qwen3_5VisionModel is needed for accessing the following method: +# rot_pos_emb() +class _OVQwen3_5ForCausalLM(OVModelForVisualCausalLM, Qwen3_5Model, Qwen3_5VisionModel): + additional_parts = ["vision_embeddings_merger", "vision_embeddings_pos"] + + def __init__( + self, + language_model: ov.Model, + text_embeddings: ov.Model, + vision_embeddings: ov.Model, + config: PretrainedConfig = None, + device: str = "CPU", + dynamic_shapes: bool = None, + ov_config: Optional[Dict[str, str]] = None, + model_save_dir: Optional[Union[str, Path, TemporaryDirectory]] = None, + quantization_config: Union[OVWeightQuantizationConfig, Dict] = None, + **kwargs, + ): + if is_transformers_version("<", "4.57.0"): + raise Exception("Qwen3.5 is not supported in transformers versions earlier than 4.57.0.") + + super().__init__( + language_model=language_model, + text_embeddings=text_embeddings, + vision_embeddings=vision_embeddings, + config=config, + device=device, + dynamic_shapes=dynamic_shapes, + ov_config=ov_config, + model_save_dir=model_save_dir, + quantization_config=quantization_config, + **kwargs, + ) + self.rope_deltas = None # cache rope_deltas here + + self.num_grid_per_side = int(config.vision_config.num_position_embeddings**0.5) + self.spatial_merge_size = config.vision_config.spatial_merge_size + head_dim = config.vision_config.hidden_size // config.vision_config.num_heads + self.rotary_pos_emb = Qwen3_5VisionRotaryEmbedding(head_dim // 2) + + def __setattr__(self, name, value): + OVModelForVisualCausalLM.__setattr__(self, name, value) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + **kwargs, + ): + # Overwritten -- in specific circumstances we don't want to forward image inputs to the model + if past_key_values is not None: + if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4 + inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :] + elif inputs_embeds is not None: + input_ids = input_ids[:, -cache_position.shape[0] :] + elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2) + input_ids = input_ids[:, cache_position] + + if cache_position[0] != 0: + pixel_values = None + pixel_values_videos = None + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]: + model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} + else: + model_inputs = {"input_ids": input_ids, "inputs_embeds": None} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": use_cache, + "attention_mask": attention_mask, + "pixel_values": pixel_values, + "pixel_values_videos": pixel_values_videos, + "image_grid_thw": image_grid_thw, + "video_grid_thw": video_grid_thw, + "cache_position": cache_position, + } + ) + return model_inputs + + # Adapted from Qwen3_5VisionModel.fast_pos_embed_interpolate + # This method needs to be changed, as instead of running self.pos_embed of type nn.Embedding, openvino model needs to be inferred (self.vision_embeddings_pos) + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list) + weight_tensor = torch.tensor(weight_list) + pos_embeds = torch.from_numpy(self.vision_embeddings_pos(idx_tensor)) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.vision_config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def get_vision_embeddings(self, pixel_values, grid_thw, **kwargs): + hidden_states = torch.from_numpy(self.vision_embeddings(pixel_values)[0]) + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + hidden_states = hidden_states + pos_embeds + + rotary_pos_emb = self.rot_pos_emb(grid_thw) + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum( + dim=0, dtype=torch.int32 + ) + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + attention_mask = torch.zeros((1, hidden_states.shape[0], hidden_states.shape[0]), dtype=torch.bool) + causal_mask = torch.zeros_like(attention_mask, dtype=torch.float32) + for i in range(1, len(cu_seqlens)): + attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = True + + causal_mask.masked_fill_(torch.logical_not(attention_mask), float("-inf")) + + res = self.vision_embeddings_merger( + pixel_values=hidden_states, attention_mask=causal_mask, rotary_pos_emb=rotary_pos_emb + ) + return res[0] + + def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None): + """ + Encodes images into continuous embeddings that can be forwarded to the language model. + """ + image_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values, image_grid_thw)) + split_sizes = (image_grid_thw.prod(-1) // self.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + return image_embeds + + def get_video_features( + self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None + ): + """ + Encodes videos into continuous embeddings that can be forwarded to the language model. + """ + return self.get_image_features(pixel_values_videos, video_grid_thw) + + def get_multimodal_embeddings( + self, + input_ids, + pixel_values=None, + attention_mask=None, + position_ids=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + cache_position=None, + **kwargs, + ): + inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids)) + if pixel_values is not None and input_ids.shape[1] != 1: + image_embeds = self.get_image_features(pixel_values, image_grid_thw) + image_embeds = torch.cat(image_embeds, dim=0) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + + mask = input_ids == self.config.image_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + image_mask = mask_expanded.to(inputs_embeds.device) + + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None and input_ids.shape[1] != 1: + video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw) + video_embeds = torch.cat(video_embeds, dim=0) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + + mask = input_ids == self.config.video_token_id + mask_unsqueezed = mask.unsqueeze(-1) + mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) + video_mask = mask_expanded.to(inputs_embeds.device) + + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: + vision_positions, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + # Compute text positions (simple cumsum) and concatenate as dim 0 + # to create shape (4, batch, seq_len): [text_pos, temporal, height, width] + if attention_mask is not None: + text_positions = attention_mask.long().cumsum(-1) - 1 + text_positions = text_positions.masked_fill(attention_mask == 0, 1) + else: + text_positions = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .unsqueeze(0) + .expand(input_ids.shape[0], -1) + ) + position_ids = torch.cat([text_positions.unsqueeze(0), vision_positions], dim=0) + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + # Prepend text positions for shape (4, batch, seq_len) + text_positions = torch.arange(seq_length, device=inputs_embeds.device) + text_positions = text_positions.view(1, -1).expand(batch_size, -1) + if cache_position is not None: + text_positions = text_positions + cache_position[0] + position_ids = torch.cat([text_positions.unsqueeze(0), position_ids], dim=0) + + return inputs_embeds, attention_mask, position_ids + + @staticmethod + def preprocess_inputs( + text: str, + image: Optional["Image"] = None, + processor: Optional[AutoImageProcessor] = None, + tokenizer: Optional[PreTrainedTokenizer] = None, + config: Optional[PretrainedConfig] = None, + video: Optional["VideoInput"] = None, + audio: Optional[np.ndarray] = None, + ): + if processor is None: + raise ValueError("Processor is required.") + if audio is not None: + raise ValueError("Audio input is not supported") + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": text}, + ], + } + ] + if image is not None: + conversation[0]["content"].insert(0, {"type": "image"}) + if video is not None: + conversation[0]["content"].insert(0, {"type": "video"}) + + text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + + inputs = processor(images=image, text=text_prompt, videos=video, return_tensors="pt") + return inputs + + def forward( + self, + input_ids, + pixel_values=None, + past_key_values=None, + inputs_embeds=None, + image_sizes=None, + attention_mask=None, + position_ids=None, + image_bound=None, + tgt_sizes=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + rope_deltas=None, + **kwargs, + ): + result = super().forward( + input_ids, + pixel_values, + past_key_values, + inputs_embeds, + image_sizes, + attention_mask, + position_ids, + image_bound, + tgt_sizes, + pixel_values_videos, + image_grid_thw, + video_grid_thw, + rope_deltas, + **kwargs, + ) + final_result = QWen2VLModelOutputWithPast( + logits=result.logits, past_key_values=result.past_key_values, rope_deltas=rope_deltas + ) + return final_result + + def _prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs): + # Mirrors Qwen3_5ForConditionalGeneration._prepare_position_ids_for_generation + # Creates proper 4D position_ids: [text_positions, temporal, height, width] + text_positions = GenerationMixin._prepare_position_ids_for_generation(self, inputs_tensor, model_kwargs) + + if "input_ids" in model_kwargs and model_kwargs["input_ids"].shape[1] > 0: + inputs_tensor = model_kwargs["input_ids"] + + is_input_ids = len(inputs_tensor.shape) == 2 and inputs_tensor.dtype in [torch.int, torch.long] + if is_input_ids and ( + model_kwargs.get("image_grid_thw") is not None or model_kwargs.get("video_grid_thw") is not None + ): + filtered_kwargs = {k: v for k, v in model_kwargs.items() if k != "input_ids"} + vision_positions, rope_deltas = self.get_rope_index(inputs_tensor, **filtered_kwargs) + self.rope_deltas = rope_deltas + else: + vision_positions = text_positions.unsqueeze(0).expand(3, -1, -1) + self.rope_deltas = torch.zeros(inputs_tensor.shape[0], 1, dtype=torch.long, device=inputs_tensor.device) + + # Concatenate "text + vision" positions into [4, bs, seq-len] + text_positions = text_positions[None, ...] + position_ids = torch.cat([text_positions, vision_positions], dim=0) + return position_ids + + def generate(self, *args, **kwargs): + # Clear cached rope delta from previous generations + self.rope_deltas = None + + return super().generate(*args, **kwargs) + + MODEL_TYPE_TO_CLS_MAPPING = { "llava": _OVLlavaForCausalLM, "llava_next": _OVLlavaNextForCausalLM, @@ -4823,5 +5253,9 @@ def preprocess_inputs( "phi4_multimodal": _OVPhi4MMForCausalLM, "llama4": _OVLlama4ForCausalLM, "qwen3_vl": _OVQwen3VLForCausalLM, + "qwen3_5": _OVQwen3_5ForCausalLM, + "qwen3_5_text": _OVQwen3_5ForCausalLM, + "qwen3_5_moe": _OVQwen3_5ForCausalLM, + "qwen3_5_moe_text": _OVQwen3_5ForCausalLM, "minicpmo": _OVMiniCPMOForCausalLM, } diff --git a/tests/openvino/test_decoder.py b/tests/openvino/test_decoder.py index 3067f1c5c4..a1e47c7451 100644 --- a/tests/openvino/test_decoder.py +++ b/tests/openvino/test_decoder.py @@ -27,6 +27,7 @@ BitnetOpenVINOConfig, DeepseekOpenVINOConfig, LFM2OpenVINOConfig, + Qwen3_5TextOpenVINOConfig, Qwen3VLOpenVINOConfig, ) from optimum.exporters.openvino.model_patcher import patch_update_causal_mask @@ -337,6 +338,10 @@ def test_find_untested_architectures(self): "qwen3_next", } + # qwen3_5_text a part of qwen3_5 architecture and is tested in seq2seq group + if is_transformers_version(">=", str(Qwen3_5TextOpenVINOConfig.MIN_TRANSFORMERS_VERSION)): + supported_architectures -= {"qwen3_5_text"} + supported_architectures -= ONNX_SUPPORTED_ARCHITECTURES untested_architectures = supported_architectures - tested_architectures diff --git a/tests/openvino/utils_tests.py b/tests/openvino/utils_tests.py index 1117604b7b..783e3584ba 100644 --- a/tests/openvino/utils_tests.py +++ b/tests/openvino/utils_tests.py @@ -173,6 +173,7 @@ "qwen3_moe": "optimum-intel-internal-testing/tiny-random-qwen3moe", "qwen3_vl": "optimum-intel-internal-testing/tiny-random-qwen3-vl", "qwen3_next": "optimum-intel-internal-testing/tiny-random-qwen3-next", + "qwen3_5": "optimum-intel-internal-testing/tiny-random-qwen3.5", "rembert": "optimum-intel-internal-testing/tiny-random-rembert", "resnet": "optimum-intel-internal-testing/tiny-random-resnet", "roberta": "optimum-intel-internal-testing/tiny-random-roberta", @@ -335,6 +336,13 @@ "vision_embeddings_merger_model": 32, "vision_embeddings_pos_model": 1, }, + "qwen3_5": { + "lm_model": 100, + "text_embeddings_model": 1, + "vision_embeddings_model": 1, + "vision_embeddings_merger_model": 32, + "vision_embeddings_pos_model": 1, + }, "sana": { "transformer": 58, "vae_decoder": 28,