|
| 1 | +#![recursion_limit = "256"] |
| 2 | +use std::io::{Read, Write}; |
| 3 | + |
| 4 | +use burn::{ |
| 5 | + config::Config, |
| 6 | + module::Module, |
| 7 | + record::{FullPrecisionSettings, NamedMpkFileRecorder}, |
| 8 | + tensor::{Int, Tensor, activation::softmax}, |
| 9 | +}; |
| 10 | +use lm::{DefaultBackend, ExtraInfo, LmConfig, LmModel}; |
| 11 | +use rand::RngExt; |
| 12 | +use tokenizers::Tokenizer; |
| 13 | + |
| 14 | +fn main() -> anyhow::Result<()> { |
| 15 | + color_backtrace::BacktracePrinter::new() |
| 16 | + .strip_function_hash(true) |
| 17 | + .add_frame_filter(Box::new(|frames| { |
| 18 | + let crate_path = std::path::Path::new(file!()).canonicalize().unwrap(); |
| 19 | + let crate_path = crate_path.parent().unwrap().parent().unwrap(); |
| 20 | + frames.retain(|f| { |
| 21 | + f.filename.as_ref().is_some_and(|f| { |
| 22 | + f.canonicalize().ok().is_some_and(|f| f.starts_with(crate_path)) |
| 23 | + }) |
| 24 | + }); |
| 25 | + })) |
| 26 | + .install(color_backtrace::default_output_stream()); |
| 27 | + |
| 28 | + let mut args = std::env::args(); |
| 29 | + args.next(); |
| 30 | + let model_path = args.next().expect("model path argument required"); |
| 31 | + |
| 32 | + let config = LmConfig::load("data/config.json")?; |
| 33 | + let tokenizer = |
| 34 | + Tokenizer::from_file("data/tokenizer.json").map_err(anyhow::Error::from_boxed)?; |
| 35 | + |
| 36 | + let device = Default::default(); |
| 37 | + let model: LmModel<DefaultBackend> = LmModel::new(&config, &device); |
| 38 | + let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new(); |
| 39 | + let model = model.load_file(model_path, &recorder, &device)?; |
| 40 | + |
| 41 | + let mut prompt = String::new(); |
| 42 | + std::io::stdin().read_to_string(&mut prompt)?; |
| 43 | + let prompt = prompt.trim(); |
| 44 | + println!("prompt: '{prompt}'"); |
| 45 | + |
| 46 | + let tokens = tokenizer.encode(prompt, true).map_err(anyhow::Error::from_boxed)?; |
| 47 | + let tokens = tokens.get_ids(); |
| 48 | + let mut info = ExtraInfo::new(&config, true, &device); |
| 49 | + let mut input = |
| 50 | + Tensor::<_, 1, Int>::from_ints(tokens, &device).reshape([1, tokens.len() as isize]); |
| 51 | + |
| 52 | + let mut rng = rand::rng(); |
| 53 | + for _ in 0..config.max_seq_len { |
| 54 | + let output = softmax(model.forward(input, &mut info), 2); |
| 55 | + let seq_len = output.dims()[1]; |
| 56 | + let (prob, idx) = output |
| 57 | + .slice_dim(1, seq_len - 1..) |
| 58 | + .reshape([config.vocab_size as isize]) |
| 59 | + .topk_with_indices(10, 0); |
| 60 | + /*println!( |
| 61 | + "{:?}", |
| 62 | + idx.clone() |
| 63 | + .to_data() |
| 64 | + .iter() |
| 65 | + .map(|i| (i, tokenizer.decode(&[i], true).unwrap())) |
| 66 | + .collect::<Vec<_>>() |
| 67 | + );*/ |
| 68 | + let prob = prob.slice_dim(0, 0..10); |
| 69 | + let prob = (prob.clone().div(prob.sum_dim(0))).cumsum(0); |
| 70 | + let mask = prob.greater_elem(rng.random::<f32>()).int().argmax(0).reshape([-1]); |
| 71 | + let next = idx.select(0, mask).reshape([1, 1]); |
| 72 | + |
| 73 | + input = next.clone(); |
| 74 | + |
| 75 | + let next = next.into_data().iter::<u32>().next().expect("token"); |
| 76 | + if next == 0 { |
| 77 | + println!("<|delim|>"); |
| 78 | + } else { |
| 79 | + print!("{}", tokenizer.decode(&[next], true).map_err(anyhow::Error::from_boxed)?); |
| 80 | + } |
| 81 | + std::io::stdout().flush()?; |
| 82 | + if next == 0 && rng.random::<f32>() < 0.3 { |
| 83 | + break; |
| 84 | + } |
| 85 | + } |
| 86 | + |
| 87 | + Ok(()) |
| 88 | +} |
0 commit comments