Resurrect the llama npy support. (#140)

This commit is contained in:
Laurent Mazare
2023-07-11 19:32:10 +01:00
committed by GitHub
parent 760f1d7055
commit 37cad85869
6 changed files with 264 additions and 90 deletions

View File

@ -144,8 +144,14 @@ fn main() -> Result<()> {
let config = Config::config_7b();
let cache = model::Cache::new(!args.no_kv_cache, &config, &device);
let (llama, tokenizer_filename) = match args.npy {
Some(_) => {
todo!("fix numpy handling if we continue supporting it")
Some(filename) => {
let tensors = Tensor::read_npz(filename)?
.into_iter()
.map(|(n, t)| Ok((n, t.to_dtype(DTYPE)?)))
.collect::<Result<std::collections::HashMap<String, Tensor>>>()?;
let vb = VarBuilder::from_tensors(tensors, DTYPE, &device);
let tokenizer = std::path::PathBuf::from("llama-tokenizer.json");
(Llama::load(vb, &cache, &config)?, tokenizer)
}
None => {
let api = Api::new()?;