diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 4783fbfd..3d79e58c 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -19,21 +19,27 @@ const DTYPE: DType = DType::F32; #[cfg(not(feature = "mkl"))] const DTYPE: DType = DType::BF16; -const TEMPERATURE: Option = None; - struct TextGeneration { model: Falcon, rng: rand::rngs::StdRng, device: Device, + temperature: Option, tokenizer: Tokenizer, } impl TextGeneration { - fn new(model: Falcon, tokenizer: Tokenizer, seed: u64, device: &Device) -> Self { + fn new( + model: Falcon, + tokenizer: Tokenizer, + seed: u64, + temperature: Option, + device: &Device, + ) -> Self { Self { model, tokenizer, rng: rand::rngs::StdRng::seed_from_u64(seed), + temperature, device: device.clone(), } } @@ -61,7 +67,7 @@ impl TextGeneration { let logits = self.model.forward(&input)?; let logits = logits.squeeze(0)?.to_dtype(DType::F32)?; - let next_token = if let Some(temperature) = TEMPERATURE { + let next_token = if let Some(temperature) = self.temperature { let prs = (&logits / temperature)?.softmax(D::Minus1)?; let logits_v: Vec = prs.to_vec1()?; let distr = rand::distributions::WeightedIndex::new(&logits_v)?; @@ -107,6 +113,10 @@ struct Args { #[arg(long)] prompt: String, + /// The temperature used to generate samples. + #[arg(long)] + temperature: Option, + /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -161,7 +171,7 @@ fn main() -> Result<()> { let model = Falcon::load(&vb, config)?; println!("loaded the model in {:?}", start.elapsed()); - let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, &device); + let mut pipeline = TextGeneration::new(model, tokenizer, args.seed, args.temperature, &device); pipeline.run(&args.prompt, args.sample_len)?; Ok(()) }