# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Patch configuration for Qwen3-Omni-MoE transformers>=5.2.0 code generation.

Covers the thinker training path (text + vision + audio + MoE):
  - Vision SP slicing with pad_scale=4 + varlen-aware attention
  - Audio tower with SP gather/slice of (mel, time) features
  - Thinker text model FSDP-safe deepstack process + MoeModelOutputWithPast
  - Qwen3OmniMoeThinkerTextExperts fused-MoE replacement (drops the
    @use_experts_implementation decorator which would otherwise bypass our
    fused kernel)
  - Thinker ForConditionalGeneration: pre-computed image/video/audio masks,
    async-Ulysses-aware embed gather+scatter, deepstack SP selection, fused
    loss via self.loss_function, precomputed multimodal position-ids
  - Qwen3OmniMoeForConditionalGeneration: skip talker, pin _no_split_modules
    down to thinker.text_config, forward-to-thinker only (skip talker),
    VeOmni parallel plan.

Regen command:
python -m veomni.patchgen.run_codegen veomni.models.transformers.qwen3_omni_moe.qwen3_omni_moe_gpu_patch_gen_config -o veomni/models/transformers/qwen3_omni_moe/generated --diff
"""

import copy
from functools import partial
from types import SimpleNamespace
from typing import Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.activations import ACT2FN
from transformers.cache_utils import DynamicCache
from transformers.masking_utils import create_causal_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import (
    BaseModelOutputWithPooling,
    MoeModelOutputWithPast,
)
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import (
    BaseModelOutputWithDeepstackFeatures,
    Qwen3OmniMoeThinkerForConditionalGeneration,
    _get_feat_extract_output_lengths,
    apply_rotary_pos_emb_vision,
    eager_attention_forward,
    load_balancing_loss_func,
)
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs

from veomni.distributed.parallel_state import get_parallel_state
from veomni.distributed.sequence_parallel import (
    gather_outputs,
    slice_input_tensor,
    sp_pad_and_slice,
    unpad_tensor,
)
from veomni.distributed.sequence_parallel.ulysses import _Gather
from veomni.models.transformers.attention_utils import VARLEN_ATTENTION_TYPES
from veomni.ops import fused_moe_forward
from veomni.patchgen.patch_spec import PatchConfig
from veomni.utils.constants import (
    AUDIO_INPUT_INDEX,
    IGNORE_INDEX,
    IMAGE_INPUT_INDEX,
    VIDEO_INPUT_INDEX,
)
from veomni.utils.model_outputs import Qwen3OmniMoeThinkerCausalLMOutputWithLogProbs


config = PatchConfig(
    source_module="transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe",
    target_file="patched_modeling_qwen3_omni_moe_gpu.py",
    description="Qwen3-Omni-MoE thinker with VeOmni v5 compatibility (SP + FSDP + fused MoE + fused loss)",
)


# ================================================================
# Additional imports needed by the patched methods in the generated file
# ================================================================
config.add_import("copy", is_from_import=False)
config.add_import("functools", names=["partial"])
config.add_import("types", names=["SimpleNamespace"])
config.add_import("torch.nn.functional", alias="F", is_from_import=False)
config.add_import("veomni.distributed.parallel_state", names=["get_parallel_state"])
config.add_import(
    "veomni.distributed.sequence_parallel",
    names=[
        "gather_outputs",
        "slice_input_tensor",
        "sp_pad_and_slice",
        "unpad_tensor",
    ],
)
config.add_import("veomni.distributed.sequence_parallel.ulysses", names=["_Gather"])
config.add_import("veomni.models.transformers.attention_utils", names=["VARLEN_ATTENTION_TYPES"])
config.add_import("veomni.ops", names=["fused_moe_forward"])
# Surface ``Qwen3OmniMoeThinkerCausalLMOutputWithLogProbs`` so the patched
# ``Qwen3OmniMoeThinkerForConditionalGeneration.forward`` can return per-token
# log-probs / entropy as constructor fields while preserving ``aux_loss`` and
# ``rope_deltas``. Mutating ``output.log_probs`` / ``output.entropy`` after the
# base-class constructor would bypass ``ModelOutput`` pytree flattening,
# breaking FSDP2's pre-backward unshard hook on ``lm_head`` and triggering
# ``setStorage … storage of size 0`` in ``chunk_logprobs.backward`` (parallels
# VeOmni #731's qwen3_5_moe fix).
config.add_import(
    "veomni.utils.model_outputs",
    names=["FusedLinearAuxOutput", "FusedLinearAuxOutputMixin", "Qwen3OmniMoeThinkerCausalLMOutputWithLogProbs"],
)
config.drop_import_names("Qwen3OmniMoeThinkerCausalLMOutputWithPast")

config.add_post_import_block(
    """
    # ── OpSlot declarations ──────────────────────────────────────────────────
    # Only the Thinker forward is in VeOmni's training path (Talker/CodePredictor
    # are inference-only speech paths excluded from the generated file).
    from veomni.ops.dispatch import OpSlot
    veomni_moe_experts_forward = OpSlot("moe_experts", "standard")
    veomni_causal_lm_loss = OpSlot("cross_entropy_loss", "causal")
    veomni_load_balancing_loss = OpSlot("load_balancing_loss", "standard")
    """
)
config.add_import(
    "veomni.utils.constants",
    names=["AUDIO_INPUT_INDEX", "IGNORE_INDEX", "IMAGE_INPUT_INDEX", "VIDEO_INPUT_INDEX"],
)


# ================================================================
# Drop talker + code2wav from generated output.
#
# The patched `Qwen3OmniMoeForConditionalGeneration.__init__` forces
# `has_talker=False` and never calls `enable_talker()`, so the talker /
# code2wav modules are never instantiated on the training path. But HF's
# `PreTrainedModel.post_init` aggregates `_no_split_modules` by walking the
# class hierarchy, so leaving `Qwen3OmniMoeTalkerDecoderLayer`,
# `Qwen3OmniMoeCode2WavTransformerLayer`, etc. in the generated module
# still lets them contribute to the aggregation and can pull dead layer
# names into the FSDP no-split set. Excluding them here also trims the
# generated file (~1500 lines) and removes an import-time footprint that
# isn't exercised.
#
# The top-level package `__init__.py` imports `Qwen3OmniMoeTalkerModel` /
# `Qwen3OmniMoeTalkerForConditionalGeneration` directly from upstream
# transformers (not from the generated file), so excluding them here is
# safe for the registry.
#
# The remaining methods on `Qwen3OmniMoeForConditionalGeneration`
# (`enable_talker`, `_get_talker_*`, `generate`, `token2wav`) reference
# these classes by name but only at call time (Python late binding) — the
# training forward never reaches them.
# ================================================================
config.exclude_from_output(
    # Talker
    "Qwen3OmniMoeTalkerResizeMLP",
    "Qwen3OmniMoeTalkerCodePredictorOutputWithPast",
    "Qwen3OmniMoeTalkerCodePredictorAttention",
    "Qwen3OmniMoeTalkerCodePredictorDecoderLayer",
    "Qwen3OmniMoeTalkerCodePredictorModel",
    "Qwen3OmniMoeTalkerCodePredictorModelForConditionalGeneration",
    "Qwen3OmniMoeTalkerOutputWithPast",
    "Qwen3OmniMoeTalkerRotaryEmbedding",
    "Qwen3OmniMoeTalkerTextMLP",
    "Qwen3OmniMoeTalkerTextTopKRouter",
    "Qwen3OmniMoeTalkerTextExperts",
    "Qwen3OmniMoeTalkerTextSparseMoeBlock",
    "Qwen3OmniMoeTalkerDecoderLayer",
    "Qwen3OmniMoeTalkerModel",
    "Qwen3OmniMoeTalkerForConditionalGeneration",
    # Shared by talker + code2wav (not referenced by thinker)
    "Qwen3OmniMoeRMSNorm",
    "Qwen3OmniMoeMLP",
    "Qwen3OmniMoeRotaryEmbedding",
    # Code2Wav
    "Qwen3OmniMoeCausalConvNet",
    "Qwen3OmniMoeCausalTransConvNet",
    "Qwen3OmniMoeConvNeXtBlock",
    "Qwen3OmniMoeCode2WavAttention",
    "Qwen3OmniMoeCode2WavMlp",
    "Qwen3OmniMoeCode2WavRMSNorm",
    "Qwen3OmniMoeCode2WavLayerScale",
    "Qwen3OmniMoeCode2WavTransformerLayer",
    "Qwen3OmniMoeCode2WavTransformerModel",
    "Qwen3OmniMoeCode2WavDecoderResidualUnit",
    "Qwen3OmniMoeCode2WavDecoderBlock",
    "Qwen3OmniMoeCode2Wav",
    # SnakeBeta activation is only referenced inside the excluded Code2Wav
    # residual blocks, so exclude it too to avoid generating dead code.
    "SnakeBeta",
)


# ================================================================
# Module-level helper emitted into the generated file so multiprocessing
# dataloaders can pickle the per-sample position-id closure.
# ================================================================
config.add_post_import_block(
    '''
def get_position_id(main_func, self, **kwargs):
    """Per-sample position-ids for VeOmni dataloader workers.

    Invoked inside the data pipeline (bs=1 per sample). Wraps the HF
    get_rope_index so it can be partial-bound with a SimpleNamespace carrying
    the model config and helpers, then shipped across multiprocessing workers.
    """
    position_ids, rope_deltas = main_func(self, **kwargs)
    assert len(position_ids.shape) == 3 and position_ids.shape[1] == 1
    assert len(rope_deltas.shape) == 2 and rope_deltas.shape[0] == 1
    return {"position_ids": position_ids.squeeze(1), "rope_deltas": rope_deltas.squeeze(0)}


def collate_multimodal_metadata(batch, sp_pad):
    """Derive ``multimodal_metadata`` for the Qwen3-Omni-MoE ViT.

    Module-level so ``get_metadata_collate_func`` can hand it to VeOmni's
    collator as a picklable callable (mirrors ``get_position_id``). Runs
    purely on CPU inside the collator after SP padding — every value it
    produces (CPU int tensors / Python ints / lists) is consumed by the ViT
    forward without a host-device sync.

    ``batch`` is the packed (+ SP-padded) batch dict; ``sp_pad`` maps
    ``pixel_values`` / ``pixel_values_videos`` to the number of patch rows
    the SP collator appended. Mutates ``batch`` in place, writing
    ``batch["multimodal_metadata"]``. Audio metadata is not covered here
    (audio uses feature lengths, not a grid_thw); see the design doc.
    """
    md = {}
    # ViT varlen-attention metadata, derived from the HF processor's
    # ``*_grid_thw`` CPU LongTensor (packed across the batch by the collator
    # via DataCollateInfo pack_dim=0). ``.tolist()`` here is a pure-CPU op —
    # the collator runs in dataloader workers, no host-device sync.
    # Temporal unroll: each (t, h, w) expands to ``t`` cu steps of ``h * w``.
    for modality, grid_key, pad_key in (
        ("image", "image_grid_thw", "pixel_values"),
        ("video", "video_grid_thw", "pixel_values_videos"),
    ):
        grid = batch.get(grid_key)
        if grid is None:
            continue
        grid_list = grid.tolist() if torch.is_tensor(grid) else grid
        if not grid_list:
            continue
        md[f"{modality}_grid_thw_list"] = grid_list
        cu = [0]
        max_hw = 0
        for t, h, w in grid_list:
            hw = h * w
            max_hw = max(max_hw, hw)
            for _ in range(t):
                cu.append(cu[-1] + hw)
        # SP-pad tail: the collator zero-pads pixel_values to SP-divisible;
        # those patches become one synthetic "image" so varlen attention
        # treats them as an independent sequence. Discarded after the slice.
        pad = sp_pad.get(pad_key, 0)
        if pad > 0:
            cu.append(cu[-1] + pad)
            max_hw = max(max_hw, pad)
        # device='cpu': runs in CPU dataloader workers — pin to CPU so a
        # global torch.set_default_device('cuda') can't misallocate it.
        md[f"vit_{modality}_cu_seqlens"] = torch.tensor(cu, dtype=torch.int32, device="cpu")
        md[f"vit_{modality}_max_seqlen"] = max_hw

    if md:
        batch["multimodal_metadata"] = md
'''
)


# ================================================================
# Patch: Qwen3OmniMoePreTrainedModel._init_weights
# Upstream branches on `isinstance(module, Qwen3OmniMoeCode2Wav)`; that
# class is excluded from the generated file (training never instantiates
# code2wav), so the name lookup fails at runtime during `post_init`. Drop
# the Code2Wav branch — everything else matches upstream verbatim.
# ================================================================
# These names (`init`, `np`, `SinusoidsPositionEmbedding`,
# `Qwen3OmniMoeThinkerTextSparseMoeBlock`, `Qwen3OmniMoeVisionRotaryEmbedding`)
# are resolved from the generated file's module namespace at import time,
# not from this patch config — patchgen lifts the function body verbatim
# into the generated class.
@config.override_method(
    "Qwen3OmniMoePreTrainedModel._init_weights",
    description="Drop Qwen3OmniMoeCode2Wav branch since the class is excluded from the generated file",
)
@torch.no_grad()
def qwen3_omni_moe_pretrained_init_weights_patched(self, module):
    super()._init_weights(module)
    std = self.config.initializer_range
    if isinstance(module, Qwen3OmniMoeThinkerTextSparseMoeBlock):  # noqa: F821
        init.normal_(module.experts.gate_up_proj, mean=0.0, std=std)  # noqa: F821
        init.normal_(module.experts.down_proj, mean=0.0, std=std)  # noqa: F821
        init.normal_(module.gate.weight, mean=0.0, std=std)  # noqa: F821
    elif isinstance(module, SinusoidsPositionEmbedding):  # noqa: F821
        log_timescale_increment = np.log(module.max_timescale) / (module.channels // 2 - 1)  # noqa: F821
        inv_timescales = torch.exp(-log_timescale_increment * torch.arange(module.channels // 2).float())
        scaled_time = torch.arange(module.length)[:, np.newaxis] * inv_timescales[np.newaxis, :]  # noqa: F821
        init.copy_(module.positional_embedding, torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1))  # noqa: F821
    elif isinstance(module, Qwen3OmniMoeVisionRotaryEmbedding):  # noqa: F821
        inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
        init.copy_(module.inv_freq, inv_freq)  # noqa: F821


# ================================================================
# Patch: Qwen3OmniMoePreTrainedModelForConditionalGeneration.get_rope_index
# 1. [PosID] support interleaved video-with-audio vs video-without-audio in
#    one batch. v5 upstream uses a single `use_audio_in_video` boolean flag
#    applied globally; we use `audio_seqlens[audio_idx] == 0` (inherited from
#    the Qwen2.5-Omni convention) to decide per-video, consuming the zero
#    placeholder to keep `audio_idx` aligned.
# 2. [Mask] tolerate `attention_mask=None` at the data boundary.
# ================================================================
@config.override_method(
    "Qwen3OmniMoePreTrainedModelForConditionalGeneration.get_rope_index",
    description="Per-video use_audio_in_video via audio_seqlens + None attention_mask tolerance",
)
def qwen3_omni_moe_get_rope_index_patched(
    self,
    input_ids: Optional[torch.LongTensor] = None,
    image_grid_thw: Optional[torch.LongTensor] = None,
    video_grid_thw: Optional[torch.LongTensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    # --- Patch.1 ---
    # use_audio_in_video removed; decided per-video below.
    # --- Patch.1 ---
    audio_seqlens: Optional[torch.LongTensor] = None,
    second_per_grids: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    spatial_merge_size = self.spatial_merge_size
    image_token_id = self.config.image_token_id
    video_token_id = self.config.video_token_id
    audio_token_id = self.config.audio_token_id
    vision_start_token_id = self.config.vision_start_token_id
    audio_start_token_id = self.config.audio_start_token_id
    position_id_per_seconds = self.config.position_id_per_seconds

    mrope_position_deltas = []
    if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
        total_input_ids = input_ids
        # --- Patch.2 ---
        if attention_mask is None:
            attention_mask = torch.ones_like(total_input_ids)
        attention_mask = attention_mask == 1
        # --- Patch.2 ---
        position_ids = torch.zeros(
            3,
            input_ids.shape[0],
            input_ids.shape[1],
            dtype=torch.float,
            device=input_ids.device,
        )
        image_idx, video_idx, audio_idx = 0, 0, 0
        for i, input_ids_i in enumerate(total_input_ids):
            # --- Patch.2 ---
            input_ids_i = input_ids_i[attention_mask[i]]
            # --- Patch.2 ---
            image_nums, video_nums, audio_nums = 0, 0, 0
            vision_start_indices = torch.argwhere(input_ids_i == vision_start_token_id).squeeze(1)
            vision_tokens = input_ids_i[vision_start_indices + 1]

            # --- Patch.1 ---
            audio_start_indices = torch.argwhere(input_ids_i == audio_start_token_id).squeeze(1)
            audio_nums = torch.sum(
                input_ids_i[audio_start_indices - 1] != vision_start_token_id
            )  # audio but not in <video><audio>
            # --- Patch.1 ---

            image_nums = (vision_tokens == image_token_id).sum()
            # --- Patch.1 ---
            video_nums = (vision_tokens == audio_start_token_id).sum() + (vision_tokens == video_token_id).sum()
            # --- Patch.1 ---

            input_tokens = input_ids_i.tolist()
            llm_pos_ids_list: list = []
            st = 0
            remain_images, remain_videos, remain_audios = image_nums, video_nums, audio_nums

            # --- Patch.1 ---
            multimodal_nums = image_nums + video_nums + audio_nums
            # --- Patch.1 ---

            for _ in range(multimodal_nums):
                st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                if (image_token_id in input_tokens or video_token_id in input_tokens) and (
                    remain_videos > 0 or remain_images > 0
                ):
                    ed_vision_start = input_tokens.index(vision_start_token_id, st)
                else:
                    ed_vision_start = len(input_tokens) + 1
                if audio_token_id in input_tokens and remain_audios > 0:
                    ed_audio_start = input_tokens.index(audio_start_token_id, st)
                else:
                    ed_audio_start = len(input_tokens) + 1
                min_ed = min(ed_vision_start, ed_audio_start)

                text_len = min_ed - st
                if text_len != 0:
                    llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
                    st_idx += text_len
                # audio-in-video shares bos (vision_start immediately followed by audio_start)
                if min_ed == ed_vision_start and input_ids_i[ed_vision_start + 1] == audio_start_token_id:
                    bos_len, eos_len = 2, 2
                else:
                    bos_len, eos_len = 1, 1
                llm_pos_ids_list.append(torch.arange(bos_len).view(1, -1).expand(3, -1) + st_idx)
                st_idx += bos_len
                # Audio only
                if min_ed == ed_audio_start:
                    audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
                    llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
                    llm_pos_ids_list.append(llm_pos_ids)

                    st += int(text_len + bos_len + audio_len + eos_len)
                    audio_idx += 1
                    remain_audios -= 1

                # Image only
                elif min_ed == ed_vision_start and input_ids_i[ed_vision_start + 1] == image_token_id:
                    grid_t = image_grid_thw[image_idx][0]
                    grid_hs = image_grid_thw[:, 1]
                    grid_ws = image_grid_thw[:, 2]
                    t_index = (torch.arange(grid_t) * 1 * position_id_per_seconds).float()
                    llm_pos_ids = self.get_llm_pos_ids_for_vision(
                        st_idx, image_idx, spatial_merge_size, t_index, grid_hs, grid_ws
                    )
                    image_len = image_grid_thw[image_idx].prod() // (spatial_merge_size**2)
                    llm_pos_ids_list.append(llm_pos_ids)

                    st += int(text_len + bos_len + image_len + eos_len)
                    image_idx += 1
                    remain_images -= 1

                # Video only — audio track determined per-video via audio_seqlens
                elif min_ed == ed_vision_start:
                    # --- Patch.1 ---
                    if audio_seqlens[audio_idx] == 0:
                        use_audio_in_video = False
                        audio_idx += 1  # consume zero-length placeholder
                    else:
                        use_audio_in_video = True
                    # --- Patch.1 ---

                    if not use_audio_in_video:
                        assert input_ids_i[ed_vision_start + 1] == video_token_id

                        grid_t = video_grid_thw[video_idx][0]
                        grid_hs = video_grid_thw[:, 1]
                        grid_ws = video_grid_thw[:, 2]
                        t_index = (
                            torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds
                        ).float()
                        llm_pos_ids = self.get_llm_pos_ids_for_vision(
                            st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
                        )
                        video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)
                        llm_pos_ids_list.append(llm_pos_ids)

                        st += int(text_len + bos_len + video_len + eos_len)
                        video_idx += 1
                        remain_videos -= 1
                    else:
                        assert input_ids_i[ed_vision_start + 1] == audio_start_token_id
                        audio_len = _get_feat_extract_output_lengths(audio_seqlens[audio_idx])
                        audio_llm_pos_ids = torch.arange(audio_len).view(1, -1).expand(3, -1) + st_idx
                        grid_t = video_grid_thw[video_idx][0]
                        grid_hs = video_grid_thw[:, 1]
                        grid_ws = video_grid_thw[:, 2]
                        t_index = (
                            torch.arange(grid_t) * second_per_grids[video_idx].cpu().float() * position_id_per_seconds
                        ).float()
                        video_llm_pos_ids = self.get_llm_pos_ids_for_vision(
                            st_idx, video_idx, spatial_merge_size, t_index, grid_hs, grid_ws
                        )
                        video_data_index, audio_data_index = 0, 0
                        while (
                            video_data_index < video_llm_pos_ids.shape[-1]
                            and audio_data_index < audio_llm_pos_ids.shape[-1]
                        ):
                            if video_llm_pos_ids[0][video_data_index] <= audio_llm_pos_ids[0][audio_data_index]:
                                llm_pos_ids_list.append(video_llm_pos_ids[:, video_data_index : video_data_index + 1])
                                video_data_index += 1
                            else:
                                llm_pos_ids_list.append(audio_llm_pos_ids[:, audio_data_index : audio_data_index + 1])
                                audio_data_index += 1
                        if video_data_index < video_llm_pos_ids.shape[-1]:
                            llm_pos_ids_list.append(
                                video_llm_pos_ids[:, video_data_index : video_llm_pos_ids.shape[-1]]
                            )
                        if audio_data_index < audio_llm_pos_ids.shape[-1]:
                            llm_pos_ids_list.append(
                                audio_llm_pos_ids[:, audio_data_index : audio_llm_pos_ids.shape[-1]]
                            )
                        video_len = video_grid_thw[video_idx].prod() // (spatial_merge_size**2)

                        st += int(text_len + bos_len + audio_len + video_len + eos_len)
                        audio_idx += 1
                        video_idx += 1
                        remain_videos -= 1

                st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                llm_pos_ids_list.append(torch.arange(eos_len).view(1, -1).expand(3, -1) + st_idx)

            if st < len(input_tokens):
                st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                text_len = len(input_tokens) - st
                llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)

            llm_positions = torch.cat([item.float() for item in llm_pos_ids_list], dim=1).reshape(3, -1)

            position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
            mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids_i))
        mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)

        return position_ids, mrope_position_deltas
    else:
        position_ids = attention_mask.float().cumsum(-1) - 1
        position_ids.masked_fill_(attention_mask == 0, 1)
        position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
        max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
        mrope_position_deltas = max_position_ids + 1 - torch.sum(attention_mask, dim=-1, keepdim=True)

        return position_ids, mrope_position_deltas


# ================================================================
# Patch: Qwen3OmniMoeVisionAttention.forward
# 1. [SP] Dispatch via VARLEN_ATTENTION_TYPES (covers veomni_flash_attention_*
#    custom names) instead of v5 upstream `is_flash_attention_requested`,
#    which only recognizes HF built-in flash attention names. Without this
#    dispatch the SP-appended cu_seqlens-padding entry would run through the
#    non-varlen split branch and size-mismatch.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeVisionAttention.forward",
    description="Route through VARLEN_ATTENTION_TYPES so veomni_flash_attention_* with cu_seqlens works",
)
def qwen3_omni_moe_vision_attention_forward_patched(
    self,
    hidden_states: torch.Tensor,
    cu_seqlens: torch.Tensor,
    rotary_pos_emb: Optional[torch.Tensor] = None,
    position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
    **kwargs,
) -> torch.Tensor:
    seq_length = hidden_states.shape[0]
    query_states, key_states, value_states = (
        self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
    )
    cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)

    query_states = query_states.transpose(0, 1).unsqueeze(0)
    key_states = key_states.transpose(0, 1).unsqueeze(0)
    value_states = value_states.transpose(0, 1).unsqueeze(0)

    attention_interface: Callable = eager_attention_forward
    if self.config._attn_implementation != "eager":
        attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

    # --- Patch.1 ---
    if self.config._attn_implementation in VARLEN_ATTENTION_TYPES:
        max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
        attn_output, _ = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask=None,
            scaling=self.scaling,
            dropout=0.0 if not self.training else self.attention_dropout,
            cu_seq_lens_q=cu_seqlens,
            cu_seq_lens_k=cu_seqlens,
            max_length_q=max_seqlen,
            max_length_k=max_seqlen,
            is_causal=False,
            **kwargs,
        )
    else:
        lengths = cu_seqlens[1:] - cu_seqlens[:-1]
        splits = [torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)]
        attn_outputs = [
            attention_interface(
                self,
                q,
                k,
                v,
                attention_mask=None,
                scaling=self.scaling,
                dropout=0.0 if not self.training else self.attention_dropout,
                is_causal=False,
                **kwargs,
            )[0]
            for q, k, v in zip(*splits)
        ]
        attn_output = torch.cat(attn_outputs, dim=1)
    # --- Patch.1 ---

    attn_output = attn_output.reshape(seq_length, -1).contiguous()
    attn_output = self.proj(attn_output)
    return attn_output


# ================================================================
# Patch: Qwen3OmniMoeVisionEncoder.forward
# 1. [SP] Slice pos_embeds and rotary cos/sin to match the SP-sharded
#    hidden_states (pad_scale=4 matches the patch-embed's padding scale).
# 2. [SP] Append an extra cu_seqlens entry for the padded tail when the
#    total seq length is not divisible by sp_size, so the varlen kernel
#    attends within the padding chunk rather than across samples.
# 3. Return v5 BaseModelOutputWithDeepstackFeatures (with pooler_output =
#    merged_hidden_states and deepstack_features list).
# ================================================================
@config.override_method(
    "Qwen3OmniMoeVisionEncoder.forward",
    description="SP-slice pos/rotary, extend cu_seqlens for SP pad tail, return BaseModelOutputWithDeepstackFeatures",
)
def qwen3_omni_moe_vision_forward_patched(
    self,
    hidden_states: torch.Tensor,
    grid_thw: torch.Tensor,
    **kwargs: Unpack[TransformersKwargs],
) -> tuple | BaseModelOutputWithDeepstackFeatures:
    # Precomputed ViT metadata — a per-modality sub-dict Thinker.forward
    # selects from `multimodal_metadata` and passes as the single
    # `vit_metadata` kwarg. All .get() below fall back to None for callers
    # that bypass MainCollator. See .agents/knowledge/multimodal_metadata.md.
    vit_metadata = kwargs.pop("vit_metadata", None) or {}
    precomputed_grid_thw_list = vit_metadata.get("grid_thw_list")
    precomputed_cu_seqlens = vit_metadata.get("cu_seqlens")
    precomputed_max_seqlen = vit_metadata.get("max_seqlen")

    hidden_states = self.patch_embed(hidden_states)

    grid_thw_list = precomputed_grid_thw_list
    if grid_thw_list is None:
        grid_thw_list = grid_thw.tolist()

    pos_embeds = self.fast_pos_embed_interpolate(grid_thw)

    sp_group = get_parallel_state().sp_group if get_parallel_state().sp_enabled else None

    # --- Patch.1 ---
    if sp_group is not None:
        pos_embeds = sp_pad_and_slice(pos_embeds, dim=0, pad_value=0, pad_scale=4)
    # --- Patch.1 ---
    hidden_states = hidden_states + pos_embeds

    # total_seq_len (pre-sp-pad) — host int, used for the rotary reshape below.
    total_seq_len = sum(t * h * w for t, h, w in grid_thw_list)

    # Prefer precomputed cu_seqlens (already includes sp-pad tail via collator).
    # Fallback builds host-side and handles sp-pad inline below.
    if precomputed_cu_seqlens is not None:
        cu_seqlens = precomputed_cu_seqlens.to(
            hidden_states.device,
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
            non_blocking=True,
        )
    else:
        cu_seqlens_list = [0]
        for t, h, w in grid_thw_list:
            frame_len = h * w
            for _ in range(t):
                cu_seqlens_list.append(cu_seqlens_list[-1] + frame_len)
        cu_seqlens = torch.tensor(
            cu_seqlens_list,
            device=hidden_states.device,
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )

    rotary_pos_emb = self.rot_pos_emb(grid_thw)
    seq_len, _ = hidden_states.size()
    hidden_states = hidden_states.reshape(seq_len, -1)
    # total_seq_len is already a Python int — no sync to use as a shape arg.
    rotary_pos_emb = rotary_pos_emb.reshape(total_seq_len, -1)
    emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
    position_embeddings = (emb.cos(), emb.sin())

    # --- Patch.1 ---
    if sp_group is not None:
        cos, sin = position_embeddings
        cos = sp_pad_and_slice(cos, dim=0, pad_value=0, pad_scale=4)
        sin = sp_pad_and_slice(sin, dim=0, pad_value=0, pad_scale=4)
        position_embeddings = (cos, sin)
    # --- Patch.1 ---

    # --- Patch.2 ---
    if sp_group is not None:
        sp_size = getattr(get_parallel_state(), "sp_size", 1)
        pad_seq_len = seq_len * sp_size - total_seq_len  # already host int
        # Precomputed cu_seqlens already has the sp-pad tail entry; only the
        # fallback path needs to extend it here.
        if pad_seq_len > 0 and precomputed_cu_seqlens is None:
            new_cumsum = cu_seqlens[-1] + pad_seq_len
            cu_seqlens = torch.cat([cu_seqlens, new_cumsum.unsqueeze(0)], dim=0)
    # --- Patch.2 ---

    # ``precomputed_max_seqlen`` is currently unused at the per-block level
    # in omni_moe ViT (varlen attention reads max from cu_seqlens internally);
    # accepted as a kwarg for contract symmetry with qwen3_vl and to insulate
    # callers from per-model differences. Touch to avoid 'unused-var' lints.
    _ = precomputed_max_seqlen

    deepstack_feature_lists = []
    for layer_num, blk in enumerate(self.blocks):
        hidden_states = blk(
            hidden_states,
            cu_seqlens=cu_seqlens,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        if layer_num in self.deepstack_visual_indexes:
            deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)](
                hidden_states
            )
            deepstack_feature_lists.append(deepstack_feature)

    # --- Patch.3 ---
    merged_hidden_states = self.merger(hidden_states)
    return BaseModelOutputWithDeepstackFeatures(
        last_hidden_state=hidden_states,
        pooler_output=merged_hidden_states,
        deepstack_features=deepstack_feature_lists,
    )
    # --- Patch.3 ---


# ================================================================
# Patch: Qwen3OmniMoeVisionEncoder.dummy_forward (NEW method)
# [FSDP] Drive a synthetic forward so reduce-scatter stays in sync when
# some ranks get pixel_values=None.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeVisionEncoder.dummy_forward",
    description="FSDP dummy forward with patch-embed dtype lookup to stay bf16-safe under MixedPrecision",
)
def qwen3_omni_moe_vision_dummy_forward_patched(self):
    # Pull dtype from a live parameter rather than `self.dtype`: under FSDP2 +
    # MixedPrecision the module's reported dtype may lag the per-call compute
    # cast, causing float/bf16 mismatches when the real-data rank runs in bf16.
    dtype = self.patch_embed.proj.weight.dtype
    pixel_values = torch.zeros((16, 3 * 2 * 16 * 16), dtype=dtype, device=self.device)
    if get_parallel_state().sp_enabled:
        # grid_thw describes the *global* pre-sharded vision grid (H scaled by sp_size).
        t, h, w = 1, 4 * get_parallel_state().sp_size, 4
    else:
        t, h, w = 1, 4, 4
    grid_thw = torch.tensor([[t, h, w]], dtype=torch.int32, device=self.device)

    # Precompute the ViT metadata host-side and pass it straight to forward:
    # dummy_forward runs inside Thinker.forward, so the collator can't
    # precompute it — but t / h / w are Python ints here, so the dummy ViT
    # forward skips the `grid_thw.tolist()` + cu_seqlens build it would
    # otherwise sync on.
    cu = [0]
    for _ in range(t):
        cu.append(cu[-1] + h * w)
    vit_metadata = {
        "grid_thw_list": [[t, h, w]],
        "cu_seqlens": torch.tensor(cu, dtype=torch.int32, device="cpu"),
        "max_seqlen": h * w,
    }
    return self(hidden_states=pixel_values, grid_thw=grid_thw, vit_metadata=vit_metadata)


# ================================================================
# Patch: Qwen3OmniMoeAudioEncoder.forward
# 1. [data] VeOmni ships input_features as (len, num_mel_bins); permute to
#    (num_mel_bins, len) to match the HF contract.
# 2. [SP] Gather along time, strip SP-padding so chunking is deterministic.
# 3. [SP] Slice hidden_states along seq dim before encoder layers; extend
#    cu_seqlens for the padded tail on the last rank.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeAudioEncoder.forward",
    description="Permute VeOmni (len, mel) input + SP gather/strip + SP slice + extend cu_seqlens",
)
def qwen3_omni_moe_audio_forward_patched(
    self,
    input_features,
    feature_lens=None,
    aftercnn_lens=None,
    **kwargs,
):
    r"""
    feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
        mel length
    aftercnn_lens (`torch.LongTensor` of shape `(batch_size,)`):
        mel length after cnn
    """
    # --- Patch.1 ---
    input_features = input_features.permute(1, 0)  # (len, num_mel_bins) -> (num_mel_bins, len)
    # --- Patch.1 ---

    aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
    chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()

    chunk_lengths = torch.full((chunk_num.sum(),), self.n_window * 2, dtype=torch.long, device=feature_lens.device)
    tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
    chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
    chunk_lengths[chunk_lengths == 0] = self.n_window * 2

    # --- Patch.2 ---
    if get_parallel_state().sp_enabled:
        unpadded_input_len = torch.sum(chunk_lengths)
        input_features = gather_outputs(input_features, gather_dim=1, group=get_parallel_state().sp_group)
        sp_input_padding = input_features.size(1) - unpadded_input_len
        if sp_input_padding > 0:
            input_features = unpad_tensor(input_features, dim=1, padding_size=sp_input_padding)
    # --- Patch.2 ---

    chunk_list = input_features.T.split(chunk_lengths.tolist(), dim=0)
    padded_feature = nn.utils.rnn.pad_sequence(chunk_list, batch_first=True).transpose(1, 2)
    feature_lens_after_cnn = _get_feat_extract_output_lengths(chunk_lengths)
    padded_mask_after_cnn = nn.utils.rnn.pad_sequence(
        [torch.ones(length, dtype=torch.bool, device=padded_feature.device) for length in feature_lens_after_cnn],
        batch_first=True,
    )
    padded_feature = padded_feature.unsqueeze(1)
    padded_embeds = []
    for chunk in padded_feature.split(self.conv_chunksize, dim=0):
        padded_embed = F.gelu(self.conv2d1(chunk))
        padded_embed = F.gelu(self.conv2d2(padded_embed))
        padded_embed = F.gelu(self.conv2d3(padded_embed))
        padded_embeds.append(padded_embed)
    padded_embed = torch.cat(padded_embeds, dim=0)
    b, c, f, t = padded_embed.size()
    padded_embed = self.conv_out(padded_embed.permute(0, 3, 1, 2).contiguous().view(b, t, c * f))

    positional_embedding = (
        self.positional_embedding.positional_embedding[: padded_embed.shape[1], :].unsqueeze(0).to(padded_embed.dtype)
    )
    padded_embed = padded_embed + positional_embedding
    hidden_states = padded_embed[padded_mask_after_cnn]
    cu_chunk_lens = [0]
    window_aftercnn = padded_mask_after_cnn.shape[-1] * (self.n_window_infer // (self.n_window * 2))
    for cnn_len in aftercnn_lens:
        cu_chunk_lens += [window_aftercnn] * (cnn_len // window_aftercnn)
        remainder = cnn_len % window_aftercnn
        if remainder != 0:
            cu_chunk_lens += [remainder]
    cu_seqlens = torch.tensor(cu_chunk_lens, device=aftercnn_lens.device).cumsum(-1, dtype=torch.int32)

    # --- Patch.3 ---
    if get_parallel_state().sp_enabled:
        unpadded_hidden_len = cu_seqlens[-1]
        hidden_states = slice_input_tensor(hidden_states, dim=0, group=get_parallel_state().sp_group)
        pad_seq_len = hidden_states.size(0) * get_parallel_state().sp_size - unpadded_hidden_len
        if pad_seq_len > 0:
            cu_seqlens = torch.cat([cu_seqlens, (cu_seqlens[-1] + pad_seq_len).unsqueeze(0)], dim=0)
    # --- Patch.3 ---

    for encoder_layer in self.layers:
        layer_outputs = encoder_layer(
            hidden_states,
            cu_seqlens,
        )
        hidden_states = layer_outputs[0]

    hidden_states = self.ln_post(hidden_states)
    hidden_states = self.proj1(hidden_states)
    hidden_states = self.act(hidden_states)
    hidden_states = self.proj2(hidden_states)
    return BaseModelOutputWithPooling(last_hidden_state=hidden_states)


# ================================================================
# Patch: Qwen3OmniMoeAudioEncoder.dummy_forward (NEW method)
# [FSDP] Synthetic forward so reduce-scatter stays in sync on ranks with
# no audio data. Minimum valid shape is one chunk of length n_window*2.
# Dtype is looked up from `self.conv2d1.weight.dtype` at call time — under
# FSDP2 + MixedPrecision, `self.dtype` (via `next(self.parameters()).dtype`)
# may still report the sharded full precision (fp32) while the per-module
# compute dtype has already been cast to bf16; using the conv weight's
# live dtype keeps the dummy inputs matched to whatever the conv actually
# runs in. We do NOT cache (`_dummy_data`) because the cached tensor
# would stay at first-call dtype and break on subsequent calls that
# enter a different mixed-precision context.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeAudioEncoder.dummy_forward",
    description="FSDP dummy forward with conv-weight dtype lookup (no caching) to stay bf16-safe",
)
def qwen3_omni_moe_audio_dummy_forward_patched(self):
    min_len = self.n_window * 2
    dtype = self.conv2d1.weight.dtype
    input_features = torch.zeros((min_len, self.num_mel_bins), dtype=dtype, device=self.device)
    feature_lens = torch.tensor([min_len], dtype=torch.long, device=self.device)
    return self(input_features=input_features, feature_lens=feature_lens)


# ================================================================
# Patch: Qwen3OmniMoeThinkerTextModel.forward
# Same outer shape as v5 upstream but:
# 1. Preserve v4's explicit 4-axis position_ids handling (`[4, bs, L]` where
#    index 0 is the text_position_ids used for the causal mask).
# 2. Return MoeModelOutputWithPast so @capture_outputs can inject
#    router_logits via the registered OutputRecorder.
# NOTE: the @merge_with_config_defaults / @capture_outputs / @auto_docstring
# decorators on the upstream method are preserved by patchgen.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeThinkerTextModel.forward",
    description="Preserve explicit [4, bs, L] position_ids handling and MoeModelOutputWithPast return",
)
def qwen3_omni_moe_thinker_text_model_forward_patched(
    self,
    input_ids: Optional[torch.LongTensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values=None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = None,
    cache_position: Optional[torch.LongTensor] = None,
    visual_pos_masks: Optional[torch.Tensor] = None,
    deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
    **kwargs: Unpack[FlashAttentionKwargs],
) -> tuple | MoeModelOutputWithPast:
    r"""
    visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
        The mask of the visual positions.
    deepstack_visual_embeds (`list[torch.Tensor]`, *optional*):
        The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim).
    """
    if (input_ids is None) ^ (inputs_embeds is not None):
        raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

    if use_cache and past_key_values is None and not torch.jit.is_tracing():
        past_key_values = DynamicCache(config=self.config)

    if inputs_embeds is None:
        inputs_embeds = self.embed_tokens(input_ids)

    if cache_position is None:
        past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
        cache_position = torch.arange(
            past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
        )

    if position_ids is None:
        position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
    elif position_ids.ndim == 2:
        position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)

    if position_ids.ndim == 3 and position_ids.shape[0] == 4:
        text_position_ids = position_ids[0]
        position_ids = position_ids[1:]
    else:
        text_position_ids = position_ids[0]

    attention_mask = create_causal_mask(
        config=self.config,
        inputs_embeds=inputs_embeds,
        attention_mask=attention_mask,
        cache_position=cache_position,
        past_key_values=past_key_values,
        position_ids=text_position_ids,
    )

    hidden_states = inputs_embeds

    position_embeddings = self.rotary_emb(hidden_states, position_ids)

    for layer_idx, decoder_layer in enumerate(self.layers):
        layer_outputs = decoder_layer(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=text_position_ids,
            past_key_values=past_key_values,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = layer_outputs

        if deepstack_visual_embeds is not None and layer_idx in range(len(deepstack_visual_embeds)):
            hidden_states = self._deepstack_process(
                hidden_states,
                visual_pos_masks,
                deepstack_visual_embeds[layer_idx],
            )

    hidden_states = self.norm(hidden_states)

    return MoeModelOutputWithPast(
        last_hidden_state=hidden_states,
        past_key_values=past_key_values,
    )


# ================================================================
# Patch: Qwen3OmniMoeThinkerTextModel._deepstack_process
# 1. [FSDP] If visual_pos_masks is None (no visual input on this rank) still
#    touch visual_embeds so FSDP reduce-scatter stays in sync across ranks.
# 2. [Mask] Squeeze trailing dim when mask is still 3D (legacy path).
# ================================================================
@config.override_method(
    "Qwen3OmniMoeThinkerTextModel._deepstack_process",
    description="Handle visual_pos_masks=None by adding 0.0 so FSDP reduce-scatter stays in sync",
)
def qwen3_omni_moe_thinker_text_deepstack_process_patched(
    self,
    hidden_states,
    visual_pos_masks,
    visual_embeds,
):
    # --- Patch.1 ---
    if visual_pos_masks is None:
        visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
        hidden_states = hidden_states + visual_embeds.mean() * 0.0
        return hidden_states
    # --- Patch.1 ---

    # --- Patch.2 ---
    visual_pos_masks = visual_pos_masks.to(hidden_states.device)
    if visual_pos_masks.ndim == 3:
        visual_pos_masks = visual_pos_masks[..., 0]
    # --- Patch.2 ---
    visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
    local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
    hidden_states[visual_pos_masks, :] = local_this
    return hidden_states


# ================================================================
# Patch: Qwen3OmniMoeThinkerTextExperts (replace_class)
# 1. Drop the upstream `@use_experts_implementation` decorator — routing
#    through ALL_EXPERTS_FUNCTIONS bypasses our fused MoE kernel.
# 2. Add VeOmni fused-MoE dispatch via the module-level
#    ``veomni_moe_experts_forward`` OpSlot; pass `gate_up_proj` directly as
#    `fc1_1_2_weight` (v5 already stores it in the fused `[E, 2*I, H]` layout).
# ================================================================
@config.replace_class(
    "Qwen3OmniMoeThinkerTextExperts",
    description="Drop @use_experts_implementation and add VeOmni fused MoE dispatch",
)
class PatchedQwen3OmniMoeThinkerTextExperts(nn.Module):
    """Collection of expert weights stored as 3D tensors with VeOmni fused/eager dispatch."""

    def __init__(self, config):
        super().__init__()
        self.num_experts = config.num_experts
        self.hidden_dim = config.hidden_size
        self.intermediate_dim = config.moe_intermediate_size
        self.gate_up_proj = nn.Parameter(torch.empty(self.num_experts, 2 * self.intermediate_dim, self.hidden_dim))
        self.down_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, self.intermediate_dim))
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(
        self,
        hidden_states: torch.Tensor,
        top_k_index: torch.Tensor,
        top_k_weights: torch.Tensor,
    ) -> torch.Tensor:
        final_hidden_states = torch.zeros_like(hidden_states)
        # --- Patch.2 ---
        if veomni_moe_experts_forward.use_non_eager_impl:
            return fused_moe_forward(
                num_experts=self.num_experts,
                routing_weights=top_k_weights.to(final_hidden_states.dtype),
                selected_experts=top_k_index,
                hidden_states=hidden_states,
                fc1_1_weight=None,
                fc1_2_weight=None,
                fc2_weight=self.down_proj,
                fc1_1_2_weight=self.gate_up_proj,
            )
        # --- Patch.2 ---

        with torch.no_grad():
            expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
            expert_mask = expert_mask.permute(2, 1, 0)
            expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

        for expert_idx in expert_hit:
            expert_idx = expert_idx[0]
            if expert_idx == self.num_experts:
                continue
            top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
            current_state = hidden_states[token_idx]
            gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
            current_hidden_states = self.act_fn(gate) * up
            current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
            current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
            final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

        return final_hidden_states


# ================================================================
# Patch: Qwen3OmniMoeThinkerForConditionalGeneration.get_audio_features
# Simplified to the VeOmni training path: input_features is already the
# flat (len, num_mel_bins) tensor (after the collator strips feature
# padding), and feature_attention_mask is not carried in training.
# Return the raw last_hidden_state to keep the forward body terse.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeThinkerForConditionalGeneration.get_audio_features",
    description="Simplify get_audio_features for VeOmni flat (len, mel) inputs — no feature_attention_mask",
)
def qwen3_omni_moe_thinker_get_audio_features_patched(
    self,
    input_features,
    audio_feature_lengths=None,
    **kwargs,
):
    r"""
    audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
        The length of feature shape of each audio in LLM.
    """
    audio_outputs = self.audio_tower(
        input_features,
        feature_lens=audio_feature_lengths,
    )
    return audio_outputs.last_hidden_state


# ================================================================
# Patch: Qwen3OmniMoeThinkerForConditionalGeneration.get_position_id_func
# Returns a per-sample closure that converts VeOmni's multimodal tokens
# (IMAGE_INPUT_INDEX / VIDEO_INPUT_INDEX / AUDIO_INPUT_INDEX) into 3D
# position_ids at data-preprocessing time. SimpleNamespace + unbound
# methods avoid pickling the full model across dataloader workers.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeThinkerForConditionalGeneration.get_position_id_func",
    description="Multiprocessing-safe per-sample position-id closure with VeOmni multimodal token ids",
)
def qwen3_omni_moe_thinker_get_position_id_func_patched(self):
    fake_config = copy.copy(self.config)
    fake_config.image_token_id = IMAGE_INPUT_INDEX
    fake_config.video_token_id = VIDEO_INPUT_INDEX
    fake_config.audio_token_id = AUDIO_INPUT_INDEX
    fake_model = SimpleNamespace(
        config=fake_config,
        spatial_merge_size=self.spatial_merge_size,
        get_llm_pos_ids_for_vision=partial(type(self).get_llm_pos_ids_for_vision, None),
        get_chunked_index=partial(type(self).get_chunked_index, None),
    )
    return partial(
        get_position_id,  # noqa: F821 — defined at module scope via add_post_import_block
        type(self).get_rope_index,
        fake_model,
    )


# ================================================================
# Patch: Qwen3OmniMoeThinkerForConditionalGeneration.forward
# 1. [Constants] Use VeOmni data constants for multimodal token indices (via
#    get_position_id_func); precomputed masks arrive via kwargs.
# 2. [Mask] Pop pre-computed image/video/audio masks — avoids extra all_gather
#    for full mask information when using SP.
# 3. [ViT] Pop flash-attention kwargs before ViT forward so ViT computes its
#    own cu_seqlens from grid_thw.
# 4. [SP] gather_outputs on input/image/video/audio embeddings to
#    do the multimodal fill-back on the full sequence.
# 5. [FSDP] Dummy ViT/audio forward when pixel_values/input_features is None
#    on this rank.
# 6. [SP] slice_input_tensor to restore seq-parallel layout.
# 7. [SP] all_gather deepstack embeddings then select per-rank slice.
# 8. [Loss] Delegate loss to `self.loss_function` for fused CE.
# 9. [PosIDs] Transpose precomputed position_ids from (bs, 3, L) to (3, bs, L).
# 10.[Data] Filter zero-length audio_feature_lengths (placeholder entries for
#    videos without audio) before forwarding the audio tower.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeThinkerForConditionalGeneration.forward",
    description="VeOmni SP + FSDP + precomputed masks + fused loss + precomputed multimodal position-ids",
)
def qwen3_omni_moe_thinker_forward_patched(
    self,
    input_ids=None,
    input_features=None,
    pixel_values=None,
    pixel_values_videos=None,
    image_grid_thw=None,
    video_grid_thw=None,
    attention_mask=None,
    audio_feature_lengths=None,
    position_ids=None,
    past_key_values=None,
    inputs_embeds=None,
    rope_deltas=None,
    labels=None,
    use_cache=None,
    output_router_logits: Optional[bool] = None,
    use_audio_in_video=None,
    cache_position=None,
    video_second_per_grid=None,
    **kwargs,
) -> tuple | Qwen3OmniMoeThinkerCausalLMOutputWithLogProbs:
    r"""
    image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
        The temporal, height and width of feature shape of each image in LLM.
    video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
        The temporal, height and width of feature shape of each video in LLM.
    audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
        The length of feature shape of each audio in LLM.
    rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
        The rope index difference between sequence length and multimodal rope.
    use_audio_in_video (`bool`, *optional*):
        Whether or not use audio track in video, should same as the parameter in `process_audio_info`.
    video_second_per_grid (`torch.LongTensor` of shape `(num_videos)`, *optional*):
        Number of seconds per grid for each video, used for temporal feature mapping.
    """
    output_router_logits = (
        output_router_logits if output_router_logits is not None else self.config.text_config.output_router_logits
    )

    if inputs_embeds is None:
        inputs_embeds = self.get_input_embeddings()(input_ids)

    # --- Patch.2 ---
    assert "image_mask" in kwargs, "image_mask should have already been computed in process_sample"
    assert "video_mask" in kwargs, "video_mask should have already been computed in process_sample"
    assert "audio_mask" in kwargs, "audio_mask should have already been computed in process_sample"
    image_mask = kwargs.pop("image_mask")
    video_mask = kwargs.pop("video_mask")
    audio_mask = kwargs.pop("audio_mask")
    # --- Patch.2 ---

    # --- Patch.11 ---
    # Mirror of qwen3_vl: unpack per-modality ViT kwargs from
    # `multimodal_metadata` (collator-precomputed) so the patched ViT
    # forward can skip the in-forward .tolist() / cu_seqlens build.
    # See .agents/knowledge/multimodal_metadata.md.
    multimodal_metadata = kwargs.pop("multimodal_metadata", None) or {}
    image_vit_kwargs = {
        "vit_metadata": {
            "grid_thw_list": multimodal_metadata.get("image_grid_thw_list"),
            "cu_seqlens": multimodal_metadata.get("vit_image_cu_seqlens"),
            "max_seqlen": multimodal_metadata.get("vit_image_max_seqlen"),
        }
    }
    video_vit_kwargs = {
        "vit_metadata": {
            "grid_thw_list": multimodal_metadata.get("video_grid_thw_list"),
            "cu_seqlens": multimodal_metadata.get("vit_video_cu_seqlens"),
            "max_seqlen": multimodal_metadata.get("vit_video_max_seqlen"),
        }
    }
    # --- Patch.11 ---

    # --- Patch.3 ---
    flash_attn_kwargs = {}
    for key in ["cu_seq_lens_q", "cu_seq_lens_k", "max_length_q", "max_length_k"]:
        if key in kwargs:
            flash_attn_kwargs[key] = kwargs.pop(key)
    # --- Patch.3 ---

    # --- Patch.4 ---
    if self.training and get_parallel_state().sp_enabled:
        inputs_embeds = gather_outputs(inputs_embeds, gather_dim=1, group=get_parallel_state().sp_group)
    # --- Patch.4 ---

    # --- Patch.10 ---
    if input_features is not None:
        valid_mask = audio_feature_lengths != 0
        audio_feature_lengths = audio_feature_lengths[valid_mask]
        if input_features.shape[0] == 0:
            input_features = None
    # --- Patch.10 ---

    if input_features is not None:
        audio_features = self.get_audio_features(
            input_features,
            audio_feature_lengths=audio_feature_lengths,
        )
        audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
        # --- Patch.4 ---
        if self.training and get_parallel_state().sp_enabled:
            audio_features = gather_outputs(audio_features, gather_dim=0, group=get_parallel_state().sp_group)
        # --- Patch.4 ---
        # `masked_scatter` consumes exactly `audio_mask.sum()` leading rows; padded audio rows are
        # trailing and unused — no `[:n]` slice / `.item()` sync needed (mirrors qwen3_5 fix).
        audio_mask_expanded = audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
        inputs_embeds = inputs_embeds.masked_scatter(audio_mask_expanded, audio_features)
    elif get_parallel_state().fsdp_enabled:
        # --- Patch.5 ---
        fake_audio = self.audio_tower.dummy_forward().last_hidden_state.mean() * 0.0
        fake_audio = fake_audio.to(inputs_embeds.device, inputs_embeds.dtype)
        inputs_embeds = inputs_embeds + fake_audio
        # --- Patch.5 ---

    fake_deepstack = None

    if pixel_values is not None:
        image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features(
            pixel_values, image_grid_thw, return_dict=True, **image_vit_kwargs
        )
        image_embeds = image_outputs.pooler_output
        deepstack_image_embeds = image_outputs.deepstack_features
        image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
        # --- Patch.4 ---
        if self.training and get_parallel_state().sp_enabled:
            image_embeds = gather_outputs(image_embeds, gather_dim=0, group=get_parallel_state().sp_group)
        # --- Patch.4 ---

        image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device, non_blocking=True)

        # `masked_scatter` consumes exactly `image_mask.sum()` leading rows; padded rows are
        # trailing and unused — no `[:n]` slice / `.item()` sync needed. The previous
        # `n_image_tokens != n_image_features` assertion was a debug guard that also forced the
        # sync; masked_scatter's own size mismatch already raises with a clear message if the
        # underlying invariant is violated.
        inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
    elif get_parallel_state().fsdp_enabled:
        # --- Patch.5 ---
        fake_vision = self.visual.dummy_forward()
        fake_embeds = fake_vision.pooler_output.mean() * 0.0
        fake_embeds = fake_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
        inputs_embeds = inputs_embeds + fake_embeds
        fake_deepstack = fake_vision.deepstack_features
        # --- Patch.5 ---

    if pixel_values_videos is not None:
        video_outputs: BaseModelOutputWithDeepstackFeatures = self.get_video_features(
            pixel_values_videos, video_grid_thw, return_dict=True, **video_vit_kwargs
        )
        video_embeds = video_outputs.pooler_output
        deepstack_video_embeds = video_outputs.deepstack_features
        video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
        # --- Patch.4 ---
        if self.training and get_parallel_state().sp_enabled:
            video_embeds = gather_outputs(video_embeds, gather_dim=0, group=get_parallel_state().sp_group)
        # --- Patch.4 ---

        video_mask = video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device, non_blocking=True)

        # Same as image branch above: drop the `[:n]` slice, the `.item()` sync, and the debug
        # assertion (masked_scatter raises on size mismatch).
        inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
    elif get_parallel_state().fsdp_enabled:
        # --- Patch.5 ---
        fake_vision = self.visual.dummy_forward()
        fake_embeds = fake_vision.pooler_output.mean() * 0.0
        fake_embeds = fake_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
        inputs_embeds = inputs_embeds + fake_embeds
        fake_deepstack = fake_vision.deepstack_features
        # --- Patch.5 ---

    rank_image_mask = None
    rank_video_mask = None

    # --- Patch.6 ---
    if self.training and get_parallel_state().sp_enabled:
        inputs_embeds = slice_input_tensor(inputs_embeds, dim=1, group=get_parallel_state().sp_group)

        sp_size = get_parallel_state().sp_size
        sp_rank = get_parallel_state().sp_rank

        # --- Patch.7 ---
        if pixel_values is not None:
            deepstack_image_embeds = [
                _Gather.apply(get_parallel_state().sp_group, embed, 0, False) for embed in deepstack_image_embeds
            ]
            image_mask_1d = image_mask[..., 0]
            seq_len = image_mask_1d.shape[1]
            seq_per_rank = seq_len // sp_size
            rank_start = sp_rank * seq_per_rank
            rank_end = rank_start + seq_per_rank
            rank_image_mask = image_mask_1d[:, rank_start:rank_end]
            offset = image_mask_1d[:, :rank_start].sum().item()
            num_visual_tokens = rank_image_mask.sum().item()
            deepstack_image_embeds = [embed[offset : offset + num_visual_tokens] for embed in deepstack_image_embeds]

        if pixel_values_videos is not None:
            deepstack_video_embeds = [
                _Gather.apply(get_parallel_state().sp_group, embed, 0, False) for embed in deepstack_video_embeds
            ]
            video_mask_1d = video_mask[..., 0]
            seq_len = video_mask_1d.shape[1]
            seq_per_rank = seq_len // sp_size
            rank_start = sp_rank * seq_per_rank
            rank_end = rank_start + seq_per_rank
            rank_video_mask = video_mask_1d[:, rank_start:rank_end]
            offset = video_mask_1d[:, :rank_start].sum().item()
            num_visual_tokens = rank_video_mask.sum().item()
            deepstack_video_embeds = [embed[offset : offset + num_visual_tokens] for embed in deepstack_video_embeds]
        # --- Patch.7 ---
    # --- Patch.6 ---

    visual_pos_masks = None
    deepstack_visual_embeds = None

    if pixel_values is not None and pixel_values_videos is not None:
        image_mask = rank_image_mask if rank_image_mask is not None else image_mask[..., 0]
        video_mask = rank_video_mask if rank_video_mask is not None else video_mask[..., 0]
        visual_pos_masks = image_mask | video_mask
        deepstack_visual_embeds = []
        image_mask_joint = image_mask[visual_pos_masks]
        video_mask_joint = video_mask[visual_pos_masks]
        for img_embed, vid_embed in zip(deepstack_image_embeds, deepstack_video_embeds):
            embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to(img_embed.device)
            embed_joint[image_mask_joint, :] = img_embed
            embed_joint[video_mask_joint, :] = vid_embed
            deepstack_visual_embeds.append(embed_joint)
    elif pixel_values is not None:
        image_mask = rank_image_mask if rank_image_mask is not None else image_mask[..., 0]
        visual_pos_masks = image_mask
        deepstack_visual_embeds = deepstack_image_embeds
    elif pixel_values_videos is not None:
        video_mask = rank_video_mask if rank_video_mask is not None else video_mask[..., 0]
        visual_pos_masks = video_mask
        deepstack_visual_embeds = deepstack_video_embeds
    else:
        # --- Patch.5 ---
        if fake_deepstack is not None:
            deepstack_visual_embeds = fake_deepstack
        # --- Patch.5 ---

    if attention_mask is not None and position_ids is None:
        if (
            cache_position is None
            or (cache_position is not None and cache_position[0] == 0)
            or self.rope_deltas is None
        ):
            delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
            position_ids, rope_deltas = self.get_rope_index(
                input_ids,
                image_grid_thw,
                video_grid_thw,
                attention_mask,
                audio_feature_lengths,
                video_second_per_grid,
            )
            rope_deltas = rope_deltas - delta0
            self.rope_deltas = rope_deltas
        else:
            batch_size, seq_length = input_ids.shape
            delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
            position_ids = torch.arange(seq_length, device=input_ids.device)
            position_ids = position_ids.view(1, -1).expand(batch_size, -1)
            position_ids = position_ids.add(delta)
            position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
    elif position_ids is not None:
        # --- Patch.9 ---
        if position_ids.ndim == 3 and position_ids.shape[1] == 3:
            position_ids = position_ids.transpose(0, 1).contiguous()
        # --- Patch.9 ---

    # --- Patch.3 ---
    kwargs.update(flash_attn_kwargs)
    # --- Patch.3 ---

    outputs = self.model(
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_router_logits=output_router_logits,
        cache_position=cache_position,
        deepstack_visual_embeds=deepstack_visual_embeds,
        visual_pos_masks=visual_pos_masks,
        **kwargs,
    )

    hidden_states = outputs[0]
    # --- Patch.8 ---
    loss = None
    logits = None
    fused_linear_aux = None
    if labels is not None:
        # Modification: OpSlot guard for cross-entropy loss.
        if veomni_causal_lm_loss.use_non_eager_impl:
            loss, logits, fused_linear_aux = veomni_causal_lm_loss(
                logits=logits,
                labels=labels,
                vocab_size=self.config.text_config.vocab_size,
                hidden_states=hidden_states,
                weights=self.lm_head.weight,
                ignore_index=IGNORE_INDEX,
                **kwargs,
            )
        else:
            logits = self.lm_head(hidden_states)
            # Modification: VeOmni's patched `loss_function` (via LOSS_MAPPING)
            # returns (loss, logits, fused_linear_aux); unpack to match the
            # OpSlot branch above.
            loss, _, fused_linear_aux = self.loss_function(
                logits=logits,
                labels=labels,
                vocab_size=self.config.text_config.vocab_size,
                hidden_states=hidden_states,
                weights=self.lm_head.weight,
                ignore_index=IGNORE_INDEX,
                **kwargs,
            )
            if fused_linear_aux is not None:
                # fused_linear_aux path empties loss/logits slots; clear the local 3D
                # logits so output mirrors the OpSlot branch's contract.
                logits = None
    else:
        logits = self.lm_head(hidden_states)
    # --- Patch.8 ---

    aux_loss = None
    if output_router_logits:
        # Modification: OpSlot guard for load-balancing loss.
        if veomni_load_balancing_loss.use_non_eager_impl:
            aux_loss = veomni_load_balancing_loss(
                outputs.router_logits,
                self.num_experts,
                self.num_experts_per_tok,
                attention_mask,
            )
        else:
            aux_loss = load_balancing_loss_func(
                outputs.router_logits,
                self.num_experts,
                self.num_experts_per_tok,
                attention_mask,
            )
        if labels is not None and isinstance(aux_loss, torch.Tensor):
            loss = loss + self.router_aux_loss_coef * aux_loss.to(loss.device)

    return Qwen3OmniMoeThinkerCausalLMOutputWithLogProbs(
        loss=loss,
        logits=logits,
        aux_loss=aux_loss,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
        past_key_values=outputs.past_key_values,
        router_logits=getattr(outputs, "router_logits", None),
        rope_deltas=self.rope_deltas,
        fused_linear_aux=fused_linear_aux,
    )


# ================================================================
# Patch: Qwen3OmniMoeForConditionalGeneration.__init__
# 1. [Talker] Force `has_talker=False` — VeOmni's training path only
#    forwards through `thinker` (see the patched `forward` below), so
#    constructing the talker and code2wav would only add unused
#    parameters and drag `Qwen3OmniMoeTalker*Layer` into the FSDP
#    `_no_split_modules` aggregation (HF recursively merges children's
#    `_no_split_modules` at `post_init`, see transformers
#    `modeling_utils.PreTrainedModel.post_init`). Unused talker layers
#    FSDP-wrapped but never forwarded cause a rank-desync hang during
#    asymmetric-modality forward.
# 2. [FSDP] After `post_init`, replace the aggregated
#    `self._no_split_modules` with the exact VeOmni target set. The
#    upstream top-level `Qwen3OmniMoePreTrainedModel._no_split_modules`
#    lists `Qwen3OmniMoeDecoderLayer` (a typo — no such class exists),
#    which `post_init` seeds into the aggregation. Resetting here
#    removes the phantom entry and pins the set to the three real
#    training targets.
# 3. [State-dict] Drop talker/code2wav keys on state-dict load.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeForConditionalGeneration.__init__",
    description="Skip talker, pin _no_split_modules, drop talker/code2wav keys on state-dict load",
)
def qwen3_omni_moe_for_conditional_generation_init_patched(self, config):
    super().__init__(config)
    self.thinker = Qwen3OmniMoeThinkerForConditionalGeneration._from_config(config.thinker_config)
    # --- Patch.1 ---
    self.has_talker = False
    # --- Patch.1 ---
    self.post_init()
    # --- Patch.2 ---
    self._no_split_modules = {
        "Qwen3OmniMoeThinkerTextDecoderLayer",
        "Qwen3OmniMoeVisionBlock",
        "Qwen3OmniMoeAudioEncoderLayer",
    }
    # --- Patch.2 ---

    # --- Patch.3 ---
    # Training builds the model with ``has_talker=False`` and excludes
    # ``Qwen3OmniMoeTalker*`` / ``Qwen3OmniMoeCode2Wav*`` classes from the
    # generated module entirely. Full pretrained checkpoints and HF-backend
    # state_dicts (e.g. the bitwise-equal patch test deep-copies the HF
    # model.state_dict() before handing it to VeOmni's load_state_dict)
    # still carry ``talker.*`` / ``code2wav.*`` keys, and the default
    # strict load_state_dict raises on them. Strip those keys at the
    # top-level prefix before the strict check fires — the rest of the
    # state_dict (thinker.*) loads unchanged.
    def _drop_talker_and_code2wav_keys(
        module, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
    ):
        if prefix:  # only filter at the top-level module
            return
        for k in list(state_dict.keys()):
            if k.startswith("talker.") or k.startswith("code2wav."):
                del state_dict[k]

    self.register_load_state_dict_pre_hook(_drop_talker_and_code2wav_keys)
    # --- Patch.3 ---


# ================================================================
# Patch: Qwen3OmniMoeForConditionalGeneration.enable_talker
# The talker + code2wav classes are excluded from the generated file
# (training never instantiates them), so the upstream body
# `self.talker = Qwen3OmniMoeTalkerForConditionalGeneration._from_config(...)`
# would fail at import-time static analysis and at call time. Replace
# with an explicit NotImplementedError so the reason is clear if anything
# reaches here.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeForConditionalGeneration.enable_talker",
    description="Disable talker/code2wav path in the training modeling (excluded classes)",
)
def qwen3_omni_moe_enable_talker_patched(self):
    raise NotImplementedError(
        "talker / code2wav are not available in the VeOmni training modeling. "
        "Use the upstream transformers implementation for TTS generation."
    )


# ================================================================
# Patch: Qwen3OmniMoeForConditionalGeneration.forward
# Simplified training path: only forward through thinker; talker +
# code2wav are skipped (only used in the TTS generate path).
# ================================================================
@config.override_method(
    "Qwen3OmniMoeForConditionalGeneration.forward",
    description="Forward through thinker only (talker/code2wav not trained via this path)",
)
def qwen3_omni_moe_for_conditional_generation_forward_patched(
    self,
    **kwargs,
) -> tuple | Qwen3OmniMoeThinkerCausalLMOutputWithLogProbs:
    return self.thinker(**kwargs)


# ================================================================
# Patch: Qwen3OmniMoeForConditionalGeneration.get_position_id_func (NEW)
# Delegate to the thinker's closure; data pipeline calls the top-level
# model's get_position_id_func.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeForConditionalGeneration.get_position_id_func",
    description="Delegate position-id computation to the thinker submodule",
)
def qwen3_omni_moe_top_get_position_id_func_patched(self):
    return self.thinker.get_position_id_func()


# ================================================================
# Patch: Qwen3OmniMoeForConditionalGeneration.get_metadata_collate_func (NEW)
# Expose the ViT metadata derivation (cu_seqlens / max_seqlen) to VeOmni's
# collator as a picklable callable, mirroring get_position_id_func.
# See .agents/knowledge/multimodal_metadata.md.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeForConditionalGeneration.get_metadata_collate_func",
    description="Expose CPU-side ViT multimodal-metadata derivation to the VeOmni collator",
)
def qwen3_omni_moe_top_get_metadata_collate_func_patched(self):
    # collate_multimodal_metadata is a module-level helper (defined via
    # add_post_import_block) — a bare function reference is picklable for the
    # DataLoader workers; the temporal-unroll formula needs no model config.
    return collate_multimodal_metadata  # noqa: F821 defined via add_post_import_block


# ================================================================
# Patch: Qwen3OmniMoeForConditionalGeneration.get_extra_collate_infos (NEW)
# Declare the omni-specific collate rules (audio feature tensors) so the
# trainer doesn't hardcode them by model_type — the model owns its own
# modality-specific collate topology. Tuples are
# (pack_dim, sp_slice, sp_pad_value, sp_pad_scale); MainCollator coerces them.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeForConditionalGeneration.get_extra_collate_infos",
    description="Declare omni-specific (audio) collate rules for the VeOmni collator",
)
def qwen3_omni_moe_top_get_extra_collate_infos_patched(self):
    return {
        "audio_feature_lengths": (0, False, None, None),
        "input_features": (0, True, 0, 1),
        "audio_mask": (-1, False, 0, 1),
    }


# ================================================================
# Patch: Qwen3OmniMoeForConditionalGeneration.get_parallel_plan (NEW)
# Register the VeOmni EP plan for thinker.model.layers.*.mlp.experts.*
# on the generated v5 modeling.
# ================================================================
@config.override_method(
    "Qwen3OmniMoeForConditionalGeneration.get_parallel_plan",
    description="Register Qwen3-Omni-MoE thinker expert parallel plan for v5 generated modeling",
)
def qwen3_omni_moe_get_parallel_plan_patched(self):
    from ..parallel_plan import get_parallel_plan as _get_parallel_plan

    return _get_parallel_plan()
