Small tweaks to tensor-tools. (#517)

This commit is contained in:
Laurent Mazare
2023-08-19 16:50:26 +01:00
committed by GitHub
parent 6431140250
commit 551409092e
2 changed files with 21 additions and 12 deletions

View File

@ -490,13 +490,14 @@ impl From<Object> for E {
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
fn rebuild_args(args: Object) -> Result<(Layout, DType, String)> {
fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
let mut args = args.tuple()?;
let stride = Vec::<usize>::try_from(args.remove(3))?;
let size = Vec::<usize>::try_from(args.remove(2))?;
let offset = args.remove(1).int()? as usize;
let storage = args.remove(0).persistent_load()?;
let mut storage = storage.tuple()?;
let storage_size = storage.remove(4).int()? as usize;
let path = storage.remove(2).unicode()?;
let (_module_name, class_name) = storage.remove(1).class()?;
let dtype = match class_name.as_str() {
@ -510,7 +511,7 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String)> {
}
};
let layout = Layout::new(crate::Shape::from(size), stride, offset);
Ok((layout, dtype, path))
Ok((layout, dtype, path, storage_size))
}
#[derive(Debug, Clone)]
@ -519,6 +520,7 @@ pub struct TensorInfo {
pub dtype: DType,
pub layout: Layout,
pub path: String,
pub storage_size: usize,
}
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> {
@ -576,7 +578,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
_ => continue,
};
match rebuild_args(args) {
Ok((layout, dtype, file_path)) => {
Ok((layout, dtype, file_path, storage_size)) => {
let mut path = dir_name.clone();
path.push(file_path);
tensor_infos.push(TensorInfo {
@ -584,6 +586,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
dtype,
layout,
path: path.to_string_lossy().into_owned(),
storage_size,
})
}
Err(err) => {