diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 2c189131..c7b9e434 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -736,10 +736,12 @@ impl PthTensors { let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?); let mut zip = zip::ZipArchive::new(zip_reader)?; let mut reader = zip.by_name(&tensor_info.path)?; + let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous(); + let rank = tensor_info.layout.shape().rank(); // Reading the data is a bit tricky as it can be strided, for now only support the basic - // case. - if !tensor_info.layout.is_contiguous() { + // case and when the tensor is fortran contiguous. + if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous { crate::bail!( "cannot retrieve non-contiguous tensors {:?}", tensor_info.layout @@ -757,7 +759,19 @@ impl PthTensors { tensor_info.dtype, &mut reader, )?; - Ok(Some(tensor)) + + if rank > 1 && is_fortran_contiguous { + // Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2) + let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect(); + let tensor = tensor.reshape(shape_reversed)?; + + // Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4) + let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect(); + let tensor = tensor.permute(dim_indeces_reversed)?; + Ok(Some(tensor)) + } else { + Ok(Some(tensor)) + } } } diff --git a/candle-core/tests/fortran_tensor_3d.pth b/candle-core/tests/fortran_tensor_3d.pth new file mode 100644 index 00000000..bd50b03d Binary files /dev/null and b/candle-core/tests/fortran_tensor_3d.pth differ diff --git a/candle-core/tests/pth.py b/candle-core/tests/pth.py index cab94f2c..5c787c20 100644 --- a/candle-core/tests/pth.py +++ b/candle-core/tests/pth.py @@ -5,6 +5,33 @@ from collections import OrderedDict a= torch.tensor([[1,2,3,4], [5,6,7,8]]) o = OrderedDict() o["test"] = a + +# Write a trivial tensor to a pt file torch.save(o, "test.pt") +############################################################################################################ +# Write a trivial tensor to a pt file with a key torch.save({"model_state_dict": o}, "test_with_key.pt") + +############################################################################################################ +# Create a tensor with fortran contiguous memory layout +import numpy as np + +# Step 1: Create a 3D NumPy array with Fortran order using a range of numbers +# For example, creating a 2x3x4 array +array_fortran = np.asfortranarray(np.arange(1, 2*3*4 + 1).reshape(2, 3, 4)) + +# Verify the memory order +print("Is Fortran contiguous (F order):", array_fortran.flags['F_CONTIGUOUS']) # Should be True +print("Is C contiguous (C order):", array_fortran.flags['C_CONTIGUOUS']) # Should be False + +# Step 2: Convert the NumPy array to a PyTorch tensor +tensor_fortran = torch.from_numpy(array_fortran) + +# Verify the tensor layout +print("Tensor stride:", tensor_fortran.stride()) # Stride will reflect the Fortran memory layout + +# Step 3: Save the PyTorch tensor to a .pth file +torch.save({"tensor_fortran": tensor_fortran}, 'fortran_tensor_3d.pth') + +print("3D Tensor saved with Fortran layout.") diff --git a/candle-core/tests/pth_tests.rs b/candle-core/tests/pth_tests.rs index ad788ed9..9521f9a0 100644 --- a/candle-core/tests/pth_tests.rs +++ b/candle-core/tests/pth_tests.rs @@ -12,3 +12,20 @@ fn test_pth_with_key() { .unwrap(); tensors.get("test").unwrap().unwrap(); } + +#[test] +fn test_pth_fortran_congiguous() { + let tensors = + candle_core::pickle::PthTensors::new("tests/fortran_tensor_3d.pth", None).unwrap(); + let tensor = tensors.get("tensor_fortran").unwrap().unwrap(); + + assert_eq!(tensor.dims3().unwrap(), (2, 3, 4)); + + assert_eq!( + tensor.to_vec3::().unwrap(), + [ + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]] + ] + ); +}