Merge pull request #37 from LaurentMazare/llama-seed

Add a seed parameter to llama.
This commit is contained in:
Laurent Mazare
2023-06-29 12:51:45 +01:00
committed by GitHub

View File

@ -13,7 +13,7 @@
// transposition operations. // transposition operations.
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use clap::Parser; use clap::Parser;
use rand::{distributions::Distribution, thread_rng}; use rand::{distributions::Distribution, SeedableRng};
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_hub::{api::Api, Repo, RepoType}; use candle_hub::{api::Api, Repo, RepoType};
@ -401,6 +401,10 @@ struct Args {
#[arg(long)] #[arg(long)]
temperature: Option<f64>, temperature: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens). /// The length of the sample to generate (in tokens).
#[arg(long, default_value_t = 100)] #[arg(long, default_value_t = 100)]
sample_len: usize, sample_len: usize,
@ -408,7 +412,6 @@ struct Args {
#[tokio::main] #[tokio::main]
async fn main() -> Result<()> { async fn main() -> Result<()> {
//use rand::prelude::*;
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
let args = Args::parse(); let args = Args::parse();
@ -418,8 +421,6 @@ async fn main() -> Result<()> {
Device::new_cuda(0)? Device::new_cuda(0)?
}; };
let api = Api::new()?; let api = Api::new()?;
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
println!("building the model");
let config = Config::config_7b(); let config = Config::config_7b();
let cache = Cache::new(&device); let cache = Cache::new(&device);
let start = std::time::Instant::now(); let start = std::time::Instant::now();
@ -430,6 +431,7 @@ async fn main() -> Result<()> {
std::path::Path::new("llama-tokenizer.json").to_path_buf(), std::path::Path::new("llama-tokenizer.json").to_path_buf(),
) )
} else { } else {
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
let mut filenames = vec![]; let mut filenames = vec![];
for rfilename in [ for rfilename in [
@ -458,7 +460,7 @@ async fn main() -> Result<()> {
let freqs_cis = precompute_freqs_cis(&config, &device)?; let freqs_cis = precompute_freqs_cis(&config, &device)?;
println!("starting the inference loop"); println!("starting the inference loop");
let mut new_tokens = vec![]; let mut new_tokens = vec![];
let mut rng = thread_rng(); let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
for index in 0..args.sample_len { for index in 0..args.sample_len {
let start_gen = std::time::Instant::now(); let start_gen = std::time::Instant::now();
@ -478,14 +480,9 @@ async fn main() -> Result<()> {
logits_v logits_v
.iter() .iter()
.enumerate() .enumerate()
.fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| { .max_by(|(_, u), (_, v)| u.total_cmp(v))
if &val_max > val { .map(|(i, _)| i as u32)
(idx_max, val_max) .unwrap()
} else {
(idx, *val)
}
})
.0 as u32
}; };
tokens.push(next_token); tokens.push(next_token);
new_tokens.push(next_token); new_tokens.push(next_token);