mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Putting back seed.
This commit is contained in:
@ -13,7 +13,7 @@
|
||||
// transposition operations.
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
use rand::{distributions::Distribution, thread_rng};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_hub::{api::Api, Repo, RepoType};
|
||||
@ -406,6 +406,10 @@ struct Args {
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
@ -463,7 +467,7 @@ async fn main() -> Result<()> {
|
||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||
println!("starting the inference loop");
|
||||
let mut new_tokens = vec![];
|
||||
let mut rng = thread_rng();
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..args.sample_len {
|
||||
let start_gen = std::time::Instant::now();
|
||||
|
Reference in New Issue
Block a user