Minimal-GRPO

Implementation of Group Relative Policy Optimization (GRPO) to fine-tune Open Language Models like LlaMa-3.2, Qwen2.5 for Math Tasks.

Overview

This project implements a simple and customizable pipeline for fine-tuning language models using GRPO strategy introduced by DeepSeekMath. This approach utilizes group-based relative advantage estimates to guide the optimization. The training data comprises synthetic math expressions and their solutions. The framework is flexible and can be extended to other tasks.

Requirements

pytorch · transformers · peft · flash-attn · pyyaml · tensorboard

Usage

  1. Edit the config.yml file with the required hyperparameters and Language model path.
  2. Then, run the following command to fine_tune the model:
    python3 grpo_train.py
    

    If needed, replace the existing dataset with your custom data.

Results

The table below demonstrates the improvement in model responses before and after GRPO fine-tuning:

Model: LlaMa-3.2-1B-Instruct | <div align="center">Before Fine-tuning</div> | <div align="center">After GRPO Fine-tuning</div> | |—|—| | Prompt: | Expand 3*(x+1)*(x+2) | | Responses: | | | <think>
1. Multiply the first two terms: 3(x+1) = 3x + 3
2. Multiply the result by the third term: (3x + 3)
(x+2)
</think>
<answer> 3x^2 + 3x + 6x + 6 </answer> | <think>
1. To expand the expression 3(x+1)(x+2), we need to apply the distributive property to expand the expression.
2. The distributive property states that a(b+c) = ab + ac.
3. So, we can expand the expression as 3(x+1)(x+2) = 3(xx + x2 + 1x + 12).
4. Now, we can simplify the expression by multiplying the terms.
5. 3
(xx + x2 + 1x + 12) = 3x^2 + 6x + 3x + 2.
6. Combine like terms.
7. 3
x^2 + 6x + 3x + 2 = 3x^2 + 9x + 2.
8. The final expanded expression is 3x^2 + 9x + 2.
</think>
<answer> 3x^2 + 9x + 2 </answer> | | Prompt: | Solve 78-(3*(47-2)+2) | | Responses: | | | <think>
1. First, we need to evaluate the expression inside the parentheses: 47-2
2. 47 - 2 = 45
3. Now, we multiply 3 by the result: 3 * 45 = 135
4. Finally, we subtract 135 from 78: 78 - 135 = -57
</think>
<answer> The final answer is -57 </answer> | <think>
1. Start by evaluating the expression inside the parentheses: 47-2 = 45
2. Now, multiply 3 by the result: 3 * 45 = 135
3. Add 2 to the result: 135 + 2 = 137
4. Finally, subtract 137 from 78: 78 - 137 = -59
</think>
<answer> -59 </answer> |

Rewards accumulated per step during training

Reward during training

Contributing

Open an issue or submit a pull request if you have any suggestions or improvements.

Acknowledgments