From 805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f Mon Sep 17 00:00:00 2001 From: Juarez Bochi Date: Tue, 12 Sep 2023 09:10:16 -0700 Subject: [PATCH] Implement top_p / nucleus sampling (#819) * Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error --- CHANGELOG.md | 2 + candle-examples/examples/bigcode/main.rs | 16 ++++- candle-examples/examples/falcon/main.rs | 37 ++++++---- candle-examples/examples/llama/main.rs | 6 +- candle-examples/examples/llama2-c/main.rs | 7 +- candle-examples/examples/quantized/main.rs | 6 +- candle-transformers/src/generation/mod.rs | 70 +++++++++++++++---- candle-transformers/tests/generation_tests.rs | 29 ++++++++ .../llama2-c/lib-example.html | 20 +++++- candle-wasm-examples/llama2-c/src/app.rs | 23 ++++-- candle-wasm-examples/llama2-c/src/bin/m.rs | 10 ++- candle-wasm-examples/llama2-c/src/worker.rs | 16 +++-- 12 files changed, 199 insertions(+), 43 deletions(-) create mode 100644 candle-transformers/tests/generation_tests.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index a0275c57..06041294 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ This documents the main changes to the `candle` crate. ## v0.2.2 - Unreleased ### Added +- Support for `top_p` sampling + [819](https://github.com/huggingface/candle/pull/819). ### Modified diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs index 3540f75d..5f17109e 100644 --- a/candle-examples/examples/bigcode/main.rs +++ b/candle-examples/examples/bigcode/main.rs @@ -28,9 +28,10 @@ impl TextGeneration { tokenizer: Tokenizer, seed: u64, temp: Option, + top_p: Option, device: &Device, ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp); + let logits_processor = LogitsProcessor::new(seed, temp, top_p); Self { model, tokenizer, @@ -94,6 +95,10 @@ struct Args { #[arg(long)] temperature: Option, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -149,7 +154,14 @@ fn main() -> Result<()> { let model = GPTBigCode::load(vb, config)?; println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device); + let mut pipeline = TextGeneration::new( + model, + tokenizer, + args.seed, + args.temperature, + args.top_p, + &device, + ); pipeline.run(&args.prompt, args.sample_len)?; Ok(()) } diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index c45fe545..b0973d64 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -25,17 +25,25 @@ struct TextGeneration { repeat_last_n: usize, } +struct GenerationOptions { + temp: Option, + top_p: Option, + repeat_penalty: f32, + repeat_last_n: usize, +} + impl TextGeneration { fn new( model: Falcon, tokenizer: Tokenizer, + generation_options: GenerationOptions, seed: u64, - temp: Option, device: &Device, - repeat_penalty: f32, - repeat_last_n: usize, ) -> Self { - let logits_processor = LogitsProcessor::new(seed, temp); + let logits_processor = + LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p); + let repeat_penalty = generation_options.repeat_penalty; + let repeat_last_n = generation_options.repeat_last_n; Self { model, tokenizer, @@ -118,6 +126,10 @@ struct Args { #[arg(long)] temperature: Option, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -185,15 +197,14 @@ fn main() -> Result<()> { let model = Falcon::load(vb, config)?; println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new( - model, - tokenizer, - args.seed, - args.temperature, - &device, - args.repeat_penalty, - args.repeat_last_n, - ); + let generation_options = GenerationOptions { + temp: args.temperature, + top_p: args.top_p, + repeat_penalty: args.repeat_penalty, + repeat_last_n: args.repeat_last_n, + }; + let mut pipeline = + TextGeneration::new(model, tokenizer, generation_options, args.seed, &device); pipeline.run(&args.prompt, args.sample_len)?; Ok(()) } diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index db3d216c..b2d7d938 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -42,6 +42,10 @@ struct Args { #[arg(long)] temperature: Option, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -193,7 +197,7 @@ fn main() -> Result<()> { println!("starting the inference loop"); print!("{prompt}"); - let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature); + let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p); let start_gen = std::time::Instant::now(); let mut index_pos = 0; let mut token_generated = 0; diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs index e0ade322..e752a494 100644 --- a/candle-examples/examples/llama2-c/main.rs +++ b/candle-examples/examples/llama2-c/main.rs @@ -27,6 +27,10 @@ struct InferenceCmd { #[arg(long)] temperature: Option, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + #[arg(long, default_value = "")] prompt: String, @@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> { None => { let cmd = InferenceCmd { temperature: None, + top_p: None, prompt: "".to_string(), config: None, model_id: "karpathy/tinyllamas".to_string(), @@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> { let model = Llama::load(vb, &cache, config)?; println!("starting the inference loop"); - let mut logits_processor = LogitsProcessor::new(299792458, args.temperature); + let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p); let mut index_pos = 0; print!("{}", args.prompt); diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs index c8179d33..a80ad420 100644 --- a/candle-examples/examples/quantized/main.rs +++ b/candle-examples/examples/quantized/main.rs @@ -71,6 +71,10 @@ struct Args { #[arg(long, default_value_t = 0.8)] temperature: f64, + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -310,7 +314,7 @@ fn main() -> anyhow::Result<()> { prompt_tokens }; let mut all_tokens = vec![]; - let mut logits_processor = LogitsProcessor::new(args.seed, temperature); + let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p); let start_prompt_processing = std::time::Instant::now(); let mut next_token = { diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs index b1d20168..6c8c8ae4 100644 --- a/candle-transformers/src/generation/mod.rs +++ b/candle-transformers/src/generation/mod.rs @@ -4,32 +4,76 @@ use rand::{distributions::Distribution, SeedableRng}; pub struct LogitsProcessor { rng: rand::rngs::StdRng, temperature: Option, + top_p: Option, } impl LogitsProcessor { - pub fn new(seed: u64, temperature: Option) -> Self { + pub fn new(seed: u64, temperature: Option, top_p: Option) -> Self { Self { rng: rand::rngs::StdRng::seed_from_u64(seed), temperature, + top_p, } } + fn sample_argmax(&mut self, logits: Tensor) -> Result { + let logits_v: Vec = logits.to_vec1()?; + let next_token = logits_v + .iter() + .enumerate() + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap(); + Ok(next_token) + } + + fn sample_multi(&mut self, prs: &Vec) -> Result { + let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; + let next_token = distr.sample(&mut self.rng) as u32; + Ok(next_token) + } + + fn sample_topp(&mut self, prs: &mut Vec, top_p: f32) -> Result { + // top-p sampling (or "nucleus sampling") samples from the smallest set of + // tokens that exceed probability top_p. This way we never sample tokens that + // have very low probabilities and are less likely to go "off the rails". + let mut argsort_indices = (0..prs.len()).collect::>(); + + // Sort by descending probability. + argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap()); + + // Clamp smaller probabilities to zero. + let mut cumsum = 0.; + for index in &argsort_indices { + if cumsum >= top_p { + prs[*index] = 0.0; + } else { + cumsum += prs[*index]; + } + } + + // Sample with clamped probabilities. + let next_token = self.sample_multi(prs)?; + Ok(next_token) + } + pub fn sample(&mut self, logits: &Tensor) -> Result { let logits = logits.to_dtype(DType::F32)?; let temperature = self.temperature.unwrap_or(0.); - let next_token = if temperature > 0. { - let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?; - let prs: Vec = prs.to_vec1()?; - let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?; - distr.sample(&mut self.rng) as u32 + let top_p = self.top_p.unwrap_or(1.); + let next_token = if temperature == 0. { + self.sample_argmax(logits)? } else { - let logits_v: Vec = logits.to_vec1()?; - logits_v - .iter() - .enumerate() - .max_by(|(_, u), (_, v)| u.total_cmp(v)) - .map(|(i, _)| i as u32) - .unwrap() + let logits = &(&logits / temperature)?; + let prs = candle_nn::ops::softmax(logits, D::Minus1)?; + let mut prs: Vec = prs.to_vec1()?; + if top_p <= 0.0 || top_p >= 1.0 { + // simply sample from the predicted probability distribution + self.sample_multi(&prs)? + } else { + // top-p (nucleus) sampling, clamping the least likely tokens to zero + self.sample_topp(&mut prs, top_p as f32)? + } }; Ok(next_token) } diff --git a/candle-transformers/tests/generation_tests.rs b/candle-transformers/tests/generation_tests.rs new file mode 100644 index 00000000..76f994d0 --- /dev/null +++ b/candle-transformers/tests/generation_tests.rs @@ -0,0 +1,29 @@ +use candle::{Device, Result, Tensor}; +use candle_transformers::generation::LogitsProcessor; + +#[test] +fn sample_with_zero_temperature() -> Result<()> { + let mut logits_process = LogitsProcessor::new(1337, None, None); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 3); + Ok(()) +} + +#[test] +fn sample_with_temperature() -> Result<()> { + let mut logits_process = LogitsProcessor::new(42, Some(0.9), None); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 0); + Ok(()) +} + +#[test] +fn sample_with_top_p() -> Result<()> { + let mut logits_process = LogitsProcessor::new(42, Some(1.0), Some(0.5)); + let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?; + let token = logits_process.sample(&logits)?; + assert_eq!(token, 2); + Ok(()) +} diff --git a/candle-wasm-examples/llama2-c/lib-example.html b/candle-wasm-examples/llama2-c/lib-example.html index b5033c54..22b12517 100644 --- a/candle-wasm-examples/llama2-c/lib-example.html +++ b/candle-wasm-examples/llama2-c/lib-example.html @@ -56,6 +56,7 @@ const weightsURL = `${MODELS_BASE_URL}/${model.url}`; const prompt = getValue("prompt"); const temperature = getValue("temperature"); + const topP = getValue("top-p"); const repeatPenalty = getValue("repeat_penalty"); const seed = getValue("seed"); const maxSeqLen = getValue("max-seq"); @@ -99,6 +100,7 @@ tokenizerURL: "tokenizer.json", prompt, temp: temperature, + top_p: topP, repeatPenalty, seed: BigInt(seed), maxSeqLen, @@ -251,7 +253,7 @@ 0.50 + + + + 1.00 >, + top_p: std::rc::Rc>, prompt: std::rc::Rc>, generated: String, n_tokens: usize, @@ -81,6 +82,7 @@ impl Component for App { status, n_tokens: 0, temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)), + top_p: std::rc::Rc::new(std::cell::RefCell::new(1.0)), prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())), generated: String::new(), current_decode: None, @@ -122,10 +124,11 @@ impl Component for App { self.n_tokens = 0; self.generated.clear(); let temp = *self.temperature.borrow(); + let top_p = *self.top_p.borrow(); let prompt = self.prompt.borrow().clone(); - console_log!("temp: {}, prompt: {}", temp, prompt); + console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt); ctx.link() - .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt))) + .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt))) } true } @@ -177,13 +180,21 @@ impl Component for App { fn view(&self, ctx: &Context) -> Html { use yew::TargetCast; let temperature = self.temperature.clone(); - let oninput = ctx.link().callback(move |e: yew::InputEvent| { + let oninput_temperature = ctx.link().callback(move |e: yew::InputEvent| { let input: web_sys::HtmlInputElement = e.target_unchecked_into(); if let Ok(temp) = f64::from_str(&input.value()) { *temperature.borrow_mut() = temp } Msg::Refresh }); + let top_p = self.top_p.clone(); + let oninput_top_p = ctx.link().callback(move |e: yew::InputEvent| { + let input: web_sys::HtmlInputElement = e.target_unchecked_into(); + if let Ok(top_p_input) = f64::from_str(&input.value()) { + *top_p.borrow_mut() = top_p_input + } + Msg::Refresh + }); let prompt = self.prompt.clone(); let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| { let input: web_sys::HtmlInputElement = e.target_unchecked_into(); @@ -201,9 +212,13 @@ impl Component for App {

{"temperature \u{00a0} "} - + {format!(" \u{00a0} {}", self.temperature.borrow())}
+ {"top_p \u{00a0} "} + + {format!(" \u{00a0} {}", self.top_p.borrow())} +
{"prompt: "}
{ diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs index 6628ab7e..61de9d7f 100644 --- a/candle-wasm-examples/llama2-c/src/bin/m.rs +++ b/candle-wasm-examples/llama2-c/src/bin/m.rs @@ -47,7 +47,7 @@ impl Model { tokenizer, model: weights, }); - let logits_processor = LogitsProcessor::new(299792458, None); + let logits_processor = LogitsProcessor::new(299792458, None, None); match model { Ok(inner) => Ok(Self { inner, @@ -69,6 +69,7 @@ impl Model { &mut self, prompt: String, temp: f64, + top_p: f64, repeat_penalty: f32, seed: u64, ) -> Result { @@ -80,7 +81,12 @@ impl Model { } } let temp = if temp <= 0. { None } else { Some(temp) }; - self.logits_processor = LogitsProcessor::new(seed, temp); + let top_p = if top_p <= 0. || top_p >= 1. { + None + } else { + Some(top_p) + }; + self.logits_processor = LogitsProcessor::new(seed, temp, top_p); self.repeat_penalty = repeat_penalty; self.tokens.clear(); let tokens = self diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs index 7e97b5da..79dd2f32 100644 --- a/candle-wasm-examples/llama2-c/src/worker.rs +++ b/candle-wasm-examples/llama2-c/src/worker.rs @@ -62,12 +62,18 @@ impl Model { link: &WorkerLink, id: HandlerId, temp: f64, + top_p: f64, prompt: String, ) -> Result<()> { let dev = Device::Cpu; let temp = if temp <= 0. { None } else { Some(temp) }; - console_log!("{temp:?} {prompt}"); - let mut logits_processor = LogitsProcessor::new(299792458, temp); + let top_p = if top_p <= 0. || top_p >= 1.0 { + None + } else { + Some(top_p) + }; + console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}"); + let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p); let mut index_pos = 0; let mut tokens = self .tokenizer @@ -268,7 +274,7 @@ pub struct Worker { #[derive(Serialize, Deserialize)] pub enum WorkerInput { ModelData(ModelData), - Run(f64, String), + Run(f64, f64, String), } #[derive(Serialize, Deserialize)] @@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker { } Err(err) => Err(format!("model creation error {err:?}")), }, - WorkerInput::Run(temp, prompt) => match &mut self.model { + WorkerInput::Run(temp, top_p, prompt) => match &mut self.model { None => Err("model has not been set yet".to_string()), Some(model) => { { @@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker { } } let result = model - .run(&self.link, id, temp, prompt) + .run(&self.link, id, temp, top_p, prompt) .map_err(|e| e.to_string()); Ok(WorkerOutput::GenerationDone(result)) }