mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Preliminary support for importing PyTorch weights. (#511)
* Pickle work-in-progress. * More unpickling. * More pickling. * Proper handling of setitems. * Clippy. * Again more pickling. * Restore the example. * Add enough pickle support to get the list of tensors. * Read the data from zip files. * Retrieve the tensor shape. * Extract the size and dtype. * More storage types. * Improve the destructuring.
This commit is contained in:
@ -45,6 +45,22 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> {
|
||||
println!("{name}: [{shape:?}; {dtype}]")
|
||||
}
|
||||
}
|
||||
Some("pt") | Some("pth") => {
|
||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
|
||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
for (name, dtype, shape) in tensors.iter() {
|
||||
println!("{name}: [{shape:?}; {dtype:?}]")
|
||||
}
|
||||
}
|
||||
Some("pkl") => {
|
||||
let file = std::fs::File::open(file)?;
|
||||
let mut reader = std::io::BufReader::new(file);
|
||||
let mut stack = candle_core::pickle::Stack::empty();
|
||||
stack.read_loop(&mut reader)?;
|
||||
for (i, obj) in stack.stack().iter().enumerate() {
|
||||
println!("{i} {obj:?}");
|
||||
}
|
||||
}
|
||||
Some(_) => {
|
||||
println!("{file:?}: unsupported file extension")
|
||||
}
|
||||
|
Reference in New Issue
Block a user