Retrieve tensor data from PyTorch files. (#516)

This commit is contained in:
Laurent Mazare
2023-08-19 15:57:18 +01:00
committed by GitHub
parent 607ffb9f1e
commit 6431140250
3 changed files with 65 additions and 9 deletions

View File

@ -45,7 +45,7 @@ struct Args {
command: Command,
}
fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> {
fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> Result<()> {
let format = match format {
Some(format) => format,
None => match Format::infer(file) {
@ -91,12 +91,14 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> {
tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() {
println!(
"{}: [{:?}; {:?}] {:?}",
"{}: [{:?}; {:?}]",
tensor_info.name,
tensor_info.layout.shape(),
tensor_info.dtype,
tensor_info.path,
)
);
if verbose {
println!(" {:?}", tensor_info);
}
}
}
Format::Pickle => {
@ -130,7 +132,7 @@ fn main() -> anyhow::Result<()> {
if multiple_files {
println!("--- {file:?} ---");
}
run_ls(file, format.clone())?
run_ls(file, format.clone(), args.verbose)?
}
}
}

View File

@ -196,7 +196,11 @@ impl Header {
impl Tensor {
// TODO: Add the possibility to read directly to a device?
fn from_reader<R: std::io::Read>(shape: Shape, dtype: DType, reader: &mut R) -> Result<Self> {
pub(crate) fn from_reader<R: std::io::Read>(
shape: Shape,
dtype: DType,
reader: &mut R,
) -> Result<Self> {
let elem_count = shape.elem_count();
match dtype {
DType::BF16 => {

View File

@ -1,7 +1,7 @@
// Just enough pickle support to be able to read PyTorch checkpoints.
// This hardcodes objects that are required for tensor reading, we may want to make this a bit more
// composable/tensor agnostic at some point.
use crate::{DType, Error as E, Layout, Result};
use crate::{DType, Error as E, Layout, Result, Tensor};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
use std::io::BufRead;
@ -518,7 +518,7 @@ pub struct TensorInfo {
pub name: String,
pub dtype: DType,
pub layout: Layout,
pub path: std::path::PathBuf,
pub path: String,
}
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> {
@ -583,7 +583,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
name,
dtype,
layout,
path,
path: path.to_string_lossy().into_owned(),
})
}
Err(err) => {
@ -595,3 +595,53 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
}
Ok(tensor_infos)
}
/// Lazy tensor loader.
pub struct PthTensors {
tensor_infos: HashMap<String, TensorInfo>,
path: std::path::PathBuf,
// We do not store a zip reader as it needs mutable access to extract data. Instead we
// re-create a zip reader for each tensor.
}
impl PthTensors {
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref())?;
let tensor_infos = tensor_infos
.into_iter()
.map(|ti| (ti.name.to_string(), ti))
.collect();
let path = path.as_ref().to_owned();
Ok(Self { tensor_infos, path })
}
pub fn tensor_infos(&self) -> &HashMap<String, TensorInfo> {
&self.tensor_infos
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
let tensor_info = match self.tensor_infos.get(name) {
None => return Ok(None),
Some(tensor_info) => tensor_info,
};
// We hope that the file has not changed since first reading it.
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_name(&tensor_info.path)?;
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
// For now only support the basic case.
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
crate::bail!(
"cannot retrieve non-contiguous tensors {:?}",
tensor_info.layout
)
}
let tensor = Tensor::from_reader(
tensor_info.layout.shape().clone(),
tensor_info.dtype,
&mut reader,
)?;
Ok(Some(tensor))
}
}