From c5e8f788be94c2aa7e91db6ef3f72b0e3b55bd5b Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 29 Jun 2023 12:05:53 +0000 Subject: [PATCH] Revert some changes. --- candle-core/examples/llama/main.rs | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/candle-core/examples/llama/main.rs b/candle-core/examples/llama/main.rs index 3a025683..2f9daec0 100644 --- a/candle-core/examples/llama/main.rs +++ b/candle-core/examples/llama/main.rs @@ -422,9 +422,6 @@ async fn main() -> Result<()> { } else { Device::new_cuda(0)? }; - let api = Api::new()?; - let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); - println!("building the model"); let config = Config::config_7b(); let cache = Cache::new(&device); let start = std::time::Instant::now(); @@ -435,6 +432,9 @@ async fn main() -> Result<()> { std::path::Path::new("llama-tokenizer.json").to_path_buf(), ) } else { + let api = Api::new()?; + let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model); + println!("building the model"); let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; let mut filenames = vec![]; for rfilename in [ @@ -483,14 +483,9 @@ async fn main() -> Result<()> { logits_v .iter() .enumerate() - .fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| { - if &val_max > val { - (idx_max, val_max) - } else { - (idx, *val) - } - }) - .0 as u32 + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap() }; tokens.push(next_token); new_tokens.push(next_token);