mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Small tweaks to tensor-tools. (#517)
This commit is contained in:
@ -6,7 +6,7 @@ enum Format {
|
|||||||
Safetensors,
|
Safetensors,
|
||||||
Npz,
|
Npz,
|
||||||
Ggml,
|
Ggml,
|
||||||
PyTorch,
|
Pth,
|
||||||
Pickle,
|
Pickle,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -16,9 +16,10 @@ impl Format {
|
|||||||
.extension()
|
.extension()
|
||||||
.and_then(|e| e.to_str())
|
.and_then(|e| e.to_str())
|
||||||
.and_then(|e| match e {
|
.and_then(|e| match e {
|
||||||
|
// We don't infer any format for .bin as it can be used for ggml or pytorch.
|
||||||
"safetensors" | "safetensor" => Some(Self::Safetensors),
|
"safetensors" | "safetensor" => Some(Self::Safetensors),
|
||||||
"npz" => Some(Self::Npz),
|
"npz" => Some(Self::Npz),
|
||||||
"pth" | "pt" => Some(Self::PyTorch),
|
"pth" | "pt" => Some(Self::Pth),
|
||||||
"ggml" => Some(Self::Ggml),
|
"ggml" => Some(Self::Ggml),
|
||||||
_ => None,
|
_ => None,
|
||||||
})
|
})
|
||||||
@ -29,18 +30,19 @@ impl Format {
|
|||||||
enum Command {
|
enum Command {
|
||||||
Ls {
|
Ls {
|
||||||
files: Vec<std::path::PathBuf>,
|
files: Vec<std::path::PathBuf>,
|
||||||
|
|
||||||
/// The file format to use, if unspecified infer from the file extension.
|
/// The file format to use, if unspecified infer from the file extension.
|
||||||
#[arg(long, value_enum)]
|
#[arg(long, value_enum)]
|
||||||
format: Option<Format>,
|
format: Option<Format>,
|
||||||
|
|
||||||
|
/// Enable verbose mode.
|
||||||
|
#[arg(short, long)]
|
||||||
|
verbose: bool,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
struct Args {
|
struct Args {
|
||||||
/// Enable verbose mode.
|
|
||||||
#[arg(short, long)]
|
|
||||||
verbose: bool,
|
|
||||||
|
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
command: Command,
|
command: Command,
|
||||||
}
|
}
|
||||||
@ -86,7 +88,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
|||||||
println!("{name}: [{shape:?}; {dtype}]")
|
println!("{name}: [{shape:?}; {dtype}]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::PyTorch => {
|
Format::Pth => {
|
||||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
|
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
|
||||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
for tensor_info in tensors.iter() {
|
for tensor_info in tensors.iter() {
|
||||||
@ -126,13 +128,17 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
|||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
match args.command {
|
match args.command {
|
||||||
Command::Ls { files, format } => {
|
Command::Ls {
|
||||||
|
files,
|
||||||
|
format,
|
||||||
|
verbose,
|
||||||
|
} => {
|
||||||
let multiple_files = files.len() > 1;
|
let multiple_files = files.len() > 1;
|
||||||
for file in files.iter() {
|
for file in files.iter() {
|
||||||
if multiple_files {
|
if multiple_files {
|
||||||
println!("--- {file:?} ---");
|
println!("--- {file:?} ---");
|
||||||
}
|
}
|
||||||
run_ls(file, format.clone(), args.verbose)?
|
run_ls(file, format.clone(), verbose)?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -490,13 +490,14 @@ impl From<Object> for E {
|
|||||||
|
|
||||||
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
|
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
|
||||||
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
|
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
|
||||||
fn rebuild_args(args: Object) -> Result<(Layout, DType, String)> {
|
fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
||||||
let mut args = args.tuple()?;
|
let mut args = args.tuple()?;
|
||||||
let stride = Vec::<usize>::try_from(args.remove(3))?;
|
let stride = Vec::<usize>::try_from(args.remove(3))?;
|
||||||
let size = Vec::<usize>::try_from(args.remove(2))?;
|
let size = Vec::<usize>::try_from(args.remove(2))?;
|
||||||
let offset = args.remove(1).int()? as usize;
|
let offset = args.remove(1).int()? as usize;
|
||||||
let storage = args.remove(0).persistent_load()?;
|
let storage = args.remove(0).persistent_load()?;
|
||||||
let mut storage = storage.tuple()?;
|
let mut storage = storage.tuple()?;
|
||||||
|
let storage_size = storage.remove(4).int()? as usize;
|
||||||
let path = storage.remove(2).unicode()?;
|
let path = storage.remove(2).unicode()?;
|
||||||
let (_module_name, class_name) = storage.remove(1).class()?;
|
let (_module_name, class_name) = storage.remove(1).class()?;
|
||||||
let dtype = match class_name.as_str() {
|
let dtype = match class_name.as_str() {
|
||||||
@ -510,7 +511,7 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String)> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let layout = Layout::new(crate::Shape::from(size), stride, offset);
|
let layout = Layout::new(crate::Shape::from(size), stride, offset);
|
||||||
Ok((layout, dtype, path))
|
Ok((layout, dtype, path, storage_size))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -519,6 +520,7 @@ pub struct TensorInfo {
|
|||||||
pub dtype: DType,
|
pub dtype: DType,
|
||||||
pub layout: Layout,
|
pub layout: Layout,
|
||||||
pub path: String,
|
pub path: String,
|
||||||
|
pub storage_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> {
|
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> {
|
||||||
@ -576,7 +578,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
|
|||||||
_ => continue,
|
_ => continue,
|
||||||
};
|
};
|
||||||
match rebuild_args(args) {
|
match rebuild_args(args) {
|
||||||
Ok((layout, dtype, file_path)) => {
|
Ok((layout, dtype, file_path, storage_size)) => {
|
||||||
let mut path = dir_name.clone();
|
let mut path = dir_name.clone();
|
||||||
path.push(file_path);
|
path.push(file_path);
|
||||||
tensor_infos.push(TensorInfo {
|
tensor_infos.push(TensorInfo {
|
||||||
@ -584,6 +586,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
|
|||||||
dtype,
|
dtype,
|
||||||
layout,
|
layout,
|
||||||
path: path.to_string_lossy().into_owned(),
|
path: path.to_string_lossy().into_owned(),
|
||||||
|
storage_size,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
|
Reference in New Issue
Block a user