mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Handle LongStorage in pytorch checkpoints. (#1152)
This commit is contained in:
@ -609,6 +609,7 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
|||||||
"HalfStorage" => DType::F16,
|
"HalfStorage" => DType::F16,
|
||||||
"BFloat16Storage" => DType::BF16,
|
"BFloat16Storage" => DType::BF16,
|
||||||
"ByteStorage" => DType::U8,
|
"ByteStorage" => DType::U8,
|
||||||
|
"LongStorage" => DType::I64,
|
||||||
other => {
|
other => {
|
||||||
crate::bail!("unsupported storage type {other}")
|
crate::bail!("unsupported storage type {other}")
|
||||||
}
|
}
|
||||||
|
@ -5,11 +5,11 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||||
|
|
||||||
use anyhow::{anyhow, Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::Tensor;
|
use candle::Tensor;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::{PaddingParams, Tokenizer};
|
use tokenizers::{PaddingParams, Tokenizer};
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -19,10 +19,6 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
/// Run offline (you must have the files already cached)
|
|
||||||
#[arg(long)]
|
|
||||||
offline: bool,
|
|
||||||
|
|
||||||
/// Enable tracing (generates a trace-timestamp.json file).
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
@ -38,6 +34,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
/// Use the pytorch weights rather than the safetensors ones
|
||||||
|
#[arg(long)]
|
||||||
|
use_pth: bool,
|
||||||
|
|
||||||
/// The number of times to run the prompt.
|
/// The number of times to run the prompt.
|
||||||
#[arg(long, default_value = "1")]
|
#[arg(long, default_value = "1")]
|
||||||
n: usize,
|
n: usize,
|
||||||
@ -60,34 +60,27 @@ impl Args {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
|
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||||
let cache = Cache::default().repo(repo);
|
|
||||||
(
|
|
||||||
cache
|
|
||||||
.get("config.json")
|
|
||||||
.ok_or(anyhow!("Missing config file in cache"))?,
|
|
||||||
cache
|
|
||||||
.get("tokenizer.json")
|
|
||||||
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
|
|
||||||
cache
|
|
||||||
.get("model.safetensors")
|
|
||||||
.ok_or(anyhow!("Missing weights file in cache"))?,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let api = api.repo(repo);
|
let api = api.repo(repo);
|
||||||
(
|
let config = api.get("config.json")?;
|
||||||
api.get("config.json")?,
|
let tokenizer = api.get("tokenizer.json")?;
|
||||||
api.get("tokenizer.json")?,
|
let weights = if self.use_pth {
|
||||||
api.get("model.safetensors")?,
|
api.get("pytorch_model.bin")?
|
||||||
)
|
} else {
|
||||||
|
api.get("model.safetensors")?
|
||||||
|
};
|
||||||
|
(config, tokenizer, weights)
|
||||||
};
|
};
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let config: Config = serde_json::from_str(&config)?;
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let vb =
|
let vb = if self.use_pth {
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
|
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||||
|
} else {
|
||||||
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||||
|
};
|
||||||
let model = BertModel::load(vb, &config)?;
|
let model = BertModel::load(vb, &config)?;
|
||||||
Ok((model, tokenizer))
|
Ok((model, tokenizer))
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user