Revert some changes.

This commit is contained in:
Nicolas Patry
2023-06-29 12:05:53 +00:00
parent e63ed6aaa3
commit c5e8f788be

View File

@ -422,9 +422,6 @@ async fn main() -> Result<()> {
} else { } else {
Device::new_cuda(0)? Device::new_cuda(0)?
}; };
let api = Api::new()?;
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
println!("building the model");
let config = Config::config_7b(); let config = Config::config_7b();
let cache = Cache::new(&device); let cache = Cache::new(&device);
let start = std::time::Instant::now(); let start = std::time::Instant::now();
@ -435,6 +432,9 @@ async fn main() -> Result<()> {
std::path::Path::new("llama-tokenizer.json").to_path_buf(), std::path::Path::new("llama-tokenizer.json").to_path_buf(),
) )
} else { } else {
let api = Api::new()?;
let repo = Repo::new("Narsil/amall-7b".to_string(), RepoType::Model);
println!("building the model");
let tokenizer_filename = api.get(&repo, "tokenizer.json").await?; let tokenizer_filename = api.get(&repo, "tokenizer.json").await?;
let mut filenames = vec![]; let mut filenames = vec![];
for rfilename in [ for rfilename in [
@ -483,14 +483,9 @@ async fn main() -> Result<()> {
logits_v logits_v
.iter() .iter()
.enumerate() .enumerate()
.fold((0, logits_v[0]), |(idx_max, val_max), (idx, val)| { .max_by(|(_, u), (_, v)| u.total_cmp(v))
if &val_max > val { .map(|(i, _)| i as u32)
(idx_max, val_max) .unwrap()
} else {
(idx, *val)
}
})
.0 as u32
}; };
tokens.push(next_token); tokens.push(next_token);
new_tokens.push(next_token); new_tokens.push(next_token);