mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Implement top_p / nucleus sampling (#819)
* Implement top_p / nucleus sampling * Update changelog * rustfmt * Add tests * Fix clippy warning * Fix another clippy error
This commit is contained in:
@ -25,17 +25,25 @@ struct TextGeneration {
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
struct GenerationOptions {
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
fn new(
|
||||
model: Falcon,
|
||||
tokenizer: Tokenizer,
|
||||
generation_options: GenerationOptions,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
device: &Device,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp);
|
||||
let logits_processor =
|
||||
LogitsProcessor::new(seed, generation_options.temp, generation_options.top_p);
|
||||
let repeat_penalty = generation_options.repeat_penalty;
|
||||
let repeat_last_n = generation_options.repeat_last_n;
|
||||
Self {
|
||||
model,
|
||||
tokenizer,
|
||||
@ -118,6 +126,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
@ -185,15 +197,14 @@ fn main() -> Result<()> {
|
||||
let model = Falcon::load(vb, config)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
&device,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
);
|
||||
let generation_options = GenerationOptions {
|
||||
temp: args.temperature,
|
||||
top_p: args.top_p,
|
||||
repeat_penalty: args.repeat_penalty,
|
||||
repeat_last_n: args.repeat_last_n,
|
||||
};
|
||||
let mut pipeline =
|
||||
TextGeneration::new(model, tokenizer, generation_options, args.seed, &device);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user