Retrieve tensor data from PyTorch files. (#516)

This commit is contained in:
Laurent Mazare
2023-08-19 15:57:18 +01:00
committed by GitHub
parent 607ffb9f1e
commit 6431140250
3 changed files with 65 additions and 9 deletions

View File

@ -45,7 +45,7 @@ struct Args {
command: Command, command: Command,
} }
fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> { fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
let format = match format { let format = match format {
Some(format) => format, Some(format) => format,
None => match Format::infer(file) { None => match Format::infer(file) {
@ -91,12 +91,14 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> {
tensors.sort_by(|a, b| a.name.cmp(&b.name)); tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() { for tensor_info in tensors.iter() {
println!( println!(
"{}: [{:?}; {:?}] {:?}", "{}: [{:?}; {:?}]",
tensor_info.name, tensor_info.name,
tensor_info.layout.shape(), tensor_info.layout.shape(),
tensor_info.dtype, tensor_info.dtype,
tensor_info.path, );
) if verbose {
println!(" {:?}", tensor_info);
}
} }
} }
Format::Pickle => { Format::Pickle => {
@ -130,7 +132,7 @@ fn main() -> anyhow::Result<()> {
if multiple_files { if multiple_files {
println!("--- {file:?} ---"); println!("--- {file:?} ---");
} }
run_ls(file, format.clone())? run_ls(file, format.clone(), args.verbose)?
} }
} }
} }

View File

@ -196,7 +196,11 @@ impl Header {
impl Tensor { impl Tensor {
// TODO: Add the possibility to read directly to a device? // TODO: Add the possibility to read directly to a device?
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> { pub(crate) fn from_reader<R: std::io::Read>(
shape: Shape,
dtype: DType,
reader: &mut R,
) -> Result<Self> {
let elem_count = shape.elem_count(); let elem_count = shape.elem_count();
match dtype { match dtype {
DType::BF16 => { DType::BF16 => {

View File

@ -1,7 +1,7 @@
// Just enough pickle support to be able to read PyTorch checkpoints. // Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more // This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point. // composable/tensor agnostic at some point.
use crate::{DType, Error as E, Layout, Result}; use crate::{DType, Error as E, Layout, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt}; use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap; use std::collections::HashMap;
use std::io::BufRead; use std::io::BufRead;
@ -518,7 +518,7 @@ pub struct TensorInfo {
pub name: String, pub name: String,
pub dtype: DType, pub dtype: DType,
pub layout: Layout, pub layout: Layout,
pub path: std::path::PathBuf, pub path: String,
} }
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> { pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> {
@ -583,7 +583,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
name, name,
dtype, dtype,
layout, layout,
path, path: path.to_string_lossy().into_owned(),
}) })
} }
Err(err) => { Err(err) => {
@ -595,3 +595,53 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
} }
Ok(tensor_infos) Ok(tensor_infos)
} }
/// Lazy tensor loader.
pub struct PthTensors {
tensor_infos: HashMap<String, TensorInfo>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
// re-create a zip reader for each tensor.
}
impl PthTensors {
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref())?;
let tensor_infos = tensor_infos
.into_iter()
.map(|ti| (ti.name.to_string(), ti))
.collect();
let path = path.as_ref().to_owned();
Ok(Self { tensor_infos, path })
}
pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
&self.tensor_infos
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None),
Some(tensor_info) => tensor_info,
};
// We hope that the file has not changed since first reading it.
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?;
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
// For now only support the basic case.
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
crate::bail!(
"cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout
)
}
let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(),
tensor_info.dtype,
&mut reader,
)?;
Ok(Some(tensor))
}
}