mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
Add support for loading Fortran contiguous tensors (#1672)
* Add support for loading Fortran contiguous tensors This commit introduces the ability to handle Fortran contiguous tensors in the tensor loading process. Previously, the code only supported loading tensors that were contiguous in memory, failing with an error for non-contiguous tensors. With this update, tensors identified as Fortran contiguous (column-major order) are now correctly handled by reversing their dimensions after loading. This enhancement ensures broader compatibility with different tensor layouts, improving the robustness of tensor loading operations. - Check if a tensor is Fortran contiguous using the `is_fortran_contiguous` flag. - For Fortran contiguous tensors, reverse the dimensions after loading to correctly represent their layout in memory. - Continue to bail out with an error for tensors that are neither C contiguous nor Fortran contiguous, maintaining the previous behavior for non-contiguous tensors without explicit support. This change addresses the issue of loading Fortran contiguous tensors, which was previously unsupported, thereby extending the functionality of the tensor loading mechanism to accommodate a wider variety of tensor layouts. * Add reshape step to handle fortran contiguous case * Skip fortran contiguous fix if rank is < 2 * Fail on rank 0, 1 if contiguous
This commit is contained in:

committed by
GitHub

parent
b75e8945bc
commit
e5eb9602d0
@ -736,10 +736,12 @@ impl PthTensors {
|
||||
let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?);
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous();
|
||||
let rank = tensor_info.layout.shape().rank();
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||
// case.
|
||||
if !tensor_info.layout.is_contiguous() {
|
||||
// case and when the tensor is fortran contiguous.
|
||||
if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
@ -757,7 +759,19 @@ impl PthTensors {
|
||||
tensor_info.dtype,
|
||||
&mut reader,
|
||||
)?;
|
||||
|
||||
if rank > 1 && is_fortran_contiguous {
|
||||
// Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2)
|
||||
let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect();
|
||||
let tensor = tensor.reshape(shape_reversed)?;
|
||||
|
||||
// Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4)
|
||||
let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect();
|
||||
let tensor = tensor.permute(dim_indeces_reversed)?;
|
||||
Ok(Some(tensor))
|
||||
} else {
|
||||
Ok(Some(tensor))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
BIN
candle-core/tests/fortran_tensor_3d.pth
Normal file
BIN
candle-core/tests/fortran_tensor_3d.pth
Normal file
Binary file not shown.
@ -5,6 +5,33 @@ from collections import OrderedDict
|
||||
a= torch.tensor([[1,2,3,4], [5,6,7,8]])
|
||||
o = OrderedDict()
|
||||
o["test"] = a
|
||||
|
||||
# Write a trivial tensor to a pt file
|
||||
torch.save(o, "test.pt")
|
||||
|
||||
############################################################################################################
|
||||
# Write a trivial tensor to a pt file with a key
|
||||
torch.save({"model_state_dict": o}, "test_with_key.pt")
|
||||
|
||||
############################################################################################################
|
||||
# Create a tensor with fortran contiguous memory layout
|
||||
import numpy as np
|
||||
|
||||
# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers
|
||||
# For example, creating a 2x3x4 array
|
||||
array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4))
|
||||
|
||||
# Verify the memory order
|
||||
print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True
|
||||
print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False
|
||||
|
||||
# Step 2: Convert the NumPy array to a PyTorch tensor
|
||||
tensor_fortran = torch.from_numpy(array_fortran)
|
||||
|
||||
# Verify the tensor layout
|
||||
print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout
|
||||
|
||||
# Step 3: Save the PyTorch tensor to a .pth file
|
||||
torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth')
|
||||
|
||||
print("3D Tensor saved with Fortran layout.")
|
||||
|
@ -12,3 +12,20 @@ fn test_pth_with_key() {
|
||||
.unwrap();
|
||||
tensors.get("test").unwrap().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pth_fortran_congiguous() {
|
||||
let tensors =
|
||||
candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap();
|
||||
let tensor = tensors.get("tensor_fortran").unwrap().unwrap();
|
||||
|
||||
assert_eq!(tensor.dims3().unwrap(), (2, 3, 4));
|
||||
|
||||
assert_eq!(
|
||||
tensor.to_vec3::<i64>().unwrap(),
|
||||
[
|
||||
[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||
[[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
Reference in New Issue
Block a user