From 0958c588f606443312bfd33502fd658ad2c0ccb1 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 29 Jun 2023 12:07:21 +0000 Subject: [PATCH] Putting back seed. --- candle-core/examples/llama/main.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 2f9daec0..c91537d8 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -13,7 +13,7 @@ // transposition operations. use anyhow::{Error as E, Result}; use clap::Parser; -use rand::{distributions::Distribution, thread_rng}; +use rand::{distributions::Distribution, SeedableRng}; use candle::{DType, Device, Tensor}; use candle_hub::{api::Api, Repo, RepoType}; @@ -406,6 +406,10 @@ struct Args { #[arg(long)] temperature: Option, + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + /// The length of the sample to generate (in tokens). #[arg(long, default_value_t = 100)] sample_len: usize, @@ -463,7 +467,7 @@ async fn main() -> Result<()> { let freqs_cis = precompute_freqs_cis(&config, &device)?; println!("starting the inference loop"); let mut new_tokens = vec![]; - let mut rng = thread_rng(); + let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed); let start_gen = std::time::Instant::now(); for index in 0..args.sample_len { let start_gen = std::time::Instant::now();