mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Retrieve tensor data from PyTorch files. (#516)
This commit is contained in:
@ -45,7 +45,7 @@ struct Args {
|
|||||||
command: Command,
|
command: Command,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> {
|
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
|
||||||
let format = match format {
|
let format = match format {
|
||||||
Some(format) => format,
|
Some(format) => format,
|
||||||
None => match Format::infer(file) {
|
None => match Format::infer(file) {
|
||||||
@ -91,12 +91,14 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> {
|
|||||||
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
tensors.sort_by(|a, b| a.name.cmp(&b.name));
|
||||||
for tensor_info in tensors.iter() {
|
for tensor_info in tensors.iter() {
|
||||||
println!(
|
println!(
|
||||||
"{}: [{:?}; {:?}] {:?}",
|
"{}: [{:?}; {:?}]",
|
||||||
tensor_info.name,
|
tensor_info.name,
|
||||||
tensor_info.layout.shape(),
|
tensor_info.layout.shape(),
|
||||||
tensor_info.dtype,
|
tensor_info.dtype,
|
||||||
tensor_info.path,
|
);
|
||||||
)
|
if verbose {
|
||||||
|
println!(" {:?}", tensor_info);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Format::Pickle => {
|
Format::Pickle => {
|
||||||
@ -130,7 +132,7 @@ fn main() -> anyhow::Result<()> {
|
|||||||
if multiple_files {
|
if multiple_files {
|
||||||
println!("--- {file:?} ---");
|
println!("--- {file:?} ---");
|
||||||
}
|
}
|
||||||
run_ls(file, format.clone())?
|
run_ls(file, format.clone(), args.verbose)?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -196,7 +196,11 @@ impl Header {
|
|||||||
|
|
||||||
impl Tensor {
|
impl Tensor {
|
||||||
// TODO: Add the possibility to read directly to a device?
|
// TODO: Add the possibility to read directly to a device?
|
||||||
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
|
pub(crate) fn from_reader<R: std::io::Read>(
|
||||||
|
shape: Shape,
|
||||||
|
dtype: DType,
|
||||||
|
reader: &mut R,
|
||||||
|
) -> Result<Self> {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
match dtype {
|
match dtype {
|
||||||
DType::BF16 => {
|
DType::BF16 => {
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
// Just enough pickle support to be able to read PyTorch checkpoints.
|
// Just enough pickle support to be able to read PyTorch checkpoints.
|
||||||
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
|
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
|
||||||
// composable/tensor agnostic at some point.
|
// composable/tensor agnostic at some point.
|
||||||
use crate::{DType, Error as E, Layout, Result};
|
use crate::{DType, Error as E, Layout, Result, Tensor};
|
||||||
use byteorder::{LittleEndian, ReadBytesExt};
|
use byteorder::{LittleEndian, ReadBytesExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::io::BufRead;
|
use std::io::BufRead;
|
||||||
@ -518,7 +518,7 @@ pub struct TensorInfo {
|
|||||||
pub name: String,
|
pub name: String,
|
||||||
pub dtype: DType,
|
pub dtype: DType,
|
||||||
pub layout: Layout,
|
pub layout: Layout,
|
||||||
pub path: std::path::PathBuf,
|
pub path: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> {
|
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> {
|
||||||
@ -583,7 +583,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
|
|||||||
name,
|
name,
|
||||||
dtype,
|
dtype,
|
||||||
layout,
|
layout,
|
||||||
path,
|
path: path.to_string_lossy().into_owned(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
@ -595,3 +595,53 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
|
|||||||
}
|
}
|
||||||
Ok(tensor_infos)
|
Ok(tensor_infos)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Lazy tensor loader.
|
||||||
|
pub struct PthTensors {
|
||||||
|
tensor_infos: HashMap<String, TensorInfo>,
|
||||||
|
path: std::path::PathBuf,
|
||||||
|
// We do not store a zip reader as it needs mutable access to extract data. Instead we
|
||||||
|
// re-create a zip reader for each tensor.
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PthTensors {
|
||||||
|
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
|
||||||
|
let tensor_infos = read_pth_tensor_info(path.as_ref())?;
|
||||||
|
let tensor_infos = tensor_infos
|
||||||
|
.into_iter()
|
||||||
|
.map(|ti| (ti.name.to_string(), ti))
|
||||||
|
.collect();
|
||||||
|
let path = path.as_ref().to_owned();
|
||||||
|
Ok(Self { tensor_infos, path })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
|
||||||
|
&self.tensor_infos
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||||
|
let tensor_info = match self.tensor_infos.get(name) {
|
||||||
|
None => return Ok(None),
|
||||||
|
Some(tensor_info) => tensor_info,
|
||||||
|
};
|
||||||
|
// We hope that the file has not changed since first reading it.
|
||||||
|
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||||
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
|
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||||
|
|
||||||
|
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
||||||
|
// For now only support the basic case.
|
||||||
|
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
||||||
|
crate::bail!(
|
||||||
|
"cannot retrieve non-contiguous tensors {:?}",
|
||||||
|
tensor_info.layout
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let tensor = Tensor::from_reader(
|
||||||
|
tensor_info.layout.shape().clone(),
|
||||||
|
tensor_info.dtype,
|
||||||
|
&mut reader,
|
||||||
|
)?;
|
||||||
|
Ok(Some(tensor))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user