Retrieve more information from PyTorch checkpoints. (#515)

* Retrieve more information from PyTorch checkpoints.

* Add enough support to load dino-v2 backbone weights.
This commit is contained in:
Laurent Mazare
2023-08-19 15:05:34 +01:00
committed by GitHub
parent f861a9df6e
commit 607ffb9f1e
3 changed files with 75 additions and 20 deletions

View File

@ -88,9 +88,15 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> {
}
Format::PyTorch => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, dtype, shape) in tensors.iter() {
println!("{name}: [{shape:?}; {dtype:?}]")
tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() {
println!(
"{}: [{:?}; {:?}] {:?}",
tensor_info.name,
tensor_info.layout.shape(),
tensor_info.dtype,
tensor_info.path,
)
}
}
Format::Pickle => {