-
Notifications
You must be signed in to change notification settings - Fork 846
Expand file tree
/
Copy pathprepare_rl_prompts.py
More file actions
76 lines (59 loc) · 2.6 KB
/
Copy pathprepare_rl_prompts.py
File metadata and controls
76 lines (59 loc) · 2.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
Build RL prompt sets ({"prompt", "gold"}) for PPO/GRPO:
- GSM8K train -> rl_prompts_train.jsonl (the RL training prompts)
- GSM8K test -> rl_prompts_test.jsonl (held-out benchmark for eval)
- a programmatic arithmetic set -> arithmetic_prompts.jsonl (RL curriculum warm-up,
where even a weak model gets some non-zero reward so RL has signal to start from)
Example:
PYTHONPATH=. HF_HOME=/ephemeral/hf_cache python scripts/prepare_rl_prompts.py --out_dir /ephemeral/data
"""
from __future__ import annotations
import argparse
import json
import os
import random
os.environ.setdefault("HF_HOME", "/ephemeral/hf_cache")
from src.post_training.rewards import gsm8k_gold_answer
def write_jsonl(rows, path):
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
with open(path, "w") as f:
for r in rows:
f.write(json.dumps(r) + "\n")
print(f" wrote {len(rows)} prompts -> {path}")
def gsm8k_prompts(split: str, limit: int | None):
from datasets import load_dataset
ds = load_dataset("openai/gsm8k", "main", split=split)
if limit:
ds = ds.select(range(min(limit, len(ds))))
rows = []
for ex in ds:
gold = gsm8k_gold_answer(ex["answer"])
if gold is not None:
rows.append({"prompt": ex["question"].strip(), "gold": gold})
return rows
def arithmetic_prompts(n: int, max_val: int, seed: int):
rng = random.Random(seed)
ops = [("+", lambda a, b: a + b), ("-", lambda a, b: a - b), ("*", lambda a, b: a * b)]
rows = []
for _ in range(n):
a, b = rng.randint(0, max_val), rng.randint(0, max_val)
sym, fn = rng.choice(ops)
rows.append({"prompt": f"What is {a} {sym} {b}?", "gold": float(fn(a, b))})
return rows
def main():
p = argparse.ArgumentParser()
p.add_argument("--out_dir", default="/ephemeral/data")
p.add_argument("--train_limit", type=int, default=None)
p.add_argument("--test_limit", type=int, default=500)
p.add_argument("--arith_n", type=int, default=5000)
p.add_argument("--arith_max", type=int, default=20)
args = p.parse_args()
print("Loading GSM8K train ...")
write_jsonl(gsm8k_prompts("train", args.train_limit), os.path.join(args.out_dir, "rl_prompts_train.jsonl"))
print("Loading GSM8K test ...")
write_jsonl(gsm8k_prompts("test", args.test_limit), os.path.join(args.out_dir, "rl_prompts_test.jsonl"))
print("Generating arithmetic curriculum ...")
write_jsonl(arithmetic_prompts(args.arith_n, args.arith_max, seed=0),
os.path.join(args.out_dir, "arithmetic_prompts.jsonl"))
if __name__ == "__main__":
main()