Skip to content
Open

add vapo #10951

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
0bda8a7
add ppo
DesmonDay Jul 23, 2025
faaa350
Update run_rl.py
dddd-d Jul 24, 2025
9f1b06c
Update ppo_trainer.py
dddd-d Jul 24, 2025
f2664bc
Update config_utils.py
dddd-d Jul 29, 2025
b517f02
Update run_rl.py
dddd-d Jul 29, 2025
5d4c614
Update gsm8k_processor.py
dddd-d Jul 30, 2025
1a574bd
Update run_rl.py
dddd-d Jul 31, 2025
d48283a
Update score_model_utils.py
dddd-d Aug 1, 2025
d6d43a9
Update score_model_utils.py
dddd-d Aug 1, 2025
a4705c4
Update run_rl.py
dddd-d Aug 1, 2025
01599d3
Update ppo_trainer.py
dddd-d Aug 4, 2025
89f25f3
Update ppo_trainer.py
dddd-d Aug 4, 2025
325d76c
Update ppo_trainer.py
dddd-d Aug 4, 2025
c55a878
Update score_model_utils.py
dddd-d Aug 5, 2025
15b277a
add vapo
DesmonDay Aug 5, 2025
2614d18
Update gsm8k_processor.py
dddd-d Aug 5, 2025
abe2722
add vapo_argument.yaml
DesmonDay Aug 5, 2025
60486dc
pre_commit
DesmonDay Aug 5, 2025
da24bff
Update advantage.py
dddd-d Aug 5, 2025
07a1dea
Update advantage.py
dddd-d Aug 5, 2025
dde6b5d
pre-commit check
DesmonDay Aug 6, 2025
99a7297
[Auto Parallel] Fix sp in gpt modeling_auto (#10835)
Xing-lil Jul 25, 2025
8d68813
fix CI bug (#10890)
Liujie0926 Jul 28, 2025
495fbe9
fix ordered_save func (#10896)
RevL147 Aug 5, 2025
9734c0e
fix global_mini_batch_size
DesmonDay Aug 12, 2025
7136e10
fix ppo_argument.yaml
DesmonDay Aug 12, 2025
ce589fc
fix ppo_argument.yaml
DesmonDay Aug 12, 2025
fe61ae5
fix
DesmonDay Aug 13, 2025
9dcabdf
fix global_mini_batch_size
DesmonDay Aug 13, 2025
52f620a
merge ppo and vapo
DesmonDay Aug 14, 2025
839d1d3
add vapo_argument_kk.yaml
DesmonDay Aug 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions llm/alignment/rl/run_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def process_args(model_args: ModelArgument, data_args: DataArgument, training_ar
if model_args.reward_server is None:
raise ValueError("Please specify reward_server when use_rm_server is true.")
logger.info(f"Use reward server: {model_args.reward_server} for training.")
if training_args.rl_algorithm == "ppo" and model_args.critic_model_name_or_path is None:
if training_args.rl_algorithm in ["ppo", "vapo"] and model_args.critic_model_name_or_path is None:
raise ValueError("Please specify critic_model_name_or_path when use_rm_server is true.")
else:
if model_args.reward_model_name_or_path is None:
Expand Down Expand Up @@ -337,7 +337,7 @@ def main():
else:
reward_model, reward_tokenizer = model_args.reward_server, actor_tokenizer

if training_args.rl_algorithm == "ppo":
if training_args.rl_algorithm in ["ppo", "vapo"]:
critic_model, critic_eval_model, critic_tokenizer = create_critic_models(
model_args, data_args, training_args, common_config
)
Expand All @@ -352,7 +352,7 @@ def main():
offload_tensor_to_cpu((actor_eval_model, "freeze_model"))
offload_tensor_to_cpu((reference_model, "freeze_model"))

if training_args.rl_algorithm == "ppo":
if training_args.rl_algorithm in ["ppo", "vapo"]:
if not training_args.use_rm_server and not training_args.use_rule_reward:
offload_tensor_to_cpu((reward_model, "freeze_model"))
if critic_eval_model is not None:
Expand All @@ -362,10 +362,10 @@ def main():
paddle.device.cuda.empty_cache()

def compute_metrics(eval_preds):
'''
"""
If "use_rm_server" is TRUE, the score ranges from -3 to 3, with 3 being the only correct score (format + result).
If using the "Regularized Matching Function (use_rule_reward=True)" (currently only implemented for the gsm8k dataset), the score ranges from 0 to 1.
'''
"""
if training_args.use_rule_reward:
accuracy = (eval_preds.predictions == 1).astype("float32").mean().item()
else:
Expand Down
22 changes: 11 additions & 11 deletions llm/config/qwen/ppo_argument.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ do_train: true # Whether to perform training
seed: 42 # Random seed for reproducibility
global_batch_size: 256 # Global batch size for training (rollouts = rollout_n * global_batch_size)
global_gen_batch_size: -1 # Global generation batch size for dynamic sampling
global_mini_batch_size: 64 # Mini-batch size for training, default = (global_batch_size * rollout_n * update_iters) // dataset_world_size
rollout_n: 1 # Number of rollouts, set rollout_n = 1 for 'ppo'
global_mini_batch_size: 64 # Mini-batch size for training, default = global_batch_size
rollout_n: 1 # Number of rollouts
update_iters: 1 # Number of training iterations for rollout samples
per_device_logprob_batch_size: 4 # Log probability batch size per device
per_device_reward_batch_size: 2 # Reward batch size per device
per_device_value_batch_size: 2 # Value batch size per device
per_device_train_batch_size: 2 # Training micro batch size per device
# gradient_accumulation_steps: 4 # Gradient accumulation steps (auto-calculated): global_bz * rollout_n *
num_train_epochs: 5 # Number of training epochs
# gradient_accumulation_steps: 4 # Gradient accumulation steps (auto-calculated):
num_train_epochs: 15 # Number of training epochs
max_length: 2048 # Maximum length for training, should be larger than max_prompt_len + max_dec_len
adam_beta1: 0.9 # AdamW optimizer beta1
adam_beta2: 0.999 # AdamW optimizer beta2
Expand All @@ -72,17 +72,17 @@ disable_tqdm: true # Whether to disable tqdm progress bar

# actor training args
learning_rate: 1e-6 # Learning rate for training
min_learning_rate: 1e-6 # Minimum learning rate
lr_scheduler_type: "constant" # Learning rate scheduler type
min_learning_rate: 5e-7 # Minimum learning rate
lr_scheduler_type: "cosine" # Learning rate scheduler type
weight_decay: 1e-2 # Weight decay for the AdamW optimizer
warmup_ratio: 0.0 # Number of warmup steps
warmup_ratio: 0.2 # Number of warmup steps

# critic training args
critic_learning_rate: 1e-5 # Learning rate for critic model
critic_min_learning_rate: 1e-5 # Minimum learning rate for critic model
critic_lr_scheduler_type: "constant" # Learning rate scheduler type for critic model
critic_min_learning_rate: 5e-6 # Minimum learning rate for critic model
critic_lr_scheduler_type: "cosine" # Learning rate scheduler type for critic model
critic_weight_decay: 1e-2 # Weight decay for the AdamW optimizer of critic model
critic_warmup_ratio: 0.0 # Number of warmup steps for critic model
critic_warmup_ratio: 0.2 # Number of warmup steps for critic model

# RL args
kl_coeff: 0.0 # KL coefficient
Expand All @@ -107,7 +107,7 @@ use_fp32_compute: true # Whether to use fp32 to compute xx_log_prob,rewards, adv
do_eval: true # Whether to perform evaluation
per_device_eval_batch_size: 1319 # Evaluation batch size per device
evaluation_strategy: "steps" # Evaluation strategy, e.g., "steps"
eval_steps: 10 # Number of steps between evaluations
eval_steps: 32 # Number of steps between evaluations

# device memory optimization args
use_flash_attention: true # Whether to use fused attention operations
Expand Down
132 changes: 132 additions & 0 deletions llm/config/qwen/vapo_argument.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# RL algorithms
rl_algorithm: "vapo" # The reinforcement learning algorithm used, supported: "ppo", "grpo", "reinforce_plus_plus"

# models
actor_model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" # The name or path of the actor model
reward_model_name_or_path: "" # The name or path of the reward model
critic_model_name_or_path: "Qwen/Qwen2.5-1.5B-Instruct" # The name or path of the critic model
use_rm_server: false # Whether to use the reward model server
reward_server: "http://127.0.0.1:8731" # The address of the reward model server
use_rule_reward: true # The reward for gsm8k dataset. If use_rule_reward: use_rm_server = false

# logging
logging_dir: vapo-logs # Directory for logging
logging_steps: 1 # Number of steps between logging
output_dir: "qwen2.5-1.5b-gsm8k-vapo/checkpoints" # Directory for output ckpts
report_to: "wandb" # Supported reporting options: "all", "wandb", "tensorboard", "visualdl"(default), "none"
wandb_http_proxy: "http://agent.baidu.com:8891" # HTTP proxy for wandb
run_name: "qwen2.5-1.5b-gsm8k-vapo" # Name of the run

# data
train_datasets: "gsm8k/train.jsonl" # Path to the training dataset
eval_datasets: "gsm8k/test.jsonl" # Path to the evaluation dataset
prompt_key: "src" # Key for the prompt in the dataset
response_key: "tgt" # Key for the response in the dataset
dataloader_drop_last: true # Whether to drop the last incomplete batch in the DataLoader
balance_batch: true # Whether to balance batch size across dataset_world_size
use_remove_padding: true # Whether to remove padding tokens in the input

# distributed training args
tensor_parallel_degree: 2 # Degree of tensor parallelism
sequence_parallel: true # Whether to enable sequence parallelism
sharding_parallel_degree: -1 # Degree of sharding parallelism
sharding: "stage1" # Sharding strategy, e.g., "stage1" or "stage2"
sharding_parallel_config: "enable_release_grads" # Configuration for sharding parallelism
pipeline_parallel_degree: 1 # Degree of pipeline parallelism
virtual_pp_degree: 1 # Degree of virtual pipeline parallelism

# rollout args
max_prompt_len: 1024 # Maximum length of the prompt, exceeding which will be automatically truncated
max_dec_len: 512 # Maximum length of the response
min_dec_len: 32 # Minimum length of the response
top_p: 1.0 # Top-p sampling parameter
temperature: 1.0 # Temperature parameter for sampling
repetition_penalty: 1.0 # Repetition penalty parameter
rollout_max_num_seqs: 32 # The maximum number of sequences that can be processed in a single inference
rollout_quant_type: "" # Quantization type, e.g., "weight_only_int8"

# training args
do_train: true # Whether to perform training
seed: 42 # Random seed for reproducibility
global_batch_size: 32 # Global batch size for training (rollouts = rollout_n * global_batch_size)
global_gen_batch_size: -1 # Global generation batch size for dynamic sampling
global_mini_batch_size: 8 # Mini-batch size for training, default = (global_batch_size * rollout_n * update_iters) // dataset_world_size
rollout_n: 8 # Number of rollouts, set rollout_n = 1 for 'ppo'
update_iters: 1 # Number of training iterations for rollout samples
per_device_logprob_batch_size: 4 # Log probability batch size per device
per_device_reward_batch_size: 2 # Reward batch size per device
per_device_value_batch_size: 2 # Value batch size per device
per_device_train_batch_size: 2 # Training micro batch size per device
# gradient_accumulation_steps: 8 # Gradient accumulation steps (auto-calculated):
num_train_epochs: 15 # Number of training epochs
max_length: 2048 # Maximum length for training, should be larger than max_prompt_len + max_dec_len
adam_beta1: 0.9 # AdamW optimizer beta1
adam_beta2: 0.999 # AdamW optimizer beta2
adam_epsilon: 1e-8 # AdamW optimizer epsilon
max_grad_norm: 1.0 # Maximum gradient norm for clipping
max_steps: -1 # Maximum number of training steps
save_steps: 300 # Number of steps between model saves
save_strategy: "steps" # Strategy for saving models
ignore_save_lr_and_optim: true # Whether to ignore saving learning rate and optimizer state (leave empty if not specified)
disable_tqdm: true # Whether to disable tqdm progress bar

# actor training args
learning_rate: 1e-6 # Learning rate for training
min_learning_rate: 1e-6 # Minimum learning rate
lr_scheduler_type: "constant" # Learning rate scheduler type
weight_decay: 1e-2 # Weight decay for the AdamW optimizer
warmup_ratio: 0.0 # Number of warmup steps

# critic training args
critic_learning_rate: 1e-5 # Learning rate for critic model
critic_min_learning_rate: 1e-5 # Minimum learning rate for critic model
critic_lr_scheduler_type: "constant" # Learning rate scheduler type for critic model
critic_weight_decay: 1e-2 # Weight decay for the AdamW optimizer of critic model
critic_warmup_ratio: 0.0 # Number of warmup steps for critic model
pretrain_critic_steps: 50 # Pre-training steps for critic model

# RL args
kl_coeff: 0.0 # KL coefficient
kl_loss_coeff: 0.001 # KL loss coefficient
pg_loss_coeff: 1.0 # Policy gradient loss coefficient
entropy_coeff: 0.001 # Entropy coefficient
clip_range_ratio: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)
clip_range_ratio_low: 0.2 # The clipping range for ratio between the old and new policy. (PPO algorithm)
clip_range_ratio_high: 0.28 # The clipping range for ratio between the old and new policy. (PPO algorithm)
clip_range_score: 10.0 # The clipping range for the output of the score model. The reward is clipped into [-clip_range_score, clip_range_score].
enable_overlong_reward_buffer: false # Whether to enable overlong reward buffer
overlong_reward_buffer: 256 # The length of the overlong reward buffer
overlong_penalty_factor: 1.0 # The penalty factor for overlong reward buffer
clip_range_value: 0.5 # The clipping range for the output of the value model. The value is clipped into [-clip_range_value, clip_range_value].
normalize_reward: false # Whether to normalize reward
normalize_advantage: false # Whether to normalize advantage
dynamic_sampling: false # Whether to use dynamic sampling, which is introcuded in DAPO algorithm https://arxiv.org/abs/2503.14476
max_gen_batches: 2 # Maximum number of generation batches for dynamic sampling
use_fp32_compute: true # Whether to use fp32 to compute xx_log_prob,rewards, advantages and loss

# eval args
do_eval: true # Whether to perform evaluation
per_device_eval_batch_size: 1319 # Evaluation batch size per device
evaluation_strategy: "steps" # Evaluation strategy, e.g., "steps"
eval_steps: 32 # Number of steps between evaluations

# device memory optimization args
use_flash_attention: true # Whether to use fused attention operations
use_fused_rms_norm: true # Whether to use fused RMS norm operations, which needs to install fused_ln in slm/model_zoo/gpt-3/external_ops
use_fused_rope: false # Whether to use fused rope operations
use_fused_head_and_loss_fn: true # Whether to use fused head and loss function
use_fused_linear: true # Whether to use fused linear operations.
recompute: false # Whether to enable gradient checkpointing for memory optimization
recompute_use_reentrant: false # Whether to use reentrant recompute
recompute_granularity: "full" # Granularity of recompute
bf16: true # Whether to use mixed precision with bfloat16
fp16_opt_level: "O2" # Optimization level for fp16 and bf16 training
amp_master_grad: false # Whether to use float32 weight gradients for master weights in amp opt level=’O2’
amp_custom_black_list: ["reduce_sum", "softmax_with_cross_entropy", "c_softmax_with_cross_entropy", "elementwise_div", "sin", "cos"] # Custom black list for amp
amp_custom_white_list: ["lookup_table", "lookup_table_v2", "flash_attn", "matmul", "matmul_v2", "fused_gemm_epilogue"] # Custom white list for amp
offload_level: "freeze_model" # Level of model offloading to pinned memory, supported values: freeze_model, train_model, optimizer
release_grads: true # Whether to release gradients
offload_optim: false # Whether to offload optimizer to pinned memory

# benchmark args
skip_profile_timer: false # Whether to skip profiling time
Loading
Loading