From b4dc9f61080acf832134a2ac5de0165486c6c758 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 29 Jun 2023 12:47:19 +0100 Subject: [PATCH] Add a seed parameter to llama. --- candle-core/examples/llama/main.rs | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index ed913595..8feb7fb0 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -13,7 +13,7 @@ // transposition operations. use anyhow::{Error as E, Result}; use clap::Parser; -use rand::{distributions::Distribution, thread_rng}; +use rand::{distributions::Distribution, SeedableRng}; use candle::{DType, Device, Tensor}; use candle_hub::{api::Api, Repo, RepoType}; @@ -401,6 +401,10 @@ struct Args { #[arg(long)] temperature: Option, + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + /// The length of the sample to generate (in tokens). #[arg(long, default_value_t = 100)] sample_len: usize, @@ -408,7 +412,6 @@ struct Args { #[tokio::main] async fn main() -> Result<()> { - //use rand::prelude::*; use tokenizers::Tokenizer; let args = Args::parse(); @@ -418,8 +421,6 @@ async fn main() -> Result<()> { Device::new_cuda(0)? }; let api = Api::new()?; - let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); - println!("building the model"); let config = Config::config_7b(); let cache = Cache::new(&device); let start = std::time::Instant::now(); @@ -430,6 +431,7 @@ async fn main() -> Result<()> { std::path::Path::new("llama-tokenizer.json").to_path_buf(), ) } else { + let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; let mut filenames = vec![]; for rfilename in [ @@ -458,7 +460,7 @@ async fn main() -> Result<()> { let freqs_cis = precompute_freqs_cis(&config, &device)?; println!("starting the inference loop"); let mut new_tokens = vec![]; - let mut rng = thread_rng(); + let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed); let start_gen = std::time::Instant::now(); for index in 0..args.sample_len { let start_gen = std::time::Instant::now(); @@ -478,14 +480,9 @@ async fn main() -> Result<()> { logits_v .iter() .enumerate() - .fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| { - if &val_max > val { - (idx_max, val_max) - } else { - (idx, *val) - } - }) - .0 as u32 + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap() }; tokens.push(next_token); new_tokens.push(next_token);