mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Sample with temperature. (#106)
This commit is contained in:
@ -19,21 +19,27 @@ const DTYPE: DType = DType::F32;
|
|||||||
#[cfg(not(feature = "mkl"))]
|
#[cfg(not(feature = "mkl"))]
|
||||||
const DTYPE: DType = DType::BF16;
|
const DTYPE: DType = DType::BF16;
|
||||||
|
|
||||||
const TEMPERATURE: Option<f64> = None;
|
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
model: Falcon,
|
model: Falcon,
|
||||||
rng: rand::rngs::StdRng,
|
rng: rand::rngs::StdRng,
|
||||||
device: Device,
|
device: Device,
|
||||||
|
temperature: Option<f64>,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TextGeneration {
|
impl TextGeneration {
|
||||||
fn new(model: Falcon, tokenizer: Tokenizer, seed: u64, device: &Device) -> Self {
|
fn new(
|
||||||
|
model: Falcon,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
seed: u64,
|
||||||
|
temperature: Option<f64>,
|
||||||
|
device: &Device,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||||
|
temperature,
|
||||||
device: device.clone(),
|
device: device.clone(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -61,7 +67,7 @@ impl TextGeneration {
|
|||||||
let logits = self.model.forward(&input)?;
|
let logits = self.model.forward(&input)?;
|
||||||
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
|
|
||||||
let next_token = if let Some(temperature) = TEMPERATURE {
|
let next_token = if let Some(temperature) = self.temperature {
|
||||||
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
let prs = (&logits / temperature)?.softmax(D::Minus1)?;
|
||||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||||
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
|
||||||
@ -107,6 +113,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples.
|
||||||
|
#[arg(long)]
|
||||||
|
temperature: Option<f64>,
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
@ -161,7 +171,7 @@ fn main() -> Result<()> {
|
|||||||
let model = Falcon::load(&vb, config)?;
|
let model = Falcon::load(&vb, config)?;
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
|
||||||
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, &device);
|
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
pipeline.run(&args.prompt, args.sample_len)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user