LoRA Fine-Tuning#
VeOmni supports LoRA (Low-Rank Adaptation) as a first-class
feature of BaseTrainer. LoRA injects trainable low-rank matrices into selected linear layers
while freezing the rest of the base model, enabling parameter-efficient fine-tuning with
significantly reduced GPU memory.
Installation#
LoRA support requires the peft library. Install it via the lora optional extra:
uv sync --extra gpu --extra lora --group dev
Or with pip:
pip install peft==0.18.0
1. LoRA Config Definition#
LoRA is configured through the model.lora_config field in your YAML:
model:
lora_config:
rank: 64 # LoRA rank r
alpha: 32 # Scaling factor α (effective scale = α / r)
lora_modules: # Target linear layer names (module name substrings)
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
Field |
Type |
Description |
|---|---|---|
|
int |
Rank of the decomposition matrices A and B |
|
int |
LoRA scaling factor; effective scale = |
|
list[str] |
Target linear layer name substrings (matched against module FQNs) |
|
str (optional) |
Path to a saved adapter directory to resume from |
For resume training, add the lora_adapter key pointing to the saved adapter directory:
model:
lora_config:
rank: 64
alpha: 32
lora_modules: [q_proj, k_proj, v_proj, o_proj]
lora_adapter: ./exp/my_run/global_step_500 # HF adapter dir to resume from
2. LoRA Initialization in BaseTrainer#
LoRA wrapping happens in BaseTrainer._setup_lora(), called from _freeze_model_module():
# veomni/trainer/base.py
def _setup_lora(self):
lora_config = self.args.model.lora_config
if not bool(lora_config):
return
lora_adapter_path = lora_config.get("lora_adapter", None)
if lora_adapter_path is not None:
# Resume: read PEFT config from disk; weights loaded later during parallelization
from peft import PeftModel
self.model = PeftModel.from_pretrained(
self.model, lora_adapter_path, is_trainable=True
)
else:
# From scratch: wrap with LoraConfig
from peft import LoraConfig, get_peft_model
peft_cfg = LoraConfig(
r=lora_config["rank"],
lora_alpha=lora_config["alpha"],
target_modules=lora_config["lora_modules"],
)
self.model = get_peft_model(self.model, peft_cfg)
After wrapping, BaseTrainer._init_callbacks() automatically selects HFLoraCkptCallback
instead of HuggingfaceCkptCallback when lora_config is set:
if self.args.model.lora_config:
self.hf_ckpt_callback = HFLoraCkptCallback(self)
else:
self.hf_ckpt_callback = HuggingfaceCkptCallback(self)
3. Weight Loading with LoRA#
VeOmni LoRA training uses FSDP2 with init_device: meta. Weight loading goes through
build_parallelize_model and then post_process_after_weight_loading in
torch_parallelize.py. The LoRA-specific path:
Base-model weights: loaded via
rank0_load_and_broadcast_weightsorload_model_weights— the standard FSDP2 path, unchanged for LoRA.Adapter weights (resume only):
_build_parallelized_modelpassesadapter_pathtobuild_parallelize_model, which callsload_lora_model_weights(all-ranks read) orrank0_load_and_broadcast_adapter_weights(rank-0 reads then broadcasts). Both functions remap PEFT keys to model FQNs before dispatching into DTensors.Adapter weight initialisation from scratch:
post_process_after_weight_loadingcalls_init_lora_parameterfor any LoRA parameter not yet filled, invokingreset_lora_parametersto apply kaiming/zero init.
Key difference from base model loading: PEFT saves adapter keys without the adapter-name
infix (e.g. lora_A.weight), whereas the live model stores them as
lora_A.<adapter_name>.weight. The _remap_adapter_key utility handles this translation.
4. Checkpoint Saving#
DCP checkpoint (training state)#
CheckpointerCallback._save_checkpoint saves the full distributed state (model + optimizer +
extra state) via PyTorch DCP. For LoRA training this includes both base-model parameters
and adapter parameters; the optimizer state only covers the trainable adapter parameters.
HF LoRA adapter (inference artifact)#
HFLoraCkptCallback._save_checkpoint calls save_lora_adapter_with_dcp
(veomni/utils/save_safetensor_utils.py), which:
Extracts adapter-only tensors via
get_peft_model_state_dict.Saves them with
dcp.savein parallel to a temporary DCP directory.Consolidates on rank 0 into
adapter_model.binandadapter_config.json.Removes the temporary DCP directory.
Output structure for each checkpoint:
<output_dir>/
├── checkpoints/
│ └── global_step_N/ ← DCP checkpoint (resume training)
│ ├── __0_0.distcp
│ └── .metadata
└── global_step_N/ ← HF adapter (inference / resume)
├── adapter_config.json
└── adapter_model.bin
5. Training Examples#
5.1 Wan2.1-I2V-1.3B LoRA (DiT, FSDP2)#
Config: configs/dit/wan2.1_I2V_1.3B_lora.yaml
model:
lora_config:
rank: 128
alpha: 64
lora_modules:
- to_q
- to_k
- to_v
- to_out.0
- ffn.net.0.proj
- ffn.net.2
train:
init_device: meta
accelerator:
fsdp_config:
fsdp_mode: fsdp2
Launch (8 GPUs, SP=2):
SP_SIZE=2
NPROC_PER_NODE=8
bash train.sh tasks/train_dit.py configs/dit/wan2.1_I2V_1.3B_lora.yaml \
--model.model_path ./Wan2.1-T2V-1.3B-Diffusers/transformer \
--model.condition_model_path ./Wan2.1-T2V-1.3B-Diffusers \
--data.train_path ./my_dataset_offline \
--data.source_name my_dataset \
--train.training_task offline_training \
--train.global_batch_size 8 \
--train.micro_batch_size 1 \
--train.accelerator.ulysses_size ${SP_SIZE} \
--train.checkpoint.output_dir ./exp/wan_lora \
--train.checkpoint.save_hf_weights true \
--train.checkpoint.save_epochs 5 \
--train.checkpoint.load_path auto \
--train.num_train_epochs 30
See Wan2.1-I2V-1.3B Training Guide for the complete dataset preparation and inference workflow.
5.2 Qwen3-0.6B LoRA (LLM, FSDP2)#
Config: configs/text/qwen3_lora.yaml
model:
model_path: Qwen3-0.6B-Base
ops_implementation:
attn_implementation: flash_attention_2
cross_entropy_loss_implementation: eager
rms_norm_implementation: eager
swiglu_mlp_implementation: eager
rotary_pos_emb_implementation: eager
lora_config:
rank: 64
alpha: 32
lora_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
train:
init_device: meta # required for FSDP2
accelerator:
fsdp_config:
fsdp_mode: fsdp2
Launch (8 GPUs, SP=2):
SP_SIZE=2
NPROC_PER_NODE=8
bash train.sh tasks/train_text.py configs/text/qwen3_lora.yaml \
--model.model_path /path/to/Qwen3-0.6B-Base \
--data.train_path /path/to/tulu-3-sft-mixture/data \
--train.global_batch_size 8 \
--train.micro_batch_size 1 \
--train.num_train_epochs 3 \
--train.checkpoint.output_dir ./exp/qwen3_lora \
--train.checkpoint.save_hf_weights true \
--train.checkpoint.load_path auto \
--train.wandb.enable true
To resume from a saved adapter:
bash train.sh tasks/train_text.py configs/text/qwen3_lora.yaml \
--model.model_path /path/to/Qwen3-0.6B-Base \
--data.train_path /path/to/tulu-3-sft-mixture/data \
--train.checkpoint.output_dir ./exp/qwen3_lora \
--train.checkpoint.load_path auto # auto-picks latest DCP checkpoint
6. Testing#
The test suite is under tests/lora/ and verifies save/load correctness using a
two-layer toy Qwen3 model:
torchrun --nproc_per_node=4 tests/lora/test_lora_trainer_saveload.py \
tests/lora/qwen3_toy_lora.yaml
What the test verifies:
Train 3 steps with LoRA on a dummy dataset (FSDP2, meta device).
After step 1: snapshot LoRA weights and save DCP checkpoint.
Continue training (steps 2–3 mutate adapter weights).
Reload the step-1 checkpoint; assert every LoRA tensor is bit-identical to the snapshot.