Small tweaks to tensor-tools. (#517)

This commit is contained in:
Laurent Mazare
2023-08-19 16:50:26 +01:00
committed by GitHub
parent 6431140250
commit 551409092e
2 changed files with 21 additions and 12 deletions

View File

@ -6,7 +6,7 @@ enum Format {
Safetensors,
Npz,
Ggml,
PyTorch,
Pth,
Pickle,
}
@ -16,9 +16,10 @@ impl Format {
.extension()
.and_then(|e| e.to_str())
.and_then(|e| match e {
// We don't infer any format for .bin as it can be used for ggml or pytorch.
"safetensors" | "safetensor" => Some(Self::Safetensors),
"npz" => Some(Self::Npz),
"pth" | "pt" => Some(Self::PyTorch),
"pth" | "pt" => Some(Self::Pth),
"ggml" => Some(Self::Ggml),
_ => None,
})
@ -29,18 +30,19 @@ impl Format {
enum Command {
Ls {
files: Vec<std::path::PathBuf>,
/// The file format to use, if unspecified infer from the file extension.
#[arg(long, value_enum)]
format: Option<Format>,
/// Enable verbose mode.
#[arg(short, long)]
verbose: bool,
},
}
#[derive(Parser, Debug, Clone)]
struct Args {
/// Enable verbose mode.
#[arg(short, long)]
verbose: bool,
#[command(subcommand)]
command: Command,
}
@ -86,7 +88,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
println!("{name}: [{shape:?}; {dtype}]")
}
}
Format::PyTorch => {
Format::Pth => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() {
@ -126,13 +128,17 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
fn main() -> anyhow::Result<()> {
let args = Args::parse();
match args.command {
Command::Ls { files, format } => {
Command::Ls {
files,
format,
verbose,
} => {
let multiple_files = files.len() > 1;
for file in files.iter() {
if multiple_files {
println!("--- {file:?} ---");
}
run_ls(file, format.clone(), args.verbose)?
run_ls(file, format.clone(), verbose)?
}
}
}