mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Refactor the pth tensor exctraction. (#1109)
This commit is contained in:
@ -193,6 +193,50 @@ impl Object {
|
|||||||
_ => Err(self),
|
_ => Err(self),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn into_tensor_info(
|
||||||
|
self,
|
||||||
|
name: Self,
|
||||||
|
dir_name: &std::path::Path,
|
||||||
|
) -> Result<Option<TensorInfo>> {
|
||||||
|
let name = match name.unicode() {
|
||||||
|
Ok(name) => name,
|
||||||
|
Err(_) => return Ok(None),
|
||||||
|
};
|
||||||
|
let (callable, args) = match self.reduce() {
|
||||||
|
Ok(callable_args) => callable_args,
|
||||||
|
_ => return Ok(None),
|
||||||
|
};
|
||||||
|
let (callable, args) = match callable {
|
||||||
|
Object::Class {
|
||||||
|
module_name,
|
||||||
|
class_name,
|
||||||
|
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
|
||||||
|
let mut args = args.tuple()?;
|
||||||
|
let callable = args.remove(0);
|
||||||
|
let args = args.remove(1);
|
||||||
|
(callable, args)
|
||||||
|
}
|
||||||
|
_ => (callable, args),
|
||||||
|
};
|
||||||
|
match callable {
|
||||||
|
Object::Class {
|
||||||
|
module_name,
|
||||||
|
class_name,
|
||||||
|
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
||||||
|
_ => return Ok(None),
|
||||||
|
};
|
||||||
|
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
|
||||||
|
let mut path = dir_name.to_path_buf();
|
||||||
|
path.push(file_path);
|
||||||
|
Ok(Some(TensorInfo {
|
||||||
|
name,
|
||||||
|
dtype,
|
||||||
|
layout,
|
||||||
|
path: path.to_string_lossy().into_owned(),
|
||||||
|
storage_size,
|
||||||
|
}))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<Object> for String {
|
impl TryFrom<Object> for String {
|
||||||
@ -623,50 +667,10 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
|
|||||||
};
|
};
|
||||||
if let Object::Dict(key_values) = obj {
|
if let Object::Dict(key_values) = obj {
|
||||||
for (name, value) in key_values.into_iter() {
|
for (name, value) in key_values.into_iter() {
|
||||||
let name = match name.unicode() {
|
match value.into_tensor_info(name, &dir_name) {
|
||||||
Ok(name) => name,
|
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
|
||||||
Err(_) => continue,
|
Ok(None) => {}
|
||||||
};
|
Err(err) => eprintln!("skipping: {err:?}"),
|
||||||
let (callable, args) = match value.reduce() {
|
|
||||||
Ok(callable_args) => callable_args,
|
|
||||||
_ => continue,
|
|
||||||
};
|
|
||||||
let (callable, args) = match callable {
|
|
||||||
Object::Class {
|
|
||||||
module_name,
|
|
||||||
class_name,
|
|
||||||
} if module_name == "torch._tensor"
|
|
||||||
&& class_name == "_rebuild_from_type_v2" =>
|
|
||||||
{
|
|
||||||
let mut args = args.tuple()?;
|
|
||||||
let callable = args.remove(0);
|
|
||||||
let args = args.remove(1);
|
|
||||||
(callable, args)
|
|
||||||
}
|
|
||||||
_ => (callable, args),
|
|
||||||
};
|
|
||||||
match callable {
|
|
||||||
Object::Class {
|
|
||||||
module_name,
|
|
||||||
class_name,
|
|
||||||
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
|
|
||||||
_ => continue,
|
|
||||||
};
|
|
||||||
match rebuild_args(args) {
|
|
||||||
Ok((layout, dtype, file_path, storage_size)) => {
|
|
||||||
let mut path = dir_name.clone();
|
|
||||||
path.push(file_path);
|
|
||||||
tensor_infos.push(TensorInfo {
|
|
||||||
name,
|
|
||||||
dtype,
|
|
||||||
layout,
|
|
||||||
path: path.to_string_lossy().into_owned(),
|
|
||||||
storage_size,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Err(err) => {
|
|
||||||
eprintln!("skipping {name}: {err:?}")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user