Fix SFT for base models (#604)
* Fix pad token bug in SFT * Add ChatML default * Clean up * Refactor grpo model load * Add doc * Bump deepspeed
Esse commit está contido em:
+39
-1
@@ -112,7 +112,6 @@ accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r
|
||||
--dataset_name open-r1/OpenR1-Math-220k \
|
||||
--learning_rate 5.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
--max_seq_length 16384 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--gradient_checkpointing \
|
||||
@@ -150,6 +149,45 @@ accelerate launch --config_file recipes/accelerate_configs/zero3.yaml src/open_r
|
||||
--wandb_entity huggingface --wandb_project open-r1 --run_name Qwen2.5-1.5B-GRPO
|
||||
```
|
||||
|
||||
**🚨 WARNING 🚨**
|
||||
|
||||
Most base models like `meta-llama/Llama-3.2-1B` do not have a chat template, so we set ChatML as the default during training. However, for Qwen base models like `Qwen/Qwen2.5-1.5B`, a chat template is pre-defined in the tokenizer, so the EOS token must be set accordingly, e.g.
|
||||
|
||||
```diff
|
||||
# Align EOS token with chat template for Qwen base models
|
||||
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-1.5B \
|
||||
+ --eos_token '<|im_end|>'
|
||||
--dataset_name open-r1/OpenR1-Math-220k \
|
||||
--learning_rate 5.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--max_seq_length 16384 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--gradient_checkpointing \
|
||||
--bf16 \
|
||||
--use_liger_kernel \
|
||||
--output_dir data/Qwen2.5-1.5B-Open-R1-Distill
|
||||
```
|
||||
|
||||
If you wish to use a custom chat template (e.g. Llama or Gemma), then the chat template and associated EOS token must be provided:
|
||||
|
||||
```diff
|
||||
# Align EOS token with custom chat template
|
||||
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
|
||||
--model_name_or_path meta-llama/Llama-3.2-1B \
|
||||
+ --chat_template "$(cat llama_chat_template.jinja)" \
|
||||
+ --eos_token '<|eot_id|>' \
|
||||
--dataset_name open-r1/OpenR1-Math-220k \
|
||||
--learning_rate 5.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--max_seq_length 16384 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--gradient_checkpointing \
|
||||
--bf16 \
|
||||
--use_liger_kernel \
|
||||
--output_dir data/Llama-3.2-1B-Open-R1-Distill
|
||||
```
|
||||
|
||||
### SFT
|
||||
|
||||
To run SFT on a dataset distilled from DeepSeek-R1 with reasoning traces such as [open-r1/OpenR1-Math-220k](https://huggingface.co/datasets/open-r1/OpenR1-Math-220k), run:
|
||||
|
||||
@@ -36,7 +36,7 @@ hub_strategy: every_save
|
||||
log_level: info
|
||||
logging_steps: 5
|
||||
logging_strategy: steps
|
||||
packing: true
|
||||
packing: false
|
||||
output_dir: data/OpenR1-Qwen-7B-SFT
|
||||
overwrite_output_dir: true
|
||||
push_to_hub: true
|
||||
|
||||
@@ -25,7 +25,7 @@ logging_strategy: steps
|
||||
lr_scheduler_type: cosine_with_min_lr
|
||||
lr_scheduler_kwargs:
|
||||
min_lr_rate: 0.1
|
||||
packing: true
|
||||
packing: false
|
||||
max_length: 16384
|
||||
max_steps: -1
|
||||
num_train_epochs: 1
|
||||
|
||||
+2
-2
@@ -44,7 +44,7 @@ _deps = [
|
||||
"accelerate==1.4.0",
|
||||
"bitsandbytes>=0.43.0",
|
||||
"datasets>=3.2.0",
|
||||
"deepspeed==0.15.4",
|
||||
"deepspeed==0.16.4",
|
||||
"distilabel[vllm,ray,openai]>=1.5.2",
|
||||
"e2b-code-interpreter>=1.0.5",
|
||||
"einops>=0.8.0",
|
||||
@@ -67,7 +67,7 @@ _deps = [
|
||||
"sentencepiece>=0.1.99",
|
||||
"torch==2.6.0",
|
||||
"transformers==4.51.2",
|
||||
"trl @ git+https://github.com/huggingface/trl.git@d625c5533a6b1c84d3565c8080857f6bb81c538a", # Bump for vLLM and 2x faster throughput: https://github.com/huggingface/trl/pull/3276
|
||||
"trl @ git+https://github.com/huggingface/trl.git@c04e84c4545acfaecdf7e0631ad07a86ab0fb2f6", # Fix EOS token for SFT on base models: https://github.com/huggingface/trl/pull/3299
|
||||
"vllm==0.8.3",
|
||||
"wandb>=0.19.1",
|
||||
]
|
||||
|
||||
+8
-16
@@ -17,7 +17,6 @@ import os
|
||||
import sys
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from transformers import set_seed
|
||||
@@ -25,7 +24,7 @@ from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from open_r1.configs import GRPOConfig, GRPOScriptArguments
|
||||
from open_r1.rewards import get_reward_funcs
|
||||
from open_r1.utils import get_tokenizer
|
||||
from open_r1.utils import get_model, get_tokenizer
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
from open_r1.utils.wandb_logging import init_wandb_training
|
||||
from trl import GRPOTrainer, ModelConfig, TrlParser, get_peft_config
|
||||
@@ -80,6 +79,12 @@ def main(script_args, training_args, model_args):
|
||||
################
|
||||
tokenizer = get_tokenizer(model_args, training_args)
|
||||
|
||||
##############
|
||||
# Load model #
|
||||
##############
|
||||
logger.info("*** Loading model ***")
|
||||
model = get_model(model_args, training_args)
|
||||
|
||||
# Get reward functions from the registry
|
||||
reward_funcs = get_reward_funcs(script_args)
|
||||
|
||||
@@ -102,24 +107,11 @@ def main(script_args, training_args, model_args):
|
||||
if "messages" in dataset[split].column_names:
|
||||
dataset[split] = dataset[split].remove_columns("messages")
|
||||
|
||||
logger.info("*** Initializing model kwargs ***")
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
)
|
||||
training_args.model_init_kwargs = model_kwargs
|
||||
|
||||
#############################
|
||||
# Initialize the GRPO trainer
|
||||
#############################
|
||||
trainer = GRPOTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
model=model,
|
||||
reward_funcs=reward_funcs,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
|
||||
+11
-30
@@ -20,7 +20,7 @@ Usage:
|
||||
# One 1 node of 8 x H100s
|
||||
accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
|
||||
--dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
|
||||
--dataset_name open-r1/OpenR1-Math-220k \
|
||||
--learning_rate 2.0e-5 \
|
||||
--num_train_epochs 1 \
|
||||
--packing \
|
||||
@@ -40,25 +40,16 @@ import os
|
||||
import sys
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from transformers import set_seed
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
from open_r1.configs import SFTConfig
|
||||
from open_r1.utils import get_tokenizer
|
||||
from open_r1.utils import get_model, get_tokenizer
|
||||
from open_r1.utils.callbacks import get_callbacks
|
||||
from open_r1.utils.wandb_logging import init_wandb_training
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
ScriptArguments,
|
||||
SFTTrainer,
|
||||
TrlParser,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl import ModelConfig, ScriptArguments, SFTTrainer, TrlParser, get_peft_config, setup_chat_format
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -106,32 +97,22 @@ def main(script_args, training_args, model_args):
|
||||
# Load tokenizer
|
||||
################
|
||||
tokenizer = get_tokenizer(model_args, training_args)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
###################
|
||||
# Model init kwargs
|
||||
# Load model
|
||||
###################
|
||||
logger.info("*** Initializing model kwargs ***")
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
training_args.model_init_kwargs = model_kwargs
|
||||
logger.info("*** Loading model ***")
|
||||
model = get_model(model_args, training_args)
|
||||
|
||||
if tokenizer.chat_template is None:
|
||||
logger.info("No chat template provided, using ChatML.")
|
||||
model, tokenizer = setup_chat_format(model, tokenizer, format="chatml")
|
||||
|
||||
############################
|
||||
# Initialize the SFT Trainer
|
||||
############################
|
||||
trainer = SFTTrainer(
|
||||
model=model_args.model_name_or_path,
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from .import_utils import is_e2b_available
|
||||
from .model_utils import get_tokenizer
|
||||
from .model_utils import get_model, get_tokenizer
|
||||
|
||||
|
||||
__all__ = ["get_tokenizer", "is_e2b_available"]
|
||||
__all__ = ["get_tokenizer", "is_e2b_available", "get_model"]
|
||||
|
||||
@@ -1,16 +1,12 @@
|
||||
from transformers import AutoTokenizer, PreTrainedTokenizer
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer
|
||||
|
||||
from trl import ModelConfig
|
||||
from trl import ModelConfig, get_kbit_device_map, get_quantization_config
|
||||
|
||||
from ..configs import GRPOConfig, SFTConfig
|
||||
|
||||
|
||||
DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
model_args: ModelConfig, training_args: SFTConfig | GRPOConfig, auto_set_chat_template: bool = True
|
||||
) -> PreTrainedTokenizer:
|
||||
def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> PreTrainedTokenizer:
|
||||
"""Get the tokenizer for the model."""
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
@@ -20,7 +16,27 @@ def get_tokenizer(
|
||||
|
||||
if training_args.chat_template is not None:
|
||||
tokenizer.chat_template = training_args.chat_template
|
||||
elif auto_set_chat_template and tokenizer.get_chat_template() is None:
|
||||
tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM:
|
||||
"""Get the model"""
|
||||
torch_dtype = (
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_args.model_revision,
|
||||
trust_remote_code=model_args.trust_remote_code,
|
||||
attn_implementation=model_args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
**model_kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
Referência em uma Nova Issue
Bloquear um usuário