mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Include topk sampling in the quantized example. (#2005)
* Include topk sampling in the quantized example. * Also sample with top-k on the mistral side.
This commit is contained in:
@ -13,7 +13,7 @@ use candle_transformers::models::quantized_mistral::Model as QMistral;
|
|||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
@ -39,11 +39,26 @@ impl TextGeneration {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
temp: Option<f64>,
|
temp: Option<f64>,
|
||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
|
top_k: Option<usize>,
|
||||||
repeat_penalty: f32,
|
repeat_penalty: f32,
|
||||||
repeat_last_n: usize,
|
repeat_last_n: usize,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
let logits_processor = {
|
||||||
|
let temperature = temp.unwrap_or(0.);
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (top_k, top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
tokenizer: TokenOutputStream::new(tokenizer),
|
tokenizer: TokenOutputStream::new(tokenizer),
|
||||||
@ -159,6 +174,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
@ -314,6 +333,7 @@ fn main() -> Result<()> {
|
|||||||
args.seed,
|
args.seed,
|
||||||
args.temperature,
|
args.temperature,
|
||||||
args.top_p,
|
args.top_p,
|
||||||
|
args.top_k,
|
||||||
args.repeat_penalty,
|
args.repeat_penalty,
|
||||||
args.repeat_last_n,
|
args.repeat_last_n,
|
||||||
&device,
|
&device,
|
||||||
|
@ -10,7 +10,7 @@ use tokenizers::Tokenizer;
|
|||||||
|
|
||||||
use candle::quantized::{ggml_file, gguf_file};
|
use candle::quantized::{ggml_file, gguf_file};
|
||||||
use candle::Tensor;
|
use candle::Tensor;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
use candle_transformers::models::quantized_llama as model;
|
use candle_transformers::models::quantized_llama as model;
|
||||||
@ -200,6 +200,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
top_p: Option<f64>,
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
@ -349,11 +353,6 @@ fn main() -> anyhow::Result<()> {
|
|||||||
#[cfg(feature = "cuda")]
|
#[cfg(feature = "cuda")]
|
||||||
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
candle::quantized::cuda::set_force_dmmv(args.force_dmmv);
|
||||||
|
|
||||||
let temperature = if args.temperature == 0. {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some(args.temperature)
|
|
||||||
};
|
|
||||||
let _guard = if args.tracing {
|
let _guard = if args.tracing {
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
@ -500,7 +499,20 @@ fn main() -> anyhow::Result<()> {
|
|||||||
prompt_tokens
|
prompt_tokens
|
||||||
};
|
};
|
||||||
let mut all_tokens = vec![];
|
let mut all_tokens = vec![];
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
let mut logits_processor = {
|
||||||
|
let temperature = args.temperature;
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (args.top_k, args.top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
let start_prompt_processing = std::time::Instant::now();
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
let mut next_token = if !args.split_prompt {
|
let mut next_token = if !args.split_prompt {
|
||||||
|
@ -7,6 +7,7 @@ pub enum Sampling {
|
|||||||
All { temperature: f64 },
|
All { temperature: f64 },
|
||||||
TopK { k: usize, temperature: f64 },
|
TopK { k: usize, temperature: f64 },
|
||||||
TopP { p: f64, temperature: f64 },
|
TopP { p: f64, temperature: f64 },
|
||||||
|
TopKThenTopP { k: usize, p: f64, temperature: f64 },
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LogitsProcessor {
|
pub struct LogitsProcessor {
|
||||||
@ -77,7 +78,6 @@ impl LogitsProcessor {
|
|||||||
self.sample_multinomial(prs)
|
self.sample_multinomial(prs)
|
||||||
} else {
|
} else {
|
||||||
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
||||||
// Sort by descending probability.
|
|
||||||
let (indices, _, _) =
|
let (indices, _, _) =
|
||||||
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
|
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
|
||||||
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
|
let prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
|
||||||
@ -86,6 +86,26 @@ impl LogitsProcessor {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// top-k sampling samples from the k tokens with the largest probabilities.
|
||||||
|
// then top-p sampling.
|
||||||
|
fn sample_topk_topp(&mut self, prs: &mut Vec<f32>, top_k: usize, top_p: f32) -> Result<u32> {
|
||||||
|
if top_k >= prs.len() {
|
||||||
|
self.sample_topp(prs, top_p)
|
||||||
|
} else {
|
||||||
|
let mut argsort_indices = (0..prs.len()).collect::<Vec<_>>();
|
||||||
|
let (indices, _, _) =
|
||||||
|
argsort_indices.select_nth_unstable_by(top_k, |&i, &j| prs[j].total_cmp(&prs[i]));
|
||||||
|
let mut prs = indices.iter().map(|&i| prs[i]).collect::<Vec<_>>();
|
||||||
|
let sum_p = prs.iter().sum::<f32>();
|
||||||
|
let index = if top_p <= 0.0 || top_p >= sum_p {
|
||||||
|
self.sample_multinomial(&prs)?
|
||||||
|
} else {
|
||||||
|
self.sample_topp(&mut prs, top_p)?
|
||||||
|
};
|
||||||
|
Ok(indices[index as usize] as u32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
pub fn sample(&mut self, logits: &Tensor) -> Result<u32> {
|
||||||
self.sample_f(logits, |_| {})
|
self.sample_f(logits, |_| {})
|
||||||
}
|
}
|
||||||
@ -120,6 +140,10 @@ impl LogitsProcessor {
|
|||||||
let mut prs = prs(*temperature)?;
|
let mut prs = prs(*temperature)?;
|
||||||
self.sample_topk(&mut prs, *k)?
|
self.sample_topk(&mut prs, *k)?
|
||||||
}
|
}
|
||||||
|
Sampling::TopKThenTopP { k, p, temperature } => {
|
||||||
|
let mut prs = prs(*temperature)?;
|
||||||
|
self.sample_topk_topp(&mut prs, *k, *p as f32)?
|
||||||
|
}
|
||||||
};
|
};
|
||||||
Ok(next_token)
|
Ok(next_token)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user