mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Putting back seed.
This commit is contained in:
@ -13,7 +13,7 @@
|
|||||||
// transposition operations.
|
// transposition operations.
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use rand::{distributions::Distribution, thread_rng};
|
use rand::{distributions::Distribution, SeedableRng};
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_hub::{api::Api, Repo, RepoType};
|
use candle_hub::{api::Api, Repo, RepoType};
|
||||||
@ -406,6 +406,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, default_value_t = 100)]
|
#[arg(long, default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
@ -463,7 +467,7 @@ async fn main() -> Result<()> {
|
|||||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
let mut rng = thread_rng();
|
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
|
Reference in New Issue
Block a user