Handle LongStorage in pytorch checkpoints. (#1152)

This commit is contained in:
Laurent Mazare
2023-10-22 18:34:36 +01:00
committed by GitHub
parent df2f89b6cf
commit 8a82d623e5
2 changed files with 21 additions and 27 deletions

View File

@ -609,6 +609,7 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
"HalfStorage" => DType::F16,
"BFloat16Storage" => DType::BF16,
"ByteStorage" => DType::U8,
"LongStorage" => DType::I64,
other => {
crate::bail!("unsupported storage type {other}")
}

View File

@ -5,11 +5,11 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use anyhow::{anyhow, Error as E, Result};
use anyhow::{Error as E, Result};
use candle::Tensor;
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)]
@ -19,10 +19,6 @@ struct Args {
#[arg(long)]
cpu: bool,
/// Run offline (you must have the files already cached)
#[arg(long)]
offline: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -38,6 +34,10 @@ struct Args {
#[arg(long)]
prompt: Option<String>,
/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
@ -60,34 +60,27 @@ impl Args {
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
let cache = Cache::default().repo(repo);
(
cache
.get("config.json")
.ok_or(anyhow!("Missing config file in cache"))?,
cache
.get("tokenizer.json")
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
cache
.get("model.safetensors")
.ok_or(anyhow!("Missing weights file in cache"))?,
)
} else {
let (config_filename, tokenizer_filename, weights_filename) = {
let api = Api::new()?;
let api = api.repo(repo);
(
api.get("config.json")?,
api.get("tokenizer.json")?,
api.get("model.safetensors")?,
)
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
let vb = if self.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}