import os
from dataclasses import asdict, dataclass, fields
from typing import Dict

import torch
from rich.console import Console
from rich.table import Table
from transformers import set_seed

from veomni.arguments.arguments_types import OpsImplementationConfig
from veomni.data.dummy_dataset import build_dummy_dataset
from veomni.ops import apply_ops_config
from veomni.utils.import_utils import is_liger_kernel_available, is_package_available, is_torch_npu_available


_LIGER_KERNEL = "liger_kernel"
_EAGER = "eager"
# Only include the Liger-backed mode when the package is actually installed
# (it's unavailable on the NPU CI runner); otherwise every parametrized case
# would fail in OpsImplementationConfig.__post_init__ with a ValueError.
_USE_LIGER_KERNEL = [True, False] if is_liger_kernel_available() else [False]
# NPU ships triton-ascend (not mainline ``triton``); the fused Triton
# load-balancing-loss kernel at ``veomni/ops/kernels/load_balancing_loss/triton.py``
# imports ``triton`` unconditionally, so fall back to the pure-PyTorch backend
# when the mainline package is missing.
_LOAD_BALANCING_LOSS_IMPL = "triton" if is_package_available("triton") else "eager"
# Pick the fused-MoE backend that matches the test hardware. On NPU the NPU
# kernel is the only option; on GPU default to Triton (SM70+).
_FUSED_MOE_IMPL = "fused_npu" if is_torch_npu_available() else "fused_triton"


@dataclass(frozen=True)
class ModelMode:
    modeling_backend: str
    attn_implementation: str
    moe_implementation: str = "eager"
    use_liger_kernel: bool = False

    def __str__(self):
        return (
            f"{self.modeling_backend}_[attn-{self.attn_implementation}]"
            f"_[moe-{self.moe_implementation}]_[ligerkernel-{self.use_liger_kernel}]"
        )


# HF uses _HF_ATTN, VeOmni uses _VEOMNI_ATTN x _USE_LIGER_KERNEL.
# On NPU skip FA3.
_HF_ATTN = ["flash_attention_2", "flash_attention_3"]
_VEOMNI_ATTN = [
    "veomni_flash_attention_2_with_sp",
    "veomni_flash_attention_3_with_sp",
]


def _skip_fa3_npu(attn_impl: str) -> bool:
    """Skip FA3 on NPU."""
    if not is_torch_npu_available():
        return False
    return attn_impl in ("flash_attention_3", "veomni_flash_attention_3_with_sp")


def _append_veomni_modes(modes: list, moe_implementation: str = "eager"):
    """Append VeOmni modes for case; every attn uses _USE_LIGER_KERNEL (True/False)."""
    for veomni_attn in _VEOMNI_ATTN:
        if _skip_fa3_npu(veomni_attn):
            continue
        for use_liger in _USE_LIGER_KERNEL:
            modes.append(
                ModelMode(
                    "veomni",
                    veomni_attn,
                    moe_implementation=moe_implementation,
                    use_liger_kernel=use_liger,
                )
            )


def _base_model_modes():
    """Base (non-MoE) model modes."""
    modes = []
    for hf_attn in _HF_ATTN:
        if _skip_fa3_npu(hf_attn):
            continue
        modes.append(ModelMode("hf", hf_attn))
    _append_veomni_modes(modes)
    return modes


def _moe_model_modes():
    """MoE model modes: same attn variants with a fused MoE backend matching the hardware."""
    modes = []
    for hf_attn in _HF_ATTN:
        if _skip_fa3_npu(hf_attn):
            continue
    _append_veomni_modes(modes, moe_implementation=_FUSED_MOE_IMPL)
    return modes


def prepare_model_modes(is_moe: bool = False):
    """
    Build model modes for patch tests.

    Args:
        is_moe: If True, include MoE-specific modes (e.g. fused MoE).
    """
    base_modes = _base_model_modes()
    moe_modes = _moe_model_modes()
    final_models_modes = base_modes + moe_modes if is_moe else base_modes

    hf_model_modes = [m for m in final_models_modes if m.modeling_backend == "hf"]
    veomni_model_modes = [m for m in final_models_modes if m.modeling_backend == "veomni"]
    return hf_model_modes, veomni_model_modes


MODEL_TO_DATASET = {
    "qwen3_vl": "qwen3vl",
    "qwen3_5": "qwen3vl",
    "qwen3_5_moe": "qwen3vl",
    "qwen3_vl_moe": "qwen3vl",
    "qwen2_vl": "qwen2vl",
    "qwen2_5_vl": "qwen2vl",
    "qwen2_5_omni": "qwen2omni",
    "qwen3_omni_moe": "qwen3omni",
}

UNSQUEECE_KEYS = ["input_ids", "attention_mask", "labels", "position_ids", "image_mask", "video_mask", "audio_mask"]


def parse_token_id_from_config(model_config):
    if model_config.model_type not in MODEL_TO_DATASET:
        return {}
    if model_config.model_type in ["qwen2_5_omni", "qwen3_omni_moe"]:
        token_ids_dict = {
            "image_token_id": model_config.thinker_config.image_token_id,
            "video_token_id": model_config.thinker_config.video_token_id,
            "audio_token_id": model_config.thinker_config.audio_token_id,
        }
    else:
        token_ids_dict = {
            "image_token_id": model_config.image_token_id,
            "video_token_id": model_config.video_token_id,
        }
    return token_ids_dict


