Retrieve tensor data from PyTorch files. (#516)

This commit is contained in:
Laurent Mazare
2023-08-19 15:57:18 +01:00
committed by GitHub
parent 607ffb9f1e
commit 6431140250
3 changed files with 65 additions and 9 deletions

View File

@ -45,7 +45,7 @@ struct Args {
command: Command,
}
fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> {
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
let format = match format {
Some(format) => format,
None => match Format::infer(file) {
@ -91,12 +91,14 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> {
tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() {
println!(
"{}: [{:?}; {:?}] {:?}",
"{}: [{:?}; {:?}]",
tensor_info.name,
tensor_info.layout.shape(),
tensor_info.dtype,
tensor_info.path,
)
);
if verbose {
println!(" {:?}", tensor_info);
}
}
}
Format::Pickle => {
@ -130,7 +132,7 @@ fn main() -> anyhow::Result<()> {
if multiple_files {
println!("--- {file:?} ---");
}
run_ls(file, format.clone())?
run_ls(file, format.clone(), args.verbose)?
}
}
}