mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Handle start-offset when loading a tensor from a pickle file. (#1546)
This commit is contained in:
@ -703,6 +703,7 @@ impl PthTensors {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||||
|
use std::io::Read;
|
||||||
let tensor_info = match self.tensor_infos.get(name) {
|
let tensor_info = match self.tensor_infos.get(name) {
|
||||||
None => return Ok(None),
|
None => return Ok(None),
|
||||||
Some(tensor_info) => tensor_info,
|
Some(tensor_info) => tensor_info,
|
||||||
@ -712,14 +713,21 @@ impl PthTensors {
|
|||||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||||
|
|
||||||
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||||
// For now only support the basic case.
|
// case.
|
||||||
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
if !tensor_info.layout.is_contiguous() {
|
||||||
crate::bail!(
|
crate::bail!(
|
||||||
"cannot retrieve non-contiguous tensors {:?}",
|
"cannot retrieve non-contiguous tensors {:?}",
|
||||||
tensor_info.layout
|
tensor_info.layout
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
let start_offset = tensor_info.layout.start_offset();
|
||||||
|
if start_offset > 0 {
|
||||||
|
std::io::copy(
|
||||||
|
&mut reader.by_ref().take(start_offset as u64),
|
||||||
|
&mut std::io::sink(),
|
||||||
|
)?;
|
||||||
|
}
|
||||||
let tensor = Tensor::from_reader(
|
let tensor = Tensor::from_reader(
|
||||||
tensor_info.layout.shape().clone(),
|
tensor_info.layout.shape().clone(),
|
||||||
tensor_info.dtype,
|
tensor_info.dtype,
|
||||||
|
Reference in New Issue
Block a user