Implementation of Group Relative Policy Optimization (GRPO) to fine-tune Open Language Models like LlaMa-3.2, Qwen2.5 for Math Tasks.
This project implements a simple and customizable pipeline for fine-tuning language models using GRPO strategy introduced by DeepSeekMath. This approach utilizes group-based relative advantage estimates to guide the optimization. The training data comprises synthetic math expressions and their solutions. The framework is flexible and can be extended to other tasks.
pytorch · transformers · peft · flash-attn · pyyaml · tensorboard
config.yml
file with the required hyperparameters and Language model path.python3 grpo_train.py
If needed, replace the existing dataset with your custom data.
The table below demonstrates the improvement in model responses before and after GRPO fine-tuning:
Model: LlaMa-3.2-1B-Instruct
| <div align="center">Before Fine-tuning</div> | <div align="center">After GRPO Fine-tuning</div> |
|—|—|
| Prompt: | Expand 3*(x+1)*(x+2) |
| Responses: | |
| <think>
1. Multiply the first two terms: 3(x+1) = 3x + 3
2. Multiply the result by the third term: (3x + 3)(x+2)
</think>
<answer> 3x^2 + 3x + 6x + 6 </answer> | <think>
1. To expand the expression 3(x+1)(x+2), we need to apply the distributive property to expand the expression.
2. The distributive property states that a(b+c) = ab + ac.
3. So, we can expand the expression as 3(x+1)(x+2) = 3(xx + x2 + 1x + 12).
4. Now, we can simplify the expression by multiplying the terms.
5. 3(xx + x2 + 1x + 12) = 3x^2 + 6x + 3x + 2.
6. Combine like terms.
7. 3x^2 + 6x + 3x + 2 = 3x^2 + 9x + 2.
8. The final expanded expression is 3x^2 + 9x + 2.
</think>
<answer> 3x^2 + 9x + 2 </answer> |
| Prompt: | Solve 78-(3*(47-2)+2) |
| Responses: | |
| <think>
1. First, we need to evaluate the expression inside the parentheses: 47-2
2. 47 - 2 = 45
3. Now, we multiply 3 by the result: 3 * 45 = 135
4. Finally, we subtract 135 from 78: 78 - 135 = -57
</think>
<answer> The final answer is -57 </answer> | <think>
1. Start by evaluating the expression inside the parentheses: 47-2 = 45
2. Now, multiply 3 by the result: 3 * 45 = 135
3. Add 2 to the result: 135 + 2 = 137
4. Finally, subtract 137 from 78: 78 - 137 = -59
</think>
<answer> -59 </answer> |
Open an issue or submit a pull request if you have any suggestions or improvements.