Handle LongStorage in pytorch checkpoints. (#1152)

This commit is contained in:
Laurent Mazare
2023-10-22 18:34:36 +01:00
committed by GitHub
parent df2f89b6cf
commit 8a82d623e5
2 changed files with 21 additions and 27 deletions

View File

@ -609,6 +609,7 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
"HalfStorage" => DType::F16, "HalfStorage" => DType::F16,
"BFloat16Storage" => DType::BF16, "BFloat16Storage" => DType::BF16,
"ByteStorage" => DType::U8, "ByteStorage" => DType::U8,
"LongStorage" => DType::I64,
other => { other => {
crate::bail!("unsupported storage type {other}") crate::bail!("unsupported storage type {other}")
} }

View File

@ -5,11 +5,11 @@ extern crate intel_mkl_src;
extern crate accelerate_src; extern crate accelerate_src;
use candle_transformers::models::bert::{BertModel, Config, DTYPE}; use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use anyhow::{anyhow, Error as E, Result}; use anyhow::{Error as E, Result};
use candle::Tensor; use candle::Tensor;
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use clap::Parser; use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer}; use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
@ -19,10 +19,6 @@ struct Args {
#[arg(long)] #[arg(long)]
cpu: bool, cpu: bool,
/// Run offline (you must have the files already cached)
#[arg(long)]
offline: bool,
/// Enable tracing (generates a trace-timestamp.json file). /// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)] #[arg(long)]
tracing: bool, tracing: bool,
@ -38,6 +34,10 @@ struct Args {
#[arg(long)] #[arg(long)]
prompt: Option<String>, prompt: Option<String>,
/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,
/// The number of times to run the prompt. /// The number of times to run the prompt.
#[arg(long, default_value = "1")] #[arg(long, default_value = "1")]
n: usize, n: usize,
@ -60,34 +60,27 @@ impl Args {
}; };
let repo = Repo::with_revision(model_id, RepoType::Model, revision); let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = if self.offline { let (config_filename, tokenizer_filename, weights_filename) = {
let cache = Cache::default().repo(repo);
(
cache
.get("config.json")
.ok_or(anyhow!("Missing config file in cache"))?,
cache
.get("tokenizer.json")
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
cache
.get("model.safetensors")
.ok_or(anyhow!("Missing weights file in cache"))?,
)
} else {
let api = Api::new()?; let api = Api::new()?;
let api = api.repo(repo); let api = api.repo(repo);
( let config = api.get("config.json")?;
api.get("config.json")?, let tokenizer = api.get("tokenizer.json")?;
api.get("tokenizer.json")?, let weights = if self.use_pth {
api.get("model.safetensors")?, api.get("pytorch_model.bin")?
) } else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
}; };
let config = std::fs::read_to_string(config_filename)?; let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?; let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = let vb = if self.use_pth {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }; VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
let model = BertModel::load(vb, &config)?; let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer)) Ok((model, tokenizer))
} }