import os


CI_HF_MODELS_DIR = os.getenv("CI_HF_MODELS_DIR", ".")
CI_DATASET_DIR = os.getenv("CI_DATASET_DIR", ".")


def qwen3_0p6b_base_tulu_sft_script():
    params = [
        "torchrun --nnodes=1 --nproc_per_node=8 --master-port=4321",
        "tasks/train_text.py",
        "configs/text/qwen3.yaml",
        f"--model.model_path {os.path.join(CI_HF_MODELS_DIR, 'Qwen', 'Qwen3-0.6B-Base')}",
        f"--data.train_path {os.path.join(CI_DATASET_DIR, 'tulu-3-sft-mixture/data')}",
        "--train.checkpoint.output_dir Qwen3-0.6B-Base-sft",
        "--train.enable_full_determinism true",
        "--train.num_train_epochs 1",
        "--train.max_steps 20",
        "--train.wandb.enable false $@ 2>&1",
    ]

    exec_script = " \\\n".join(params)

    return exec_script


def qwen3_0p6b_base_tulu_sft_no_reshard_script():
    params = [
        "torchrun --nnodes=1 --nproc_per_node=8 --master-port=4321",
        "tasks/train_text.py",
        "configs/text/qwen3.yaml",
        f"--model.model_path {os.path.join(CI_HF_MODELS_DIR, 'Qwen', 'Qwen3-0.6B-Base')}",
        f"--data.train_path {os.path.join(CI_DATASET_DIR, 'tulu-3-sft-mixture/data')}",
        "--train.checkpoint.output_dir Qwen3-0.6B-Base-sft-no-reshard",
        "--train.enable_full_determinism true",
        "--train.num_train_epochs 1",
        "--train.max_steps 20",
        "--train.accelerator.fsdp_config.reshard_after_forward false",
        "--train.accelerator.fsdp_config.reshard_after_backward false",
        "--train.wandb.enable false $@ 2>&1",
    ]

    exec_script = " \\\n".join(params)

    return exec_script


def qwen3_0p6b_base_tulu_sft_padded_script():
    params = [
        "torchrun --nnodes=1 --nproc_per_node=8 --master-port=4322",
        "tasks/train_text.py",
        "configs/text/qwen3.yaml",
        f"--model.model_path {os.path.join(CI_HF_MODELS_DIR, 'Qwen', 'Qwen3-0.6B-Base')}",
        f"--data.train_path {os.path.join(CI_DATASET_DIR, 'tulu-3-sft-mixture/data')}",
        "--train.checkpoint.output_dir Qwen3-0.6B-Base-sft-padded",
        "--train.enable_full_determinism true",
        "--train.num_train_epochs 1",
        "--train.max_steps 20",
        "--train.pad_to_length true",
        "--train.wandb.enable false $@ 2>&1",
    ]

    exec_script = " \\\n".join(params)

    return exec_script


SFT_SCRIPT = {
    "qwen3_0p6b_base_tulu_sft": qwen3_0p6b_base_tulu_sft_script(),
    "qwen3_0p6b_base_tulu_sft_no_reshard": qwen3_0p6b_base_tulu_sft_no_reshard_script(),
    "qwen3_0p6b_base_tulu_sft_padded": qwen3_0p6b_base_tulu_sft_padded_script(),
}

E2E_TEST_SCRIPT = {**SFT_SCRIPT}
