mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Small tweaks to tensor-tools. (#517)
This commit is contained in:
@ -6,7 +6,7 @@ enum Format {
|
||||
Safetensors,
|
||||
Npz,
|
||||
Ggml,
|
||||
PyTorch,
|
||||
Pth,
|
||||
Pickle,
|
||||
}
|
||||
|
||||
@ -16,9 +16,10 @@ impl Format {
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.and_then(|e| match e {
|
||||
// We don't infer any format for .bin as it can be used for ggml or pytorch.
|
||||
"safetensors" | "safetensor" => Some(Self::Safetensors),
|
||||
"npz" => Some(Self::Npz),
|
||||
"pth" | "pt" => Some(Self::PyTorch),
|
||||
"pth" | "pt" => Some(Self::Pth),
|
||||
"ggml" => Some(Self::Ggml),
|
||||
_ => None,
|
||||
})
|
||||
@ -29,18 +30,19 @@ impl Format {
|
||||
enum Command {
|
||||
Ls {
|
||||
files: Vec<std::path::PathBuf>,
|
||||
|
||||
/// The file format to use, if unspecified infer from the file extension.
|
||||
#[arg(long, value_enum)]
|
||||
format: Option<Format>,
|
||||
|
||||
/// Enable verbose mode.
|
||||
#[arg(short, long)]
|
||||
verbose: bool,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug, Clone)]
|
||||
struct Args {
|
||||
/// Enable verbose mode.
|
||||
#[arg(short, long)]
|
||||
verbose: bool,
|
||||
|
||||
#[command(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
@ -86,7 +88,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
println!("{name}: [{shape:?}; {dtype}]")
|
||||
}
|
||||
}
|
||||
Format::PyTorch => {
|
||||
Format::Pth => {
|
||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
|
||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
for tensor_info in tensors.iter() {
|
||||
@ -126,13 +128,17 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
match args.command {
|
||||
Command::Ls { files, format } => {
|
||||
Command::Ls {
|
||||
files,
|
||||
format,
|
||||
verbose,
|
||||
} => {
|
||||
let multiple_files = files.len() > 1;
|
||||
for file in files.iter() {
|
||||
if multiple_files {
|
||||
println!("--- {file:?} ---");
|
||||
}
|
||||
run_ls(file, format.clone(), args.verbose)?
|
||||
run_ls(file, format.clone(), verbose)?
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user