From 8a82d623e5ad919ba422bc796bd31b6fc3b91ab1 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 22 Oct 2023 18:34:36 +0100 Subject: [PATCH] Handle LongStorage in pytorch checkpoints. (#1152) --- candle-core/src/pickle.rs | 1 + candle-examples/examples/bert/main.rs | 47 ++++++++++++--------------- 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 4a2c65fd..25640d1a 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -609,6 +609,7 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> { "HalfStorage" => DType::F16, "BFloat16Storage" => DType::BF16, "ByteStorage" => DType::U8, + "LongStorage" => DType::I64, other => { crate::bail!("unsupported storage type {other}") } diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 70592013..fcd2eab9 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -5,11 +5,11 @@ extern crate intel_mkl_src; extern crate accelerate_src; use candle_transformers::models::bert::{BertModel, Config, DTYPE}; -use anyhow::{anyhow, Error as E, Result}; +use anyhow::{Error as E, Result}; use candle::Tensor; use candle_nn::VarBuilder; use clap::Parser; -use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; +use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::{PaddingParams, Tokenizer}; #[derive(Parser, Debug)] @@ -19,10 +19,6 @@ struct Args { #[arg(long)] cpu: bool, - /// Run offline (you must have the files already cached) - #[arg(long)] - offline: bool, - /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] tracing: bool, @@ -38,6 +34,10 @@ struct Args { #[arg(long)] prompt: Option, + /// Use the pytorch weights rather than the safetensors ones + #[arg(long)] + use_pth: bool, + /// The number of times to run the prompt. #[arg(long, default_value = "1")] n: usize, @@ -60,34 +60,27 @@ impl Args { }; let repo = Repo::with_revision(model_id, RepoType::Model, revision); - let (config_filename, tokenizer_filename, weights_filename) = if self.offline { - let cache = Cache::default().repo(repo); - ( - cache - .get("config.json") - .ok_or(anyhow!("Missing config file in cache"))?, - cache - .get("tokenizer.json") - .ok_or(anyhow!("Missing tokenizer file in cache"))?, - cache - .get("model.safetensors") - .ok_or(anyhow!("Missing weights file in cache"))?, - ) - } else { + let (config_filename, tokenizer_filename, weights_filename) = { let api = Api::new()?; let api = api.repo(repo); - ( - api.get("config.json")?, - api.get("tokenizer.json")?, - api.get("model.safetensors")?, - ) + let config = api.get("config.json")?; + let tokenizer = api.get("tokenizer.json")?; + let weights = if self.use_pth { + api.get("pytorch_model.bin")? + } else { + api.get("model.safetensors")? + }; + (config, tokenizer, weights) }; let config = std::fs::read_to_string(config_filename)?; let config: Config = serde_json::from_str(&config)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; - let vb = - unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }; + let vb = if self.use_pth { + VarBuilder::from_pth(&weights_filename, DTYPE, &device)? + } else { + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? } + }; let model = BertModel::load(vb, &config)?; Ok((model, tokenizer)) }