mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Add support for loading Fortran contiguous tensors (#1672)
* Add support for loading Fortran contiguous tensors This commit introduces the ability to handle Fortran contiguous tensors in the tensor loading process. Previously, the code only supported loading tensors that were contiguous in memory, failing with an error for non-contiguous tensors. With this update, tensors identified as Fortran contiguous (column-major order) are now correctly handled by reversing their dimensions after loading. This enhancement ensures broader compatibility with different tensor layouts, improving the robustness of tensor loading operations. - Check if a tensor is Fortran contiguous using the `is_fortran_contiguous` flag. - For Fortran contiguous tensors, reverse the dimensions after loading to correctly represent their layout in memory. - Continue to bail out with an error for tensors that are neither C contiguous nor Fortran contiguous, maintaining the previous behavior for non-contiguous tensors without explicit support. This change addresses the issue of loading Fortran contiguous tensors, which was previously unsupported, thereby extending the functionality of the tensor loading mechanism to accommodate a wider variety of tensor layouts. * Add reshape step to handle fortran contiguous case * Skip fortran contiguous fix if rank is < 2 * Fail on rank 0, 1 if contiguous
This commit is contained in:

committed by
GitHub

parent
b75e8945bc
commit
e5eb9602d0
@ -736,10 +736,12 @@ impl PthTensors {
|
||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
|
||||
let rank = tensor_info.layout.shape().rank();
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||
// case.
|
||||
if !tensor_info.layout.is_contiguous() {
|
||||
// case and when the tensor is fortran contiguous.
|
||||
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
@ -757,7 +759,19 @@ impl PthTensors {
|
||||
tensor_info.dtype,
|
||||
&mut reader,
|
||||
)?;
|
||||
Ok(Some(tensor))
|
||||
|
||||
if rank > 1 && is_fortran_contiguous {
|
||||
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
|
||||
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
|
||||
let tensor = tensor.reshape(shape_reversed)?;
|
||||
|
||||
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
|
||||
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
|
||||
let tensor = tensor.permute(dim_indeces_reversed)?;
|
||||
Ok(Some(tensor))
|
||||
} else {
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user