Skip to content

Commit c9638be

Browse files
committed
feat lm: complete basic training and inference.
1 parent 28dce04 commit c9638be

File tree

14 files changed

+1294
-22
lines changed

14 files changed

+1294
-22
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
/target
2+
/data
3+
/lm/data
24
*.db
35
/result
46
.env

lm/Cargo.lock

Lines changed: 71 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lm/Cargo.toml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,19 @@ version = "0.1.0"
44
edition = "2024"
55

66
[dependencies]
7-
anyhow = "1.0.101"
8-
burn = { version = "0.20.1", features = ["candle", "train", "std"] }
7+
anyhow = { version = "1.0.101", features = ["backtrace"] }
8+
burn = { version = "0.20.1", features = ["ndarray", "tch", "train", "std"] }
99
tokenizers = { version = "0.22.2" }
10+
11+
serde_json = { version = "1.0.149" }
12+
rand = "0.10.0"
13+
memmap2 = "0.5"
14+
num-traits = "0.2.19"
15+
color-backtrace = "0.7.2"
16+
log = "0.4.29"
17+
18+
[[bin]]
19+
name = "train_tokenizer"
20+
21+
[[bin]]
22+
name = "run_inference"

lm/README.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# LM
2+
3+
A small transformers language model for the bot.
4+
5+
## Inference
6+
7+
Simply run `cargo r --bin infer model_path` with desired prompt as stdin.
8+
Here the model path is typically `data/model.mpk`.
9+
10+
## Training
11+
12+
First of all, export telegram chat messages to json format and copy the result.json to `data/result.json`.
13+
Nothing else is needed, and only text messages with some simple markups are accepted.
14+
Stickers and images would be filtered out and ignored.
15+
16+
Before starting training, some steps must be performed.
17+
They can easily be done by running the corresponding binary with `cargo r --bin name`:
18+
19+
- Train the tokenizer (MUST be the first step): `train_tokenizer`;
20+
- Translate the message json into datasets: `trans_dataset`;
21+
- Run pretrain: `pretrain`.
22+
23+
Training can be continued by simply rerunning pretrain:
24+
it would automatically read `data/model.mpk` to resume the model.
25+
And therefore, also remember to remove the model if it's desired to retrain the model or model config changed.
26+
27+
## Changing the model
28+
29+
**Note**: it's better to remove the data dir and rerun the whole training process after changing the model.
30+
Unless it's known what's being done.
31+
32+
The backend for training and inference (no other process need backend) is `lib.rs::DefaultBackend`.
33+
Normally changing this after training won't affect inference,
34+
but note that the default float precision may as well be provided as generic argument (like it's value in HEAD),
35+
in which case the inference would be affected if the precision is changed.
36+
37+
The model config (layer num, hidden size, etc.) is provided in `impl Default for LmConfig` in `lib.rs`.
38+
The default config is not clever (0.08b --- too small), and would definitely overfit on small groups/messages.

lm/src/bin/infer.rs

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
#![recursion_limit = "256"]
2+
use std::io::{Read, Write};
3+
4+
use burn::{
5+
config::Config,
6+
module::Module,
7+
record::{FullPrecisionSettings, NamedMpkFileRecorder},
8+
tensor::{Int, Tensor, activation::softmax},
9+
};
10+
use lm::{DefaultBackend, ExtraInfo, LmConfig, LmModel};
11+
use rand::RngExt;
12+
use tokenizers::Tokenizer;
13+
14+
fn main() -> anyhow::Result<()> {
15+
color_backtrace::BacktracePrinter::new()
16+
.strip_function_hash(true)
17+
.add_frame_filter(Box::new(|frames| {
18+
let crate_path = std::path::Path::new(file!()).canonicalize().unwrap();
19+
let crate_path = crate_path.parent().unwrap().parent().unwrap();
20+
frames.retain(|f| {
21+
f.filename.as_ref().is_some_and(|f| {
22+
f.canonicalize().ok().is_some_and(|f| f.starts_with(crate_path))
23+
})
24+
});
25+
}))
26+
.install(color_backtrace::default_output_stream());
27+
28+
let mut args = std::env::args();
29+
args.next();
30+
let model_path = args.next().expect("model path argument required");
31+
32+
let config = LmConfig::load("data/config.json")?;
33+
let tokenizer =
34+
Tokenizer::from_file("data/tokenizer.json").map_err(anyhow::Error::from_boxed)?;
35+
36+
let device = Default::default();
37+
let model: LmModel<DefaultBackend> = LmModel::new(&config, &device);
38+
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
39+
let model = model.load_file(model_path, &recorder, &device)?;
40+
41+
let mut prompt = String::new();
42+
std::io::stdin().read_to_string(&mut prompt)?;
43+
let prompt = prompt.trim();
44+
println!("prompt: '{prompt}'");
45+
46+
let tokens = tokenizer.encode(prompt, true).map_err(anyhow::Error::from_boxed)?;
47+
let tokens = tokens.get_ids();
48+
let mut info = ExtraInfo::new(&config, true, &device);
49+
let mut input =
50+
Tensor::<_, 1, Int>::from_ints(tokens, &device).reshape([1, tokens.len() as isize]);
51+
52+
let mut rng = rand::rng();
53+
for _ in 0..config.max_seq_len {
54+
let output = softmax(model.forward(input, &mut info), 2);
55+
let seq_len = output.dims()[1];
56+
let (prob, idx) = output
57+
.slice_dim(1, seq_len - 1..)
58+
.reshape([config.vocab_size as isize])
59+
.topk_with_indices(10, 0);
60+
/*println!(
61+
"{:?}",
62+
idx.clone()
63+
.to_data()
64+
.iter()
65+
.map(|i| (i, tokenizer.decode(&[i], true).unwrap()))
66+
.collect::<Vec<_>>()
67+
);*/
68+
let prob = prob.slice_dim(0, 0..10);
69+
let prob = (prob.clone().div(prob.sum_dim(0))).cumsum(0);
70+
let mask = prob.greater_elem(rng.random::<f32>()).int().argmax(0).reshape([-1]);
71+
let next = idx.select(0, mask).reshape([1, 1]);
72+
73+
input = next.clone();
74+
75+
let next = next.into_data().iter::<u32>().next().expect("token");
76+
if next == 0 {
77+
println!("<|delim|>");
78+
} else {
79+
print!("{}", tokenizer.decode(&[next], true).map_err(anyhow::Error::from_boxed)?);
80+
}
81+
std::io::stdout().flush()?;
82+
if next == 0 && rng.random::<f32>() < 0.3 {
83+
break;
84+
}
85+
}
86+
87+
Ok(())
88+
}

lm/src/bin/inspect_dataset.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
use anyhow::Context;
2+
use burn::data::dataset::Dataset;
3+
use lm::dataset::{ChatFile, SeqLenWrapper};
4+
use tokenizers::Tokenizer;
5+
6+
fn main() -> anyhow::Result<()> {
7+
let mut args = std::env::args();
8+
args.next();
9+
let data = ChatFile::new(args.next().context("dataset in arg")?)?;
10+
let data = SeqLenWrapper::new(data, 128);
11+
let tokenizer =
12+
Tokenizer::from_file("data/tokenizer.json").map_err(anyhow::Error::from_boxed)?;
13+
println!("dataset len {}", data.len());
14+
for i in 0..=data.len() {
15+
let Some(it) = data.get(i) else {
16+
println!("(Empty at #{i})");
17+
continue;
18+
};
19+
println!("#{i}: {it:?} => '{}'", tokenizer.decode(&it, false).map_err(anyhow::Error::from_boxed)?);
20+
}
21+
Ok(())
22+
}

0 commit comments

Comments
 (0)