diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 0013113a..4a2c65fd 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -193,6 +193,50 @@ impl Object { _ => Err(self), } } + + pub fn into_tensor_info( + self, + name: Self, + dir_name: &std::path::Path, + ) -> Result> { + let name = match name.unicode() { + Ok(name) => name, + Err(_) => return Ok(None), + }; + let (callable, args) = match self.reduce() { + Ok(callable_args) => callable_args, + _ => return Ok(None), + }; + let (callable, args) = match callable { + Object::Class { + module_name, + class_name, + } if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => { + let mut args = args.tuple()?; + let callable = args.remove(0); + let args = args.remove(1); + (callable, args) + } + _ => (callable, args), + }; + match callable { + Object::Class { + module_name, + class_name, + } if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {} + _ => return Ok(None), + }; + let (layout, dtype, file_path, storage_size) = rebuild_args(args)?; + let mut path = dir_name.to_path_buf(); + path.push(file_path); + Ok(Some(TensorInfo { + name, + dtype, + layout, + path: path.to_string_lossy().into_owned(), + storage_size, + })) + } } impl TryFrom for String { @@ -623,50 +667,10 @@ pub fn read_pth_tensor_info>( }; if let Object::Dict(key_values) = obj { for (name, value) in key_values.into_iter() { - let name = match name.unicode() { - Ok(name) => name, - Err(_) => continue, - }; - let (callable, args) = match value.reduce() { - Ok(callable_args) => callable_args, - _ => continue, - }; - let (callable, args) = match callable { - Object::Class { - module_name, - class_name, - } if module_name == "torch._tensor" - && class_name == "_rebuild_from_type_v2" => - { - let mut args = args.tuple()?; - let callable = args.remove(0); - let args = args.remove(1); - (callable, args) - } - _ => (callable, args), - }; - match callable { - Object::Class { - module_name, - class_name, - } if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {} - _ => continue, - }; - match rebuild_args(args) { - Ok((layout, dtype, file_path, storage_size)) => { - let mut path = dir_name.clone(); - path.push(file_path); - tensor_infos.push(TensorInfo { - name, - dtype, - layout, - path: path.to_string_lossy().into_owned(), - storage_size, - }) - } - Err(err) => { - eprintln!("skipping {name}: {err:?}") - } + match value.into_tensor_info(name, &dir_name) { + Ok(Some(tensor_info)) => tensor_infos.push(tensor_info), + Ok(None) => {} + Err(err) => eprintln!("skipping: {err:?}"), } } }