Fix GLM4 alignment issue (#2723)

* Fix GLM4 alignment issue

* Cleanups.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Guoqing Bao
2025-01-21 05:51:46 +08:00
committed by GitHub
parent 17cbbe4286
commit e4c3a71f11
5 changed files with 54 additions and 22 deletions

View File

@ -4,7 +4,6 @@ pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
pub mod wav;
use candle::utils::{cuda_is_available, metal_is_available};
use candle::{Device, Result, Tensor};
@ -147,3 +146,28 @@ pub fn hub_load_safetensors(
.collect::<Result<Vec<_>>>()?;
Ok(safetensors_files)
}
pub fn hub_load_local_safetensors<P: AsRef<std::path::Path>>(
path: P,
json_file: &str,
) -> Result<Vec<std::path::PathBuf>> {
let path = path.as_ref();
let jsfile = std::fs::File::open(path.join(json_file))?;
let json: serde_json::Value = serde_json::from_reader(&jsfile).map_err(candle::Error::wrap)?;
let weight_map = match json.get("weight_map") {
None => candle::bail!("no weight map in {json_file:?}"),
Some(serde_json::Value::Object(map)) => map,
Some(_) => candle::bail!("weight map in {json_file:?} is not a map"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
if let Some(file) = value.as_str() {
safetensors_files.insert(file);
}
}
let safetensors_files: Vec<_> = safetensors_files
.into_iter()
.map(|v| path.join(v))
.collect();
Ok(safetensors_files)
}