Preliminary support for importing PyTorch weights. (#511)

* Pickle work-in-progress.

* More unpickling.

* More pickling.

* Proper handling of setitems.

* Clippy.

* Again more pickling.

* Restore the example.

* Add enough pickle support to get the list of tensors.

* Read the data from zip files.

* Retrieve the tensor shape.

* Extract the size and dtype.

* More storage types.

* Improve the destructuring.
This commit is contained in:
Laurent Mazare
2023-08-19 11:26:32 +01:00
committed by GitHub
parent 90ff04e77e
commit ad33715c61
3 changed files with 573 additions and 0 deletions

View File

@ -45,6 +45,22 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> {
println!("{name}: [{shape:?}; {dtype}]")
}
}
Some("pt") | Some("pth") => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, dtype, shape) in tensors.iter() {
println!("{name}: [{shape:?}; {dtype:?}]")
}
}
Some("pkl") => {
let file = std::fs::File::open(file)?;
let mut reader = std::io::BufReader::new(file);
let mut stack = candle_core::pickle::Stack::empty();
stack.read_loop(&mut reader)?;
for (i, obj) in stack.stack().iter().enumerate() {
println!("{i} {obj:?}");
}
}
Some(_) => {
println!("{file:?}: unsupported file extension")
}