Kernel Selection in VeOmni#
VeOmni selects optimized kernel implementations for attention, cross-entropy
loss, Liger fused ops (RMSNorm, RoPE, SwiGLU), MoE, and load-balancing loss.
All selections are driven by config fields in OpsImplementationConfig.
Quick Reference#
Every configurable kernel lives under model.ops_implementation.* in YAML and
maps to a field on OpsImplementationConfig (veomni/arguments/arguments_types.py).
Below is the full list — if a field is not in this table, it is not a kernel
selection knob.
Kernel |
Config field |
Available values |
Default |
Selection time |
|---|---|---|---|---|
Attention |
|
|
|
Config |
Cross-entropy loss |
|
|
|
|
RMSNorm |
|
|
|
Model registration via ops config singleton |
SwiGLU MLP |
|
|
|
Model registration via ops config singleton |
Rotary embedding |
|
|
|
Model registration via ops config singleton |
Load-balancing loss |
|
|
|
|
MoE experts |
|
|
|
|
Defaults are GPU-optimal. On Ascend NPU the defaults above raise at
OpsImplementationConfig.__post_init__ time — NPU users must set every field
explicitly to an NPU-supported value (npu / chunk_loss / fused_npu /
triton for load-balancing loss) or to eager when the op has no NPU
backend (swiglu_mlp_implementation, DeepSeek-V3 RMSNorm/RoPE, Qwen2-VL
multimodal RoPE). The error message lists the allowed alternatives per op.
The per-op fields are typed as plain str (not Literal), so third-party
backends can be registered via extra_backends in a model’s device_patch.py
without modifying OpsImplementationConfig.
Lifecycle Overview#
import veomni # (1) import time
└─ apply_ops_patch()
└─ apply_veomni_attention_patch() # register FA2/3/4 with SP
OpsImplementationConfig.__post_init__() # (2) config parse time
├─ validate requested backends are available
├─ rewrite attn_implementation for SP
└─ set_ops_config(self) # populate singleton
BaseTrainer._build_model() # (3) model build time
└─ build_foundation_model(..., ops_implementation=ops)
├─ apply_ops_config(ops) # install LOSS_MAPPING + GLOBAL patches
│ ├─ install_loss_mapping(ce_impl) # partial(ForCausalLMLoss, cross_entropy_fn=<impl>)
│ └─ apply_global_ops(config) # load_balancing_loss, etc.
├─ apply_veomni_fused_moe_patch(...) # bind MoE kernel
├─ device_patch.py reads ops config # RMSNorm/RoPE/SwiGLU
├─ OpSlot.bind(impl_name) # per-model OpSlot dispatch
└─ model init + weight loading
model.forward() # (4) runtime
├─ attention: ALL_ATTENTION_FUNCTIONS[config._attn_implementation]
├─ loss: self.loss_function(...) -> LOSS_MAPPING[...] (pre-bound partial)
│ OR veomni_causal_lm_loss(...) via OpSlot.use_non_eager_impl guard
├─ RMSNorm/RoPE/SwiGLU: Liger or HF default (set at registration)
└─ MoE: fused_moe_forward(...) or eager loop
Single install point. apply_ops_config is the only place that binds
LOSS_MAPPING — there is no separate apply_veomni_loss_patch call. The
inner CE kernel (eager / liger / npu) is pre-bound onto the wrapper via
functools.partial, so runtime dispatch is just a function call and there
is no per-forward “which impl?” lookup.
Ownership. build_foundation_model owns the call to apply_ops_config:
when callers pass ops_implementation=ops (trainers do this), it runs
apply_ops_config(ops) before constructing the model and reads
attn_implementation from ops. Callers that pass neither
ops_implementation nor a prior apply_ops_config raise ValueError —
there is no silent all-eager fallback. Standalone scripts (tasks/infer/*)
construct an explicit OpsImplementationConfig (typically all-eager so
inference doesn’t depend on Liger / Triton). The DiT trainer is the one
exception that calls apply_ops_config manually — it has to populate the
singleton before building the condition model, which uses
model_class._from_config(...) rather than build_foundation_model. The
subsequent build_foundation_model call hits the
“singleton-already-installed” branch and leaves the prior config alone.
1. Attention#
Config#
model:
ops_implementation:
attn_implementation: flash_attention_2 # default
Field: OpsImplementationConfig.attn_implementation
Available implementations#
Value |
Kernel |
Sequence Parallel |
Requirements |
|---|---|---|---|
|
PyTorch |
No |
— |
|
|
No |
— |
|
Flash Attention v2 |
Yes |
|
|
Flash Attention v3 |
Yes |
|
|
Flash Attention v4 |
Yes |
|
|
Sparse attention |
No |
— |
When MODELING_BACKEND=veomni (the default), __post_init__ automatically
rewrites flash_attention_2/3/4 to VeOmni SP-aware variants
(veomni_flash_attention_2_with_sp, etc.) which wrap the underlying kernel
with DeepSpeed Ulysses sequence parallelism gather/scatter.
Key files#
Config:
veomni/arguments/arguments_types.py—OpsImplementationConfigRegistration:
veomni/ops/kernels/attention/__init__.py—apply_veomni_attention_patch()Plumbing:
veomni/models/auto.py—build_foundation_model(attn_implementation=...)
2. Cross-Entropy Loss#
Config#
model:
ops_implementation:
cross_entropy_loss_implementation: liger_kernel # default; set to "chunk_loss" / "npu" / "eager" on NPU
Field: OpsImplementationConfig.cross_entropy_loss_implementation
Available implementations#
Value |
Implementation |
Requirements |
|---|---|---|
|
|
|
|
|
|
|
|
— |
The npu chunk-loss binds only to ForCausalLM and
ForConditionalGeneration; ForSequenceClassification stays on
eager_cross_entropy because chunk_loss hard-codes the causal
labels[..., 1:] shift (incompatible with token-level classification
labels).
Selecting liger_kernel requires that the model’s forward pass pass
hidden_states= and weights=self.lm_head.weight through
self.loss_function(...) — the Liger fused linear+CE kernel does the
projection itself and has no full logits tensor to fall back on. VeOmni’s
patched modeling files (patched_modeling_*.py) already do this. If a model
whose forward was not patched calls self.loss_function without these
kwargs while cross_entropy_loss_implementation="liger_kernel", the Liger
kernel raises RuntimeError with a pointer to the patch pattern — it does
not silently fall back to eager. Switch the field to eager if the
model cannot be patched.
Key files#
Dispatch:
veomni/ops/kernels/cross_entropy/__init__.py—install_loss_mapping(impl)Eager impl:
veomni/ops/kernels/cross_entropy/eager.pyLiger impl:
veomni/ops/kernels/cross_entropy/liger.pyNPU chunk loss:
veomni/ops/kernels/cross_entropy/chunk_loss.py—chunk_loss_function
3. Per-Model Ops (RMSNorm, RoPE, SwiGLU MLP)#
Each operation can be independently controlled. Despite the historical
“Liger fused ops” label, these fields are not Liger-only: they also accept
npu (for Ascend NPU backends) and triton (for model-specific Triton
kernels registered in the model’s device_patch.py, e.g. DeepSeek-V3’s
batch-invariant RMSNorm and deterministic RoPE).
Config#
model:
ops_implementation:
rms_norm_implementation: liger_kernel # default; pin to "npu" / "eager" on NPU
swiglu_mlp_implementation: liger_kernel # default; pin to "eager" on NPU (no NPU backend)
rotary_pos_emb_implementation: liger_kernel # default; pin to "npu" / "eager" on NPU
Available implementations#
rms_norm_implementation#
Value |
Implementation |
Requirements |
|---|---|---|
|
|
|
|
|
|
|
Model-specific Triton kernel registered via |
|
|
HuggingFace default ( |
— |
rotary_pos_emb_implementation#
Value |
Implementation |
Requirements |
|---|---|---|
|
|
|
|
|
|
|
Model-specific Triton kernel registered via |
|
|
HuggingFace default ( |
— |
swiglu_mlp_implementation#
Value |
Implementation |
Requirements |
|---|---|---|
|
|
|
|
HuggingFace default ( |
— |
What gets patched#
For each selected backend, the model’s device_patch.py either swaps the
target HF class (replace_forward=False) or rebinds its forward
(replace_forward=True). The summary of the Liger swap shape (the most
common case):
Config field |
Original |
Liger replacement |
|---|---|---|
|
|
|
|
|
|
|
|
|
The npu and triton backends follow the same device_patch.py flow — the
only difference is the kernel callable on the other side of the registry.
Models with Liger support#
Qwen2, Qwen3, Qwen3-MoE, Qwen2-VL, DeepSeek-V3, Llama, Seed-OSS.
Key files#
Config singleton:
veomni/ops/config/singleton.py—get_ops_config(),set_ops_config()Unified registry:
veomni/ops/config/registry.py—register_op(),apply_per_model_patches(),apply_global_ops()OSS backend registration:
veomni/ops/kernels/{rms_norm,rotary,swiglu}/__init__.pyPer-model
extra_backends(e.g. DeepSeek-V3 Triton):veomni/models/transformers/{model}/device_patch.py
4. Load-Balancing Loss#
Config#
model:
ops_implementation:
load_balancing_loss_implementation: triton # default; pin to "eager" on NPU (triton-ascend not exposed as `triton`)
Field: OpsImplementationConfig.load_balancing_loss_implementation
Available implementations#
Value |
Implementation |
Requirements |
|---|---|---|
|
Fused Triton kernel ( |
|
|
Pure-PyTorch reference ( |
— |
This is a GLOBAL-scope op: the function pointer
veomni.ops.kernels.load_balancing_loss._load_balancing_loss is rebound
once per process from apply_ops_config(), and every call site that
imports from veomni.ops import load_balancing_loss_func picks up the
selected backend automatically — no per-model patching needed.
Key files#
Selection:
veomni/ops/kernels/load_balancing_loss/__init__.py—register_op(...)entryTriton impl:
veomni/ops/kernels/load_balancing_loss/triton.pyEager impl:
veomni/ops/kernels/load_balancing_loss/eager.py
5. MoE Kernel#
Config#
model:
ops_implementation:
moe_implementation: fused_triton # Triton group-gemm (GPU, SM70+)
# moe_implementation: fused_quack # Quack CUTLASS/CuTe (GPU, SM90+)
# moe_implementation: fused_npu # NPU group-gemm (Ascend)
# moe_implementation: eager # Reference PyTorch loop (very slow, debug only)
Field: OpsImplementationConfig.moe_implementation
Default: "fused_triton" (GPU). On NPU set to "fused_npu" or "eager" — fused_triton / fused_quack raise at config validation time.
The mode and kernel backend are expressed as a single field. Mismatches raise
at apply_veomni_fused_moe_patch time — no silent hardware fallback.
Value |
Kernel |
Hardware |
EP support |
|---|---|---|---|
|
PyTorch expert loop |
Any |
No |
|
Triton group-gemm |
GPU, SM70+ (V100+) |
Yes |
|
Quack CUTLASS/CuTe |
GPU, SM90+ (H100+) |
No |
|
NPU group-gemm |
Ascend NPU |
Yes |
Key files#
Config:
veomni/arguments/arguments_types.py—OpsImplementationConfigDispatch:
veomni/ops/kernels/moe/__init__.py—apply_veomni_fused_moe_patch()Plumbing:
veomni/models/auto.py—build_foundation_model(moe_implementation=...)
Environment Variables#
Env var |
Default |
Scope |
Notes |
|---|---|---|---|
|
|
Global |
|
Kernel selection is otherwise driven by OpsImplementationConfig fields.
The VEOMNI_USE_LIGER_KERNEL and USE_GROUP_GEMM environment variables
have been removed in favor of the per-op config fields.
All remaining env vars are registered in veomni/utils/env.py with defaults and can be
overridden by setting the corresponding shell environment variable.
5. Comparison with Transformers v5+ Kernel Selection#
Transformers v5 (transformers>=4.57) introduces a unified kernel selection
framework that replaces the ad-hoc patching used in earlier versions.
This section compares VeOmni’s approach (Sections 1-4 above) with the four
mechanisms available in Transformers v5, using Qwen3MoE and Qwen3.5MoE as
reference models.
5.1 Transformers v5 Mechanisms Overview#
# |
Mechanism |
Decorator / API |
What it replaces |
Scope |
|---|---|---|---|---|
1 |
Hub kernel layers |
|
|
Per-class, via |
2 |
Hub kernel functions |
|
Standalone functions (e.g. |
Per-function, via |
3 |
Attention interface |
|
Attention forward pass |
Per-model via |
4 |
Experts interface |
|
MoE expert forward pass |
Per-class via |
All four are defined in transformers.integrations:
hub_kernels.py— mechanisms 1 & 2moe.py— mechanism 4modeling_utils.py— mechanism 3 (ALL_ATTENTION_FUNCTIONS)
5.2 Side-by-Side Comparison#
RMSNorm#
VeOmni |
Transformers v5 |
|
|---|---|---|
Mechanism |
|
|
Config |
|
|
When |
Model registration (import time) |
Deferred — |
SP support |
N/A (norm is local) |
N/A |
Qwen3.5 MoE gap |
Same as Qwen3 — Liger swap works |
Not annotated. |
Rotary Position Embedding (RoPE)#
VeOmni |
Transformers v5 |
|
|---|---|---|
Mechanism |
|
|
Config |
|
|
When |
Model registration (import time) |
Import time (decorator) + |
Qwen3.5 MoE gap |
N/A — VeOmni does not yet support Qwen3.5 MoE |
Partially annotated. |
Attention#
VeOmni |
Transformers v5 |
|
|---|---|---|
Mechanism |
|
Same |
Config |
|
|
SP rewrite |
|
No SP support — upstream Transformers does not handle Ulysses SP |
Compatibility |
VeOmni registers into the same |
MoE Experts#
VeOmni |
Transformers v5 |
|
|---|---|---|
Mechanism |
A module-level |
|
Config |
|
|
EP support |
|
|
When |
Deferred to |
Decorator at class definition time; dispatch at forward time |
Note: Transformers v5 hardcodes two MoE experts implementations (batched_mm and grouped_mm) and does not expose a registration interface for external fused kernels, so backends like VeOmni’s Triton / Quack / NPU group-gemm must be plugged in through the OpSlot dispatch layer rather than via ALL_EXPERTS_FUNCTIONS.
5.3 Gaps — What Transformers v5 Does NOT Cover#
The following areas have kernel selection in VeOmni but no corresponding mechanism in Transformers v5:
1. Fused Cross-Entropy Loss#
Transformers v5 uses a loss_function property on PreTrainedModel that looks
up LOSS_MAPPING[self.loss_type] — this returns a standard PyTorch
F.cross_entropy-based loss. There is no decorator, no hub kernel, and no
env-var-based kernel swap for the loss function.
VeOmni replaces this at model-build time via apply_ops_config(...) →
install_loss_mapping(impl), which binds LOSS_MAPPING["ForCausalLM"] to
partial(ForCausalLMLoss, cross_entropy_fn=<impl>) — where <impl> is
fused_liger_kernel_cross_entropy (GPU liger_kernel), chunk_loss_function
(NPU), or eager_cross_entropy (portable default). The fused Liger
cross-entropy computes the loss without materializing the full logits
tensor, which significantly reduces memory for large-vocabulary models.
Implication: When using VeOmni’s trainer or build_foundation_model
with ops_implementation=..., the fused loss is transparent. A standalone
Transformers training loop that doesn’t go through build_foundation_model
would need to call apply_ops_config(OpsImplementationConfig(...))
itself before model construction (or directly monkey-patch LOSS_MAPPING).
2. MoE Load-Balancing Auxiliary Loss#
Both Qwen3MoE and Qwen3.5MoE in Transformers v5 include a standalone
load_balancing_loss_func() that computes the Switch Transformer auxiliary
loss. This function is called directly in Qwen3MoeForCausalLM.forward() —
there is no kernel selection, no registry, and no hub kernel for it.
VeOmni similarly does not provide a fused kernel for the auxiliary loss, but
this is worth noting because the load-balancing loss involves several
one_hot → mean → dot operations that could benefit from fusion, especially
at scale with many experts.
3. Qwen3.5 MoE Variant-Specific Ops#
Qwen3.5 MoE introduces architectural differences that prevent direct use of the standard hub kernel annotations:
Component |
Qwen3 MoE |
Qwen3.5 MoE |
Why standard kernel fails |
|---|---|---|---|
RMSNorm |
|
|
LigerRMSNorm assumes no offset; applying it would produce incorrect results |
RoPE |
Full rotary on all dims |
Partial rotary ( |
Hub |
RMSNormGated |
N/A |
|
Uses explicit |
RMSNormGated: explicit fla library selection (not the hub kernel framework)
Unlike RMSNorm and RoPE above, Qwen3.5 MoE’s RMSNormGated does have a
fused kernel path — but it bypasses the Transformers v5 @use_kernel_forward_from_hub
framework entirely. Instead, Qwen3_5MoeGatedDeltaNet.__init__ performs a
hard-coded conditional selection at model init time:
# transformers/models/qwen3_5_moe/modeling_qwen3_5_moe.py
# At module top level:
if is_flash_linear_attention_available():
from fla.modules import FusedRMSNormGated
from fla.ops.gated_delta_rule import chunk_gated_delta_rule, fused_recurrent_gated_delta_rule
else:
chunk_gated_delta_rule, fused_recurrent_gated_delta_rule = None, None
FusedRMSNormGated = None
# In Qwen3_5MoeGatedDeltaNet.__init__:
self.norm = (
Qwen3_5MoeRMSNormGated(self.head_v_dim, eps=self.layer_norm_epsilon)
if FusedRMSNormGated is None
else FusedRMSNormGated(
self.head_v_dim,
eps=self.layer_norm_epsilon,
activation=self.activation,
device=torch.cuda.current_device(),
dtype=config.dtype if config.dtype is not None else torch.get_default_dtype(),
)
)
This is a 5th kernel selection pattern — not covered by any of the four
Transformers v5 mechanisms. It is a simple if library_available else fallback
check, similar to how the same file selects between causal_conv1d_fn (from
the causal-conv1d library) and a pure-PyTorch torch_causal_conv1d_update
fallback, and between chunk_gated_delta_rule (from fla.ops) and
torch_chunk_gated_delta_rule.
Key characteristics of this pattern:
No decorator, no registry, no env var — purely hard-coded
if/elsein__init__Library:
flash-linear-attention(fla) — a separate library from both Liger and thekernelshubScope: Only the Gated DeltaNet linear attention layers in Qwen3.5 MoE; the standard full-attention
Qwen3_5MoeAttentionlayers do not use this normNot configurable at runtime — determined solely by whether
flais installedFusedRMSNormGatedfuses the RMSNorm + SiLU gate multiply into a single Triton kernel, which the eagerQwen3_5MoeRMSNormGateddoes in two steps:hidden = weight * (x / rms)thenhidden = hidden * silu(gate)
In Transformers v5, these remaining Qwen3.5 MoE ops (RMSNorm with +1 offset,
partial RoPE) are left un-annotated — they always run the eager PyTorch
implementation. In theory, fused kernels could still be written for each (e.g.,
a Triton RMSNorm with +1 offset, a partial-RoPE kernel), but no such kernels
currently exist in the kernels-community hub.
5.4 Summary Table#
Component |
VeOmni mechanism |
Transformers v5 mechanism |
Compatible? |
Gap |
|---|---|---|---|---|
RMSNorm |
|
|
Parallel — both can apply |
Qwen3.5 MoE |
RoPE |
|
|
Parallel |
Qwen3.5 MoE partial RoPE not covered by either |
SwiGLU MLP |
|
Not annotated in MoE models (MLP is per-expert, not standalone) |
VeOmni only |
— |
Attention |
|
|
Yes |
VeOmni adds SP wrapping |
MoE experts |
|
|
No — different dispatch paths |
VeOmni uses custom Triton kernels; HF uses PyTorch native |
Cross-entropy |
|
|
VeOmni only |
HF has no fused loss |
MoE aux loss |
Eager (same as HF) |
Eager |
Same |
Neither provides a fused kernel |
RMSNormGated |
N/A |
Hard-coded |
— |
Bypasses all 4 HF v5 mechanisms; 5th ad-hoc pattern |
Full Config Example#
model:
ops_implementation:
attn_implementation: flash_attention_2
moe_implementation: fused_triton
cross_entropy_loss_implementation: liger_kernel
rms_norm_implementation: liger_kernel
swiglu_mlp_implementation: eager # disable Liger for MLP only
rotary_pos_emb_implementation: liger_kernel
load_balancing_loss_implementation: triton