From e5eb9602d0eb385c53e7c1dd92687d732bb038e9 Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Wed, 7 Feb 2024 14:49:59 -0600 Subject: [PATCH] Add support for loading Fortran contiguous tensors (#1672) * Add support for loading Fortran contiguous tensors This commit introduces the ability to handle Fortran contiguous tensors in the tensor loading process. Previously, the code only supported loading tensors that were contiguous in memory, failing with an error for non-contiguous tensors. With this update, tensors identified as Fortran contiguous (column-major order) are now correctly handled by reversing their dimensions after loading. This enhancement ensures broader compatibility with different tensor layouts, improving the robustness of tensor loading operations. - Check if a tensor is Fortran contiguous using the `is_fortran_contiguous` flag. - For Fortran contiguous tensors, reverse the dimensions after loading to correctly represent their layout in memory. - Continue to bail out with an error for tensors that are neither C contiguous nor Fortran contiguous, maintaining the previous behavior for non-contiguous tensors without explicit support. This change addresses the issue of loading Fortran contiguous tensors, which was previously unsupported, thereby extending the functionality of the tensor loading mechanism to accommodate a wider variety of tensor layouts. * Add reshape step to handle fortran contiguous case * Skip fortran contiguous fix if rank is < 2 * Fail on rank 0, 1 if contiguous --- candle-core/src/pickle.rs | 20 +++++++++++++++--- candle-core/tests/fortran_tensor_3d.pth | Bin 0 -> 1486 bytes candle-core/tests/pth.py | 27 ++++++++++++++++++++++++ candle-core/tests/pth_tests.rs | 17 +++++++++++++++ 4 files changed, 61 insertions(+), 3 deletions(-) create mode 100644 candle-core/tests/fortran_tensor_3d.pth diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 2c189131..c7b9e434 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -736,10 +736,12 @@ impl PthTensors { let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?); let mut zip = zip::ZipArchive::new(zip_reader)?; let mut reader = zip.by_name(&tensor_info.path)?; + let is_fortran_contiguous = tensor_info.layout.is_fortran_contiguous(); + let rank = tensor_info.layout.shape().rank(); // Reading the data is a bit tricky as it can be strided, for now only support the basic - // case. - if !tensor_info.layout.is_contiguous() { + // case and when the tensor is fortran contiguous. + if !tensor_info.layout.is_contiguous() && !is_fortran_contiguous { crate::bail!( "cannot retrieve non-contiguous tensors {:?}", tensor_info.layout @@ -757,7 +759,19 @@ impl PthTensors { tensor_info.dtype, &mut reader, )?; - Ok(Some(tensor)) + + if rank > 1 && is_fortran_contiguous { + // Reverse the shape, e.g. Shape(2, 3, 4) -> Shape(4, 3, 2) + let shape_reversed: Vec<_> = tensor_info.layout.dims().iter().rev().cloned().collect(); + let tensor = tensor.reshape(shape_reversed)?; + + // Permute (transpose) the dimensions, e.g. Shape(4, 3, 2) -> Shape(2, 3, 4) + let dim_indeces_reversed: Vec<_> = (0..rank).rev().collect(); + let tensor = tensor.permute(dim_indeces_reversed)?; + Ok(Some(tensor)) + } else { + Ok(Some(tensor)) + } } } diff --git a/candle-core/tests/fortran_tensor_3d.pth b/candle-core/tests/fortran_tensor_3d.pth new file mode 100644 index 0000000000000000000000000000000000000000..bd50b03d9cd74135fa6f53e2bf39da2f7ac0879d GIT binary patch literal 1486 zcmbVMON-M`6uxcRw3C_f{b*;vl}H8cq_1fM4rq~Vu3&5tikp&2uAN|L<|elT3Zfes z`~xoi8LnNr)ZgHra3gq5ax?akg5!b1*QDoseErTjwWv}cq+BK|u|vw_We|$co{j~Z zMnO1kxqGfH?E3L5-|QF23yAmmx15YFK)0-sEj;IlAasuEF|D;O?2<9xF#&?}#<-dQ*_3!wA}nJZmn5c(c}+fVNvudj z=N{boHv9_DnCqRZ5+lR86Bb6WEM~uRZtLQCw=`7nI)k}hL5W)3v?BD(;LAu`VHkz$&w_)rOI+dOSykWlHOd3LLTKWc|6b;pN%gI@N0_} N().unwrap(), + [ + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]] + ] + ); +}