From 805bf9ffa78119a1a7e047b4ddf6b2ea7df4d94f Mon Sep 17 00:00:00 2001
From: Juarez Bochi
Date: Tue, 12 Sep 2023 09:10:16 -0700
Subject: [PATCH] Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling
* Update changelog
* rustfmt
* Add tests
* Fix clippy warning
* Fix another clippy error
---
CHANGELOG.md | 2 +
candle-examples/examples/bigcode/main.rs | 16 ++++-
candle-examples/examples/falcon/main.rs | 37 ++++++----
candle-examples/examples/llama/main.rs | 6 +-
candle-examples/examples/llama2-c/main.rs | 7 +-
candle-examples/examples/quantized/main.rs | 6 +-
candle-transformers/src/generation/mod.rs | 70 +++++++++++++++----
candle-transformers/tests/generation_tests.rs | 29 ++++++++
.../llama2-c/lib-example.html | 20 +++++-
candle-wasm-examples/llama2-c/src/app.rs | 23 ++++--
candle-wasm-examples/llama2-c/src/bin/m.rs | 10 ++-
candle-wasm-examples/llama2-c/src/worker.rs | 16 +++--
12 files changed, 199 insertions(+), 43 deletions(-)
create mode 100644 candle-transformers/tests/generation_tests.rs
diff --git a/CHANGELOG.md b/CHANGELOG.md
index a0275c57..06041294 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -4,6 +4,8 @@ This documents the main changes to the `candle` crate.
## v0.2.2 - Unreleased
### Added
+- Support for `top_p` sampling
+ [819](https://github.com/huggingface/candle/pull/819).
### Modified
diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs
index 3540f75d..5f17109e 100644
--- a/candle-examples/examples/bigcode/main.rs
+++ b/candle-examples/examples/bigcode/main.rs
@@ -28,9 +28,10 @@ impl TextGeneration {
tokenizer: Tokenizer,
seed: u64,
temp: Option,
+ top_p: Option,
device: &Device,
) -> Self {
- let logits_processor = LogitsProcessor::new(seed, temp);
+ let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
@@ -94,6 +95,10 @@ struct Args {
#[arg(long)]
temperature: Option,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -149,7 +154,14 @@ fn main() -> Result<()> {
let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
- let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
+ let mut pipeline = TextGeneration::new(
+ model,
+ tokenizer,
+ args.seed,
+ args.temperature,
+ args.top_p,
+ &device,
+ );
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}
diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs
index c45fe545..b0973d64 100644
--- a/candle-examples/examples/falcon/main.rs
+++ b/candle-examples/examples/falcon/main.rs
@@ -25,17 +25,25 @@ struct TextGeneration {
repeat_last_n: usize,
}
+struct GenerationOptions {
+ temp: Option,
+ top_p: Option,
+ repeat_penalty: f32,
+ repeat_last_n: usize,
+}
+
impl TextGeneration {
fn new(
model: Falcon,
tokenizer: Tokenizer,
+ generation_options: GenerationOptions,
seed: u64,
- temp: Option,
device: &Device,
- repeat_penalty: f32,
- repeat_last_n: usize,
) -> Self {
- let logits_processor = LogitsProcessor::new(seed, temp);
+ let logits_processor =
+ LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p);
+ let repeat_penalty = generation_options.repeat_penalty;
+ let repeat_last_n = generation_options.repeat_last_n;
Self {
model,
tokenizer,
@@ -118,6 +126,10 @@ struct Args {
#[arg(long)]
temperature: Option,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -185,15 +197,14 @@ fn main() -> Result<()> {
let model = Falcon::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed());
- let mut pipeline = TextGeneration::new(
- model,
- tokenizer,
- args.seed,
- args.temperature,
- &device,
- args.repeat_penalty,
- args.repeat_last_n,
- );
+ let generation_options = GenerationOptions {
+ temp: args.temperature,
+ top_p: args.top_p,
+ repeat_penalty: args.repeat_penalty,
+ repeat_last_n: args.repeat_last_n,
+ };
+ let mut pipeline =
+ TextGeneration::new(model, tokenizer, generation_options, args.seed, &device);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}
diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs
index db3d216c..b2d7d938 100644
--- a/candle-examples/examples/llama/main.rs
+++ b/candle-examples/examples/llama/main.rs
@@ -42,6 +42,10 @@ struct Args {
#[arg(long)]
temperature: Option,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -193,7 +197,7 @@ fn main() -> Result<()> {
println!("starting the inference loop");
print!("{prompt}");
- let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature);
+ let mut logits_processor = LogitsProcessor::new(args.seed, args.temperature, args.top_p);
let start_gen = std::time::Instant::now();
let mut index_pos = 0;
let mut token_generated = 0;
diff --git a/candle-examples/examples/llama2-c/main.rs b/candle-examples/examples/llama2-c/main.rs
index e0ade322..e752a494 100644
--- a/candle-examples/examples/llama2-c/main.rs
+++ b/candle-examples/examples/llama2-c/main.rs
@@ -27,6 +27,10 @@ struct InferenceCmd {
#[arg(long)]
temperature: Option,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option,
+
#[arg(long, default_value = "")]
prompt: String,
@@ -133,6 +137,7 @@ fn main() -> anyhow::Result<()> {
None => {
let cmd = InferenceCmd {
temperature: None,
+ top_p: None,
prompt: "".to_string(),
config: None,
model_id: "karpathy/tinyllamas".to_string(),
@@ -256,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop");
- let mut logits_processor = LogitsProcessor::new(299792458, args.temperature);
+ let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
let mut index_pos = 0;
print!("{}", args.prompt);
diff --git a/candle-examples/examples/quantized/main.rs b/candle-examples/examples/quantized/main.rs
index c8179d33..a80ad420 100644
--- a/candle-examples/examples/quantized/main.rs
+++ b/candle-examples/examples/quantized/main.rs
@@ -71,6 +71,10 @@ struct Args {
#[arg(long, default_value_t = 0.8)]
temperature: f64,
+ /// Nucleus sampling probability cutoff.
+ #[arg(long)]
+ top_p: Option,
+
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@@ -310,7 +314,7 @@ fn main() -> anyhow::Result<()> {
prompt_tokens
};
let mut all_tokens = vec![];
- let mut logits_processor = LogitsProcessor::new(args.seed, temperature);
+ let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
diff --git a/candle-transformers/src/generation/mod.rs b/candle-transformers/src/generation/mod.rs
index b1d20168..6c8c8ae4 100644
--- a/candle-transformers/src/generation/mod.rs
+++ b/candle-transformers/src/generation/mod.rs
@@ -4,32 +4,76 @@ use rand::{distributions::Distribution, SeedableRng};
pub struct LogitsProcessor {
rng: rand::rngs::StdRng,
temperature: Option,
+ top_p: Option,
}
impl LogitsProcessor {
- pub fn new(seed: u64, temperature: Option) -> Self {
+ pub fn new(seed: u64, temperature: Option, top_p: Option) -> Self {
Self {
rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature,
+ top_p,
}
}
+ fn sample_argmax(&mut self, logits: Tensor) -> Result {
+ let logits_v: Vec = logits.to_vec1()?;
+ let next_token = logits_v
+ .iter()
+ .enumerate()
+ .max_by(|(_, u), (_, v)| u.total_cmp(v))
+ .map(|(i, _)| i as u32)
+ .unwrap();
+ Ok(next_token)
+ }
+
+ fn sample_multi(&mut self, prs: &Vec) -> Result {
+ let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
+ let next_token = distr.sample(&mut self.rng) as u32;
+ Ok(next_token)
+ }
+
+ fn sample_topp(&mut self, prs: &mut Vec, top_p: f32) -> Result {
+ // top-p sampling (or "nucleus sampling") samples from the smallest set of
+ // tokens that exceed probability top_p. This way we never sample tokens that
+ // have very low probabilities and are less likely to go "off the rails".
+ let mut argsort_indices = (0..prs.len()).collect::>();
+
+ // Sort by descending probability.
+ argsort_indices.sort_by(|&i, &j| prs[j].partial_cmp(&prs[i]).unwrap());
+
+ // Clamp smaller probabilities to zero.
+ let mut cumsum = 0.;
+ for index in &argsort_indices {
+ if cumsum >= top_p {
+ prs[*index] = 0.0;
+ } else {
+ cumsum += prs[*index];
+ }
+ }
+
+ // Sample with clamped probabilities.
+ let next_token = self.sample_multi(prs)?;
+ Ok(next_token)
+ }
+
pub fn sample(&mut self, logits: &Tensor) -> Result {
let logits = logits.to_dtype(DType::F32)?;
let temperature = self.temperature.unwrap_or(0.);
- let next_token = if temperature > 0. {
- let prs = candle_nn::ops::softmax(&(&logits / temperature)?, D::Minus1)?;
- let prs: Vec = prs.to_vec1()?;
- let distr = rand::distributions::WeightedIndex::new(prs).map_err(Error::wrap)?;
- distr.sample(&mut self.rng) as u32
+ let top_p = self.top_p.unwrap_or(1.);
+ let next_token = if temperature == 0. {
+ self.sample_argmax(logits)?
} else {
- let logits_v: Vec = logits.to_vec1()?;
- logits_v
- .iter()
- .enumerate()
- .max_by(|(_, u), (_, v)| u.total_cmp(v))
- .map(|(i, _)| i as u32)
- .unwrap()
+ let logits = &(&logits / temperature)?;
+ let prs = candle_nn::ops::softmax(logits, D::Minus1)?;
+ let mut prs: Vec = prs.to_vec1()?;
+ if top_p <= 0.0 || top_p >= 1.0 {
+ // simply sample from the predicted probability distribution
+ self.sample_multi(&prs)?
+ } else {
+ // top-p (nucleus) sampling, clamping the least likely tokens to zero
+ self.sample_topp(&mut prs, top_p as f32)?
+ }
};
Ok(next_token)
}
diff --git a/candle-transformers/tests/generation_tests.rs b/candle-transformers/tests/generation_tests.rs
new file mode 100644
index 00000000..76f994d0
--- /dev/null
+++ b/candle-transformers/tests/generation_tests.rs
@@ -0,0 +1,29 @@
+use candle::{Device, Result, Tensor};
+use candle_transformers::generation::LogitsProcessor;
+
+#[test]
+fn sample_with_zero_temperature() -> Result<()> {
+ let mut logits_process = LogitsProcessor::new(1337, None, None);
+ let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
+ let token = logits_process.sample(&logits)?;
+ assert_eq!(token, 3);
+ Ok(())
+}
+
+#[test]
+fn sample_with_temperature() -> Result<()> {
+ let mut logits_process = LogitsProcessor::new(42, Some(0.9), None);
+ let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
+ let token = logits_process.sample(&logits)?;
+ assert_eq!(token, 0);
+ Ok(())
+}
+
+#[test]
+fn sample_with_top_p() -> Result<()> {
+ let mut logits_process = LogitsProcessor::new(42, Some(1.0), Some(0.5));
+ let logits = Tensor::new(&[0.1, 0.2, 0.3, 0.4], &Device::Cpu)?;
+ let token = logits_process.sample(&logits)?;
+ assert_eq!(token, 2);
+ Ok(())
+}
diff --git a/candle-wasm-examples/llama2-c/lib-example.html b/candle-wasm-examples/llama2-c/lib-example.html
index b5033c54..22b12517 100644
--- a/candle-wasm-examples/llama2-c/lib-example.html
+++ b/candle-wasm-examples/llama2-c/lib-example.html
@@ -56,6 +56,7 @@
const weightsURL = `${MODELS_BASE_URL}/${model.url}`;
const prompt = getValue("prompt");
const temperature = getValue("temperature");
+ const topP = getValue("top-p");
const repeatPenalty = getValue("repeat_penalty");
const seed = getValue("seed");
const maxSeqLen = getValue("max-seq");
@@ -99,6 +100,7 @@
tokenizerURL: "tokenizer.json",
prompt,
temp: temperature,
+ top_p: topP,
repeatPenalty,
seed: BigInt(seed),
maxSeqLen,
@@ -251,7 +253,7 @@
0.50
+
+
+
>,
+ top_p: std::rc::Rc>,
prompt: std::rc::Rc>,
generated: String,
n_tokens: usize,
@@ -81,6 +82,7 @@ impl Component for App {
status,
n_tokens: 0,
temperature: std::rc::Rc::new(std::cell::RefCell::new(0.)),
+ top_p: std::rc::Rc::new(std::cell::RefCell::new(1.0)),
prompt: std::rc::Rc::new(std::cell::RefCell::new("".to_string())),
generated: String::new(),
current_decode: None,
@@ -122,10 +124,11 @@ impl Component for App {
self.n_tokens = 0;
self.generated.clear();
let temp = *self.temperature.borrow();
+ let top_p = *self.top_p.borrow();
let prompt = self.prompt.borrow().clone();
- console_log!("temp: {}, prompt: {}", temp, prompt);
+ console_log!("temp: {}, top_p: {}, prompt: {}", temp, top_p, prompt);
ctx.link()
- .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, prompt)))
+ .send_message(Msg::WorkerInMsg(WorkerInput::Run(temp, top_p, prompt)))
}
true
}
@@ -177,13 +180,21 @@ impl Component for App {
fn view(&self, ctx: &Context) -> Html {
use yew::TargetCast;
let temperature = self.temperature.clone();
- let oninput = ctx.link().callback(move |e: yew::InputEvent| {
+ let oninput_temperature = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
if let Ok(temp) = f64::from_str(&input.value()) {
*temperature.borrow_mut() = temp
}
Msg::Refresh
});
+ let top_p = self.top_p.clone();
+ let oninput_top_p = ctx.link().callback(move |e: yew::InputEvent| {
+ let input: web_sys::HtmlInputElement = e.target_unchecked_into();
+ if let Ok(top_p_input) = f64::from_str(&input.value()) {
+ *top_p.borrow_mut() = top_p_input
+ }
+ Msg::Refresh
+ });
let prompt = self.prompt.clone();
let oninput_prompt = ctx.link().callback(move |e: yew::InputEvent| {
let input: web_sys::HtmlInputElement = e.target_unchecked_into();
@@ -201,9 +212,13 @@ impl Component for App {
{"temperature \u{00a0} "}
-
+
{format!(" \u{00a0} {}", self.temperature.borrow())}
+ {"top_p \u{00a0} "}
+
+ {format!(" \u{00a0} {}", self.top_p.borrow())}
+
{"prompt: "}
{
diff --git a/candle-wasm-examples/llama2-c/src/bin/m.rs b/candle-wasm-examples/llama2-c/src/bin/m.rs
index 6628ab7e..61de9d7f 100644
--- a/candle-wasm-examples/llama2-c/src/bin/m.rs
+++ b/candle-wasm-examples/llama2-c/src/bin/m.rs
@@ -47,7 +47,7 @@ impl Model {
tokenizer,
model: weights,
});
- let logits_processor = LogitsProcessor::new(299792458, None);
+ let logits_processor = LogitsProcessor::new(299792458, None, None);
match model {
Ok(inner) => Ok(Self {
inner,
@@ -69,6 +69,7 @@ impl Model {
&mut self,
prompt: String,
temp: f64,
+ top_p: f64,
repeat_penalty: f32,
seed: u64,
) -> Result {
@@ -80,7 +81,12 @@ impl Model {
}
}
let temp = if temp <= 0. { None } else { Some(temp) };
- self.logits_processor = LogitsProcessor::new(seed, temp);
+ let top_p = if top_p <= 0. || top_p >= 1. {
+ None
+ } else {
+ Some(top_p)
+ };
+ self.logits_processor = LogitsProcessor::new(seed, temp, top_p);
self.repeat_penalty = repeat_penalty;
self.tokens.clear();
let tokens = self
diff --git a/candle-wasm-examples/llama2-c/src/worker.rs b/candle-wasm-examples/llama2-c/src/worker.rs
index 7e97b5da..79dd2f32 100644
--- a/candle-wasm-examples/llama2-c/src/worker.rs
+++ b/candle-wasm-examples/llama2-c/src/worker.rs
@@ -62,12 +62,18 @@ impl Model {
link: &WorkerLink,
id: HandlerId,
temp: f64,
+ top_p: f64,
prompt: String,
) -> Result<()> {
let dev = Device::Cpu;
let temp = if temp <= 0. { None } else { Some(temp) };
- console_log!("{temp:?} {prompt}");
- let mut logits_processor = LogitsProcessor::new(299792458, temp);
+ let top_p = if top_p <= 0. || top_p >= 1.0 {
+ None
+ } else {
+ Some(top_p)
+ };
+ console_log!("temp: {temp:?} top_p: {top_p:?} prompt: {prompt}");
+ let mut logits_processor = LogitsProcessor::new(299792458, temp, top_p);
let mut index_pos = 0;
let mut tokens = self
.tokenizer
@@ -268,7 +274,7 @@ pub struct Worker {
#[derive(Serialize, Deserialize)]
pub enum WorkerInput {
ModelData(ModelData),
- Run(f64, String),
+ Run(f64, f64, String),
}
#[derive(Serialize, Deserialize)]
@@ -301,7 +307,7 @@ impl yew_agent::Worker for Worker {
}
Err(err) => Err(format!("model creation error {err:?}")),
},
- WorkerInput::Run(temp, prompt) => match &mut self.model {
+ WorkerInput::Run(temp, top_p, prompt) => match &mut self.model {
None => Err("model has not been set yet".to_string()),
Some(model) => {
{
@@ -311,7 +317,7 @@ impl yew_agent::Worker for Worker {
}
}
let result = model
- .run(&self.link, id, temp, prompt)
+ .run(&self.link, id, temp, top_p, prompt)
.map_err(|e| e.to_string());
Ok(WorkerOutput::GenerationDone(result))
}