def _to_feature_dict(example: list, token_ids_dict: dict) -> dict:
    """Convert dataset item to feature dict for MainCollator (no batch dim, no precomputed cu_seqlens)."""
    example = example[0]
    example = {key: (v.clone() if isinstance(v, torch.Tensor) else torch.tensor(v)) for key, v in example.items()}

    if "image_mask" in example and "image_token_id" in token_ids_dict:
        example["input_ids"] = example["input_ids"].masked_fill(
            example["image_mask"], token_ids_dict["image_token_id"]
        )
    if "video_mask" in example and "video_token_id" in token_ids_dict:
        example["input_ids"] = example["input_ids"].masked_fill(
            example["video_mask"], token_ids_dict["video_token_id"]
        )
    if "audio_mask" in example and "audio_token_id" in token_ids_dict:
        example["input_ids"] = example["input_ids"].masked_fill(
            example["audio_mask"], token_ids_dict["audio_token_id"]
        )

    return example


def prepare_data(model_name: str, max_seq_len: int, model_config):
    """
    Build a DummyDataLoader that yields raw features (list of 2 dicts).
    MainCollator packing + cu_seqlens precompute is done in the trainer (TrainerTest).
    """

    dataset_name = MODEL_TO_DATASET.get(model_name, "text")
    dataset = build_dummy_dataset(dataset_name, 2, max_seq_len)

    token_ids_dict = parse_token_id_from_config(model_config)

    class DummyDataLoader:
        def __iter__(self):
            set_seed(42)
            features = [
                _to_feature_dict(dataset[0], token_ids_dict),
                _to_feature_dict(dataset[1], token_ids_dict),
            ]
            yield features

    return DummyDataLoader()


def print_all_values(output_dict, value_key: str, model_type: str = ""):
    console = Console()
    first_mode = next(iter(output_dict.keys()))

    table = Table(title=f"Alignment Result: [bold magenta]{model_type} {value_key}[/bold magenta]")
    mode_fields = [f.name for f in fields(first_mode)]

    for field in mode_fields:
        table.add_column(field, style="cyan", justify="left")

    table.add_column(value_key.upper(), style="bold green", justify="right")

    for mode, output in output_dict.items():
        mode_data = asdict(mode)
        row_cells = []

        for field in mode_fields:
            row_cells.append(str(mode_data[field]))

        val_obj = output.get(value_key, "N/A")
        val_str = f"{val_obj.item() if hasattr(val_obj, 'item') else val_obj:.8f}"
        row_cells.append(val_str)

        table.add_row(*row_cells)

    console.print(table)


def compare_multi_items(outputs_dict: Dict, rtol=0.01, atol=0.01):
    base_task = next(iter(outputs_dict))
    base_output = outputs_dict[base_task]

    for task, output in outputs_dict.items():
        if task == base_task:
            continue
        for key in output.keys():
            try:
                torch.testing.assert_close(
                    output[key],
                    base_output[key],
                    rtol=rtol,
                    atol=atol,
                )
            except AssertionError as e:
                print_all_values(outputs_dict, key)
                raise AssertionError(f"{key} not match") from e


def apply_veomni_loss_unpatch():
    from transformers.loss.loss_utils import (
        LOSS_MAPPING,
        ForCausalLMLoss,
        ForSequenceClassificationLoss,
    )

    LOSS_MAPPING["ForCausalLM"] = ForCausalLMLoss
    LOSS_MAPPING["ForConditionalGeneration"] = ForCausalLMLoss
    LOSS_MAPPING["ForSequenceClassification"] = ForSequenceClassificationLoss


def apply_veomni_moe_unpatch():
    from veomni.ops.kernels import moe

    moe._fused_moe_forward = None


def _build_ops_config_for_mode(model_mode: ModelMode) -> OpsImplementationConfig:
    """Build an OpsImplementationConfig from a ModelMode for testing."""
    liger_impl = _LIGER_KERNEL if model_mode.use_liger_kernel else _EAGER
    # Linear-attention fields (rms_norm_gated/causal_conv1d/chunk_gated_delta_rule)
    # are intentionally left at their production default ``"fla"``. Qwen3.5
    # has no NPU backend for these ops, so the qwen3_5 / qwen3_5_moe
    # parametrize cases in test_models_patch.py skip on NPU at the test
    # level.
    return OpsImplementationConfig(
        attn_implementation=model_mode.attn_implementation,
        moe_implementation=model_mode.moe_implementation,
        cross_entropy_loss_implementation=liger_impl,
        rms_norm_implementation=liger_impl,
        swiglu_mlp_implementation=liger_impl,
        rotary_pos_emb_implementation=liger_impl,
        load_balancing_loss_implementation=_LOAD_BALANCING_LOSS_IMPL,
    )


def set_environ_param(model_mode: ModelMode):
    apply_veomni_loss_unpatch()
    apply_veomni_moe_unpatch()
    if model_mode.modeling_backend == "veomni":
        os.environ["MODELING_BACKEND"] = "veomni"
    else:
        os.environ["MODELING_BACKEND"] = "hf"

    ops_config = _build_ops_config_for_mode(model_mode)
    apply_ops_config(ops_config)
