mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Merge pull request #37 from LaurentMazare/llama-seed
Add a seed parameter to llama.
This commit is contained in:
@ -13,7 +13,7 @@
|
|||||||
// transposition operations.
|
// transposition operations.
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use rand::{distributions::Distribution, thread_rng};
|
use rand::{distributions::Distribution, SeedableRng};
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_hub::{api::Api, Repo, RepoType};
|
use candle_hub::{api::Api, Repo, RepoType};
|
||||||
@ -401,6 +401,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
temperature: Option<f64>,
|
temperature: Option<f64>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, default_value_t = 100)]
|
#[arg(long, default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
@ -408,7 +412,6 @@ struct Args {
|
|||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() -> Result<()> {
|
async fn main() -> Result<()> {
|
||||||
//use rand::prelude::*;
|
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
@ -418,8 +421,6 @@ async fn main() -> Result<()> {
|
|||||||
Device::new_cuda(0)?
|
Device::new_cuda(0)?
|
||||||
};
|
};
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
|
||||||
println!("building the model");
|
|
||||||
let config = Config::config_7b();
|
let config = Config::config_7b();
|
||||||
let cache = Cache::new(&device);
|
let cache = Cache::new(&device);
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
@ -430,6 +431,7 @@ async fn main() -> Result<()> {
|
|||||||
std::path::Path::new("llama-tokenizer.json").to_path_buf(),
|
std::path::Path::new("llama-tokenizer.json").to_path_buf(),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
|
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
|
||||||
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
|
||||||
let mut filenames = vec![];
|
let mut filenames = vec![];
|
||||||
for rfilename in [
|
for rfilename in [
|
||||||
@ -458,7 +460,7 @@ async fn main() -> Result<()> {
|
|||||||
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
let freqs_cis = precompute_freqs_cis(&config, &device)?;
|
||||||
println!("starting the inference loop");
|
println!("starting the inference loop");
|
||||||
let mut new_tokens = vec![];
|
let mut new_tokens = vec![];
|
||||||
let mut rng = thread_rng();
|
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed);
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..args.sample_len {
|
for index in 0..args.sample_len {
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
@ -478,14 +480,9 @@ async fn main() -> Result<()> {
|
|||||||
logits_v
|
logits_v
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| {
|
.max_by(|(_, u), (_, v)| u.total_cmp(v))
|
||||||
if &val_max > val {
|
.map(|(i, _)| i as u32)
|
||||||
(idx_max, val_max)
|
.unwrap()
|
||||||
} else {
|
|
||||||
(idx, *val)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.0 as u32
|
|
||||||
};
|
};
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
new_tokens.push(next_token);
|
new_tokens.push(next_token);
|
||||||
|
Reference in New Issue
Block a user