Fix SFT for base models (#604)

* Fix pad token bug in SFT

* Add ChatML default

* Clean up

* Refactor grpo model load

* Add doc

* Bump deepspeed
Esse commit está contido em:
lewtun
2025-04-16 11:45:50 +02:00
commit de GitHub
commit 5112bfc401
8 arquivos alterados com 90 adições e 63 exclusões
+39 -1
Ver Arquivo
@@ -112,7 +112,6 @@ accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r
--dataset_name open-r1/OpenR1-Math-220k \
--learning_rate 5.0e-5 \
--num_train_epochs 1 \
--packing \
--max_seq_length 16384 \
--per_device_train_batch_size 16 \
--gradient_checkpointing \
@@ -150,6 +149,45 @@ accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r
--wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO
```
**🚨 WARNING 🚨**
Most base models like `meta-llama/Llama-3.2-1B` do not have a chat template, so we set ChatML as the default during training. However, for Qwen base models like `Qwen/Qwen2.5-1.5B`, a chat template is pre-defined in the tokenizer, so the EOS token must be set accordingly, e.g.
```diff
# Align EOS token with chat template for Qwen base models
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path Qwen/Qwen2.5-1.5B \
+ --eos_token '<|im_end|>'
--dataset_name open-r1/OpenR1-Math-220k \
--learning_rate 5.0e-5 \
--num_train_epochs 1 \
--max_seq_length 16384 \
--per_device_train_batch_size 16 \
--gradient_checkpointing \
--bf16 \
--use_liger_kernel \
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
```
If you wish to use a custom chat template (e.g. Llama or Gemma), then the chat template and associated EOS token must be provided:
```diff
# Align EOS token with custom chat template
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path meta-llama/Llama-3.2-1B \
+ --chat_template "$(cat llama_chat_template.jinja)" \
+ --eos_token '<|eot_id|>' \
--dataset_name open-r1/OpenR1-Math-220k \
--learning_rate 5.0e-5 \
--num_train_epochs 1 \
--max_seq_length 16384 \
--per_device_train_batch_size 16 \
--gradient_checkpointing \
--bf16 \
--use_liger_kernel \
--output_dir data/Llama-3.2-1B-Open-R1-Distill
```
### SFT
To run SFT on a dataset distilled from DeepSeek-R1 with reasoning traces such as [open-r1/OpenR1-Math-220k](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k), run:
+1 -1
Ver Arquivo
@@ -36,7 +36,7 @@ hub_strategy: every_save
log_level: info
logging_steps: 5
logging_strategy: steps
packing: true
packing: false
output_dir: data/OpenR1-Qwen-7B-SFT
overwrite_output_dir: true
push_to_hub: true
@@ -25,7 +25,7 @@ logging_strategy: steps
lr_scheduler_type: cosine_with_min_lr
lr_scheduler_kwargs:
min_lr_rate: 0.1
packing: true
packing: false
max_length: 16384
max_steps: -1
num_train_epochs: 1
+2 -2
Ver Arquivo
@@ -44,7 +44,7 @@ _deps = [
"accelerate==1.4.0",
"bitsandbytes>=0.43.0",
"datasets>=3.2.0",
"deepspeed==0.15.4",
"deepspeed==0.16.4",
"distilabel[vllm,ray,openai]>=1.5.2",
"e2b-code-interpreter>=1.0.5",
"einops>=0.8.0",
@@ -67,7 +67,7 @@ _deps = [
"sentencepiece>=0.1.99",
"torch==2.6.0",
"transformers==4.51.2",
"trl @ git+https://github.com/huggingface/trl.git@d625c5533a6b1c84d3565c8080857f6bb81c538a", # Bump for vLLM and 2x faster throughput: https://github.com/huggingface/trl/pull/3276
"trl @ git+https://github.com/huggingface/trl.git@c04e84c4545acfaecdf7e0631ad07a86ab0fb2f6", # Fix EOS token for SFT on base models: https://github.com/huggingface/trl/pull/3299
"vllm==0.8.3",
"wandb>=0.19.1",
]
+8 -16
Ver Arquivo
@@ -17,7 +17,6 @@ import os
import sys
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
@@ -25,7 +24,7 @@ from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import GRPOConfig, GRPOScriptArguments
from open_r1.rewards import get_reward_funcs
from open_r1.utils import get_tokenizer
from open_r1.utils import get_model, get_tokenizer
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import GRPOTrainer, ModelConfig, TrlParser, get_peft_config
@@ -80,6 +79,12 @@ def main(script_args, training_args, model_args):
################
tokenizer = get_tokenizer(model_args, training_args)
##############
# Load model #
##############
logger.info("*** Loading model ***")
model = get_model(model_args, training_args)
# Get reward functions from the registry
reward_funcs = get_reward_funcs(script_args)
@@ -102,24 +107,11 @@ def main(script_args, training_args, model_args):
if "messages" in dataset[split].column_names:
dataset[split] = dataset[split].remove_columns("messages")
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
)
training_args.model_init_kwargs = model_kwargs
#############################
# Initialize the GRPO trainer
#############################
trainer = GRPOTrainer(
model=model_args.model_name_or_path,
model=model,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
+11 -30
Ver Arquivo
@@ -20,7 +20,7 @@ Usage:
# One 1 node of 8 x H100s
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
--dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
--dataset_name open-r1/OpenR1-Math-220k \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
@@ -40,25 +40,16 @@ import os
import sys
import datasets
import torch
import transformers
from datasets import load_dataset
from transformers import set_seed
from transformers.trainer_utils import get_last_checkpoint
from open_r1.configs import SFTConfig
from open_r1.utils import get_tokenizer
from open_r1.utils import get_model, get_tokenizer
from open_r1.utils.callbacks import get_callbacks
from open_r1.utils.wandb_logging import init_wandb_training
from trl import (
ModelConfig,
ScriptArguments,
SFTTrainer,
TrlParser,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
from trl import ModelConfig, ScriptArguments, SFTTrainer, TrlParser, get_peft_config, setup_chat_format
logger = logging.getLogger(__name__)
@@ -106,32 +97,22 @@ def main(script_args, training_args, model_args):
# Load tokenizer
################
tokenizer = get_tokenizer(model_args, training_args)
tokenizer.pad_token = tokenizer.eos_token
###################
# Model init kwargs
# Load model
###################
logger.info("*** Initializing model kwargs ***")
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
training_args.model_init_kwargs = model_kwargs
logger.info("*** Loading model ***")
model = get_model(model_args, training_args)
if tokenizer.chat_template is None:
logger.info("No chat template provided, using ChatML.")
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
############################
# Initialize the SFT Trainer
############################
trainer = SFTTrainer(
model=model_args.model_name_or_path,
model=model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
+2 -2
Ver Arquivo
@@ -1,5 +1,5 @@
from .import_utils import is_e2b_available
from .model_utils import get_tokenizer
from .model_utils import get_model, get_tokenizer
__all__ = ["get_tokenizer", "is_e2b_available"]
__all__ = ["get_tokenizer", "is_e2b_available", "get_model"]
+26 -10
Ver Arquivo
@@ -1,16 +1,12 @@
from transformers import AutoTokenizer, PreTrainedTokenizer
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
from trl import ModelConfig
from trl import ModelConfig, get_kbit_device_map, get_quantization_config
from ..configs import GRPOConfig, SFTConfig
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
def get_tokenizer(
model_args: ModelConfig, training_args: SFTConfig | GRPOConfig, auto_set_chat_template: bool = True
) -> PreTrainedTokenizer:
def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> PreTrainedTokenizer:
"""Get the tokenizer for the model."""
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
@@ -20,7 +16,27 @@ def get_tokenizer(
if training_args.chat_template is not None:
tokenizer.chat_template = training_args.chat_template
elif auto_set_chat_template and tokenizer.get_chat_template() is None:
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
return tokenizer
def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM:
"""Get the model"""
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
**model_kwargs,
)
return model