Sample with temperature. (#106)

This commit is contained in:
Laurent Mazare
2023-07-07 18:12:25 +01:00
committed by GitHub
parent 03dffe9ecc
commit f35cfc5e97

View File

@ -19,21 +19,27 @@ const DTYPE: DType = DType::F32;
#[cfg(not(feature = "mkl"))] #[cfg(not(feature = "mkl"))]
const DTYPE: DType = DType::BF16; const DTYPE: DType = DType::BF16;
const TEMPERATURE: Option<f64> = None;
struct TextGeneration { struct TextGeneration {
model: Falcon, model: Falcon,
rng: rand::rngs::StdRng, rng: rand::rngs::StdRng,
device: Device, device: Device,
temperature: Option<f64>,
tokenizer: Tokenizer, tokenizer: Tokenizer,
} }
impl TextGeneration { impl TextGeneration {
fn new(model: Falcon, tokenizer: Tokenizer, seed: u64, device: &Device) -> Self { fn new(
model: Falcon,
tokenizer: Tokenizer,
seed: u64,
temperature: Option<f64>,
device: &Device,
) -> Self {
Self { Self {
model, model,
tokenizer, tokenizer,
rng: rand::rngs::StdRng::seed_from_u64(seed), rng: rand::rngs::StdRng::seed_from_u64(seed),
temperature,
device: device.clone(), device: device.clone(),
} }
} }
@ -61,7 +67,7 @@ impl TextGeneration {
let logits = self.model.forward(&input)?; let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let next_token = if let Some(temperature) = TEMPERATURE { let next_token = if let Some(temperature) = self.temperature {
let prs = (&logits / temperature)?.softmax(D::Minus1)?; let prs = (&logits / temperature)?.softmax(D::Minus1)?;
let logits_v: Vec<f32> = prs.to_vec1()?; let logits_v: Vec<f32> = prs.to_vec1()?;
let distr = rand::distributions::WeightedIndex::new(&logits_v)?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?;
@ -107,6 +113,10 @@ struct Args {
#[arg(long)] #[arg(long)]
prompt: String, prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// The seed to use when generating random samples. /// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)] #[arg(long, default_value_t = 299792458)]
seed: u64, seed: u64,
@ -161,7 +171,7 @@ fn main() -> Result<()> {
let model = Falcon::load(&vb, config)?; let model = Falcon::load(&vb, config)?;
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, &device); let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device);
pipeline.run(&args.prompt, args.sample_len)?; pipeline.run(&args.prompt, args.sample_len)?;
Ok(()) Ok(())
} }