mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Retrieve more information from PyTorch checkpoints. (#515)
* Retrieve more information from PyTorch checkpoints. * Add enough support to load dino-v2 backbone weights.
This commit is contained in:
@ -88,9 +88,15 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> {
|
|||||||
}
|
}
|
||||||
Format::PyTorch => {
|
Format::PyTorch => {
|
||||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
|
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
|
||||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
for (name, dtype, shape) in tensors.iter() {
|
for tensor_info in tensors.iter() {
|
||||||
println!("{name}: [{shape:?}; {dtype:?}]")
|
println!(
|
||||||
|
"{}: [{:?}; {:?}] {:?}",
|
||||||
|
tensor_info.name,
|
||||||
|
tensor_info.layout.shape(),
|
||||||
|
tensor_info.dtype,
|
||||||
|
tensor_info.path,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Pickle => {
|
Format::Pickle => {
|
||||||
|
@ -9,6 +9,14 @@ pub struct Layout {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Layout {
|
impl Layout {
|
||||||
|
pub fn new(shape: Shape, stride: Vec<usize>, start_offset: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
shape,
|
||||||
|
stride,
|
||||||
|
start_offset,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
|
pub fn contiguous_with_offset<S: Into<Shape>>(shape: S, start_offset: usize) -> Self {
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let stride = shape.stride_contiguous();
|
let stride = shape.stride_contiguous();
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
// Just enough pickle support to be able to read PyTorch checkpoints.
|
// Just enough pickle support to be able to read PyTorch checkpoints.
|
||||||
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
|
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
|
||||||
// composable/tensor agnostic at some point.
|
// composable/tensor agnostic at some point.
|
||||||
use crate::{DType, Error as E, Result};
|
use crate::{DType, Error as E, Layout, Result};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::io::BufRead;
|
use std::io::BufRead;
|
||||||
|
|
||||||
|
const VERBOSE: bool = false;
|
||||||
|
|
||||||
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
|
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/
|
||||||
#[repr(u8)]
|
#[repr(u8)]
|
||||||
#[derive(Debug, Eq, PartialEq, Clone)]
|
#[derive(Debug, Eq, PartialEq, Clone)]
|
||||||
@ -352,8 +354,10 @@ impl Stack {
|
|||||||
match op_code {
|
match op_code {
|
||||||
OpCode::Proto => {
|
OpCode::Proto => {
|
||||||
let version = r.read_u8()?;
|
let version = r.read_u8()?;
|
||||||
|
if VERBOSE {
|
||||||
println!("proto {version}");
|
println!("proto {version}");
|
||||||
}
|
}
|
||||||
|
}
|
||||||
OpCode::Global => {
|
OpCode::Global => {
|
||||||
let module_name = read_to_newline(r)?;
|
let module_name = read_to_newline(r)?;
|
||||||
let class_name = read_to_newline(r)?;
|
let class_name = read_to_newline(r)?;
|
||||||
@ -486,11 +490,14 @@ impl From<Object> for E {
|
|||||||
|
|
||||||
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
|
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
|
||||||
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
|
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
|
||||||
fn rebuild_args(args: Object) -> Result<(Vec<usize>, DType)> {
|
fn rebuild_args(args: Object) -> Result<(Layout, DType, String)> {
|
||||||
let mut args = args.tuple()?;
|
let mut args = args.tuple()?;
|
||||||
|
let stride = Vec::<usize>::try_from(args.remove(3))?;
|
||||||
let size = Vec::<usize>::try_from(args.remove(2))?;
|
let size = Vec::<usize>::try_from(args.remove(2))?;
|
||||||
|
let offset = args.remove(1).int()? as usize;
|
||||||
let storage = args.remove(0).persistent_load()?;
|
let storage = args.remove(0).persistent_load()?;
|
||||||
let mut storage = storage.tuple()?;
|
let mut storage = storage.tuple()?;
|
||||||
|
let path = storage.remove(2).unicode()?;
|
||||||
let (_module_name, class_name) = storage.remove(1).class()?;
|
let (_module_name, class_name) = storage.remove(1).class()?;
|
||||||
let dtype = match class_name.as_str() {
|
let dtype = match class_name.as_str() {
|
||||||
"FloatStorage" => DType::F32,
|
"FloatStorage" => DType::F32,
|
||||||
@ -502,12 +509,19 @@ fn rebuild_args(args: Object) -> Result<(Vec<usize>, DType)> {
|
|||||||
crate::bail!("unsupported storage type {other}")
|
crate::bail!("unsupported storage type {other}")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
Ok((size, dtype))
|
let layout = Layout::new(crate::Shape::from(size), stride, offset);
|
||||||
|
Ok((layout, dtype, path))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
#[derive(Debug, Clone)]
|
||||||
file: P,
|
pub struct TensorInfo {
|
||||||
) -> Result<Vec<(String, DType, Vec<usize>)>> {
|
pub name: String,
|
||||||
|
pub dtype: DType,
|
||||||
|
pub layout: Layout,
|
||||||
|
pub path: std::path::PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> {
|
||||||
let file = std::fs::File::open(file)?;
|
let file = std::fs::File::open(file)?;
|
||||||
let zip_reader = std::io::BufReader::new(file);
|
let zip_reader = std::io::BufReader::new(file);
|
||||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
@ -516,26 +530,44 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
.map(|f| f.to_string())
|
.map(|f| f.to_string())
|
||||||
.collect::<Vec<String>>();
|
.collect::<Vec<String>>();
|
||||||
|
|
||||||
let mut tensor_info = vec![];
|
let mut tensor_infos = vec![];
|
||||||
for name in zip_file_names.iter() {
|
for file_name in zip_file_names.iter() {
|
||||||
if !name.ends_with("data.pkl") {
|
if !file_name.ends_with("data.pkl") {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let reader = zip.by_name(name)?;
|
let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap());
|
||||||
|
let reader = zip.by_name(file_name)?;
|
||||||
let mut reader = std::io::BufReader::new(reader);
|
let mut reader = std::io::BufReader::new(reader);
|
||||||
let mut stack = Stack::empty();
|
let mut stack = Stack::empty();
|
||||||
stack.read_loop(&mut reader)?;
|
stack.read_loop(&mut reader)?;
|
||||||
let obj = stack.finalize()?;
|
let obj = stack.finalize()?;
|
||||||
|
if VERBOSE {
|
||||||
|
println!("{obj:?}");
|
||||||
|
}
|
||||||
if let Object::Dict(key_values) = obj {
|
if let Object::Dict(key_values) = obj {
|
||||||
for (key, value) in key_values.into_iter() {
|
for (name, value) in key_values.into_iter() {
|
||||||
let key = match key.unicode() {
|
let name = match name.unicode() {
|
||||||
Ok(key) => key,
|
Ok(name) => name,
|
||||||
Err(_) => continue,
|
Err(_) => continue,
|
||||||
};
|
};
|
||||||
let (callable, args) = match value.reduce() {
|
let (callable, args) = match value.reduce() {
|
||||||
Ok(callable_args) => callable_args,
|
Ok(callable_args) => callable_args,
|
||||||
_ => continue,
|
_ => continue,
|
||||||
};
|
};
|
||||||
|
let (callable, args) = match callable {
|
||||||
|
Object::Class {
|
||||||
|
module_name,
|
||||||
|
class_name,
|
||||||
|
} if module_name == "torch._tensor"
|
||||||
|
&& class_name == "_rebuild_from_type_v2" =>
|
||||||
|
{
|
||||||
|
let mut args = args.tuple()?;
|
||||||
|
let callable = args.remove(0);
|
||||||
|
let args = args.remove(1);
|
||||||
|
(callable, args)
|
||||||
|
}
|
||||||
|
_ => (callable, args),
|
||||||
|
};
|
||||||
match callable {
|
match callable {
|
||||||
Object::Class {
|
Object::Class {
|
||||||
module_name,
|
module_name,
|
||||||
@ -544,13 +576,22 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
_ => continue,
|
_ => continue,
|
||||||
};
|
};
|
||||||
match rebuild_args(args) {
|
match rebuild_args(args) {
|
||||||
Ok((size, dtype)) => tensor_info.push((key, dtype, size)),
|
Ok((layout, dtype, file_path)) => {
|
||||||
|
let mut path = dir_name.clone();
|
||||||
|
path.push(file_path);
|
||||||
|
tensor_infos.push(TensorInfo {
|
||||||
|
name,
|
||||||
|
dtype,
|
||||||
|
layout,
|
||||||
|
path,
|
||||||
|
})
|
||||||
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
eprintln!("skipping {key}: {err:?}")
|
eprintln!("skipping {name}: {err:?}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(tensor_info)
|
Ok(tensor_infos)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user