mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Helper function to load sharded safetensors files (#1481)
* Fix the quantized mistral example. * Add a helper function to load sharded safetensors weights. * Use the sharded loader.
This commit is contained in:
@ -17,8 +17,11 @@ candle-nn = { path = "../candle-nn", version = "0.3.2" }
|
|||||||
candle-transformers = { path = "../candle-transformers", version = "0.3.2" }
|
candle-transformers = { path = "../candle-transformers", version = "0.3.2" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.2", optional = true }
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.2", optional = true }
|
||||||
candle-onnx = { path = "../candle-onnx", version = "0.3.2", optional = true }
|
candle-onnx = { path = "../candle-onnx", version = "0.3.2", optional = true }
|
||||||
|
|
||||||
|
csv = "1.3.0"
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
|
hf-hub = { workspace = true, features=["tokio"]}
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
@ -28,13 +31,11 @@ safetensors = { workspace = true }
|
|||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
tokenizers = { workspace = true, features = ["onig"] }
|
tokenizers = { workspace = true, features = ["onig"] }
|
||||||
csv = "1.3.0"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
clap = { workspace = true }
|
clap = { workspace = true }
|
||||||
hf-hub = { workspace = true, features=["tokio"]}
|
|
||||||
imageproc = { workspace = true }
|
imageproc = { workspace = true }
|
||||||
memmap2 = { workspace = true }
|
memmap2 = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
|
@ -235,10 +235,7 @@ fn main() -> Result<()> {
|
|||||||
if args.quantized {
|
if args.quantized {
|
||||||
vec![repo.get("model-q4k.gguf")?]
|
vec![repo.get("model-q4k.gguf")?]
|
||||||
} else {
|
} else {
|
||||||
vec![
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
repo.get("model-00001-of-00002.safetensors")?,
|
|
||||||
repo.get("model-00002-of-00002.safetensors")?,
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -209,29 +209,7 @@ fn main() -> Result<()> {
|
|||||||
.split(',')
|
.split(',')
|
||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => {
|
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
vec![
|
|
||||||
repo.get("model-00001-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00002-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00003-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00004-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00005-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00006-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00007-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00008-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00009-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00010-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00011-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00012-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00013-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00014-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00015-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00016-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00017-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00018-of-00019.safetensors")?,
|
|
||||||
repo.get("model-00019-of-00019.safetensors")?,
|
|
||||||
]
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
@ -278,10 +278,10 @@ fn main() -> Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||||
WhichModel::V2 => vec![
|
WhichModel::V2 => candle_examples::hub_load_safetensors(
|
||||||
repo.get("model-00001-of-00002.safetensors")?,
|
&repo,
|
||||||
repo.get("model-00002-of-00002.safetensors")?,
|
"model.safetensors.index.json",
|
||||||
],
|
)?,
|
||||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
|
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2.safetensors")?],
|
||||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
|
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B.safetensors")?],
|
||||||
}
|
}
|
||||||
|
@ -96,25 +96,9 @@ impl T5ModelBuilder {
|
|||||||
let api = api.repo(repo);
|
let api = api.repo(repo);
|
||||||
let config_filename = api.get("config.json")?;
|
let config_filename = api.get("config.json")?;
|
||||||
let tokenizer_filename = api.get("tokenizer.json")?;
|
let tokenizer_filename = api.get("tokenizer.json")?;
|
||||||
let weights_filename = if model_id == "google/flan-t5-xxl" {
|
let weights_filename = if model_id == "google/flan-t5-xxl" || model_id == "google/flan-ul2"
|
||||||
vec![
|
{
|
||||||
api.get("model-00001-of-00005.safetensors")?,
|
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||||
api.get("model-00002-of-00005.safetensors")?,
|
|
||||||
api.get("model-00003-of-00005.safetensors")?,
|
|
||||||
api.get("model-00004-of-00005.safetensors")?,
|
|
||||||
api.get("model-00005-of-00005.safetensors")?,
|
|
||||||
]
|
|
||||||
} else if model_id == "google/flan-ul2" {
|
|
||||||
vec![
|
|
||||||
api.get("model-00001-of-00008.safetensors")?,
|
|
||||||
api.get("model-00002-of-00008.safetensors")?,
|
|
||||||
api.get("model-00003-of-00008.safetensors")?,
|
|
||||||
api.get("model-00004-of-00008.safetensors")?,
|
|
||||||
api.get("model-00005-of-00008.safetensors")?,
|
|
||||||
api.get("model-00006-of-00008.safetensors")?,
|
|
||||||
api.get("model-00007-of-00008.safetensors")?,
|
|
||||||
api.get("model-00008-of-00008.safetensors")?,
|
|
||||||
]
|
|
||||||
} else {
|
} else {
|
||||||
vec![api.get("model.safetensors")?]
|
vec![api.get("model.safetensors")?]
|
||||||
};
|
};
|
||||||
|
@ -218,21 +218,7 @@ fn main() -> Result<()> {
|
|||||||
.split(',')
|
.split(',')
|
||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => match args.which {
|
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
Which::L6b => vec![
|
|
||||||
repo.get("model-00001-of-00002.safetensors")?,
|
|
||||||
repo.get("model-00002-of-00002.safetensors")?,
|
|
||||||
],
|
|
||||||
Which::L34b => vec![
|
|
||||||
repo.get("model-00001-of-00007.safetensors")?,
|
|
||||||
repo.get("model-00002-of-00007.safetensors")?,
|
|
||||||
repo.get("model-00003-of-00007.safetensors")?,
|
|
||||||
repo.get("model-00004-of-00007.safetensors")?,
|
|
||||||
repo.get("model-00005-of-00007.safetensors")?,
|
|
||||||
repo.get("model-00006-of-00007.safetensors")?,
|
|
||||||
repo.get("model-00007-of-00007.safetensors")?,
|
|
||||||
],
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
@ -117,3 +117,30 @@ pub fn save_image_resize<P: AsRef<std::path::Path>>(
|
|||||||
image.save(p).map_err(candle::Error::wrap)?;
|
image.save(p).map_err(candle::Error::wrap)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Loads the safetensors files for a model from the hub based on a json index file.
|
||||||
|
pub fn hub_load_safetensors(
|
||||||
|
repo: &hf_hub::api::sync::ApiRepo,
|
||||||
|
json_file: &str,
|
||||||
|
) -> Result<Vec<std::path::PathBuf>> {
|
||||||
|
let json_file = repo.get(json_file).map_err(candle::Error::wrap)?;
|
||||||
|
let json_file = std::fs::File::open(json_file)?;
|
||||||
|
let json: serde_json::Value =
|
||||||
|
serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?;
|
||||||
|
let weight_map = match json.get("weight_map") {
|
||||||
|
None => candle::bail!("no weight map in {json_file:?}"),
|
||||||
|
Some(serde_json::Value::Object(map)) => map,
|
||||||
|
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
|
||||||
|
};
|
||||||
|
let mut safetensors_files = std::collections::HashSet::new();
|
||||||
|
for value in weight_map.values() {
|
||||||
|
if let Some(file) = value.as_str() {
|
||||||
|
safetensors_files.insert(file.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let safetensors_files = safetensors_files
|
||||||
|
.iter()
|
||||||
|
.map(|v| repo.get(v).map_err(candle::Error::wrap))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
Ok(safetensors_files)
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user