From 18ea92d83b6dca5c256068125237e6e7f4327665 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 15 Jul 2023 21:30:35 +0100 Subject: [PATCH] Iteration over strided blocks (#175) * Introduce the strided blocks. * Use the strided blocks to fasten the copy. * Add more testing. --- candle-core/src/cpu_backend.rs | 32 +++++-- candle-core/src/layout.rs | 33 ++++++- candle-core/src/lib.rs | 2 +- candle-core/src/strided_index.rs | 42 ++++++--- candle-core/src/tensor.rs | 12 ++- candle-core/tests/layout_tests.rs | 138 ++++++++++++++++++++++++++++++ 6 files changed, 236 insertions(+), 23 deletions(-) create mode 100644 candle-core/tests/layout_tests.rs diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 28b1f5b0..97e46e74 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -190,13 +190,17 @@ fn copy_strided_src_( dst_offset: usize, src_l: &Layout, ) { - match src_l.contiguous_offsets() { - Some((o_dst1, o_dst2)) => { - let elem_to_copy = (dst.len() - dst_offset).min(o_dst2 - o_dst1); - dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[o_dst1..o_dst2]) + match src_l.strided_blocks() { + crate::StridedBlocks::SingleBlock { start_offset, len } => { + let to_copy = (dst.len() - dst_offset).min(len); + dst[dst_offset..dst_offset + to_copy] + .copy_from_slice(&src[start_offset..start_offset + to_copy]) } - None => { - for (dst_index, src_index) in src_l.strided_index().enumerate() { + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len: 1, + } => { + for (dst_index, src_index) in block_start_index.enumerate() { let dst_index = dst_index + dst_offset; if dst_index >= dst.len() { break; @@ -204,6 +208,22 @@ fn copy_strided_src_( dst[dst_index] = src[src_index] } } + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + let mut dst_index = dst_offset; + for src_index in block_start_index { + let next_dst_index = dst_index + block_len; + if dst_index >= dst.len() { + break; + } + let to_copy = usize::min(block_len, dst.len() - dst_index); + dst[dst_index..dst_index + to_copy] + .copy_from_slice(&src[src_index..src_index + to_copy]); + dst_index = next_dst_index + } + } } } diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index d92864aa..22ad53cf 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -51,6 +51,8 @@ impl Layout { } /// Returns true if the data is stored in a C contiguous (aka row major) way. + /// Note that this does not implies that the start offset is 0 or that there are no extra + /// elements at the end of the storage. pub fn is_contiguous(&self) -> bool { self.shape.is_contiguous(&self.stride) } @@ -146,6 +148,35 @@ impl Layout { } pub(crate) fn strided_index(&self) -> crate::StridedIndex { - crate::StridedIndex::new(self) + crate::StridedIndex::from_layout(self) + } + + pub(crate) fn strided_blocks(&self) -> crate::StridedBlocks { + let mut block_len = 1; + let mut contiguous_dims = 0; // These are counted from the right. + for (&stride, &dim) in self.stride().iter().zip(self.dims().iter()).rev() { + if stride != block_len { + break; + } + block_len *= dim; + contiguous_dims += 1; + } + let index_dims = self.dims().len() - contiguous_dims; + if index_dims == 0 { + crate::StridedBlocks::SingleBlock { + start_offset: self.start_offset, + len: block_len, + } + } else { + let block_start_index = crate::StridedIndex::new( + &self.dims()[..index_dims], + &self.stride[..index_dims], + self.start_offset, + ); + crate::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } + } } } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index f11bad6e..bb5ecb01 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -67,7 +67,7 @@ pub use indexer::IndexOp; pub use layout::Layout; pub use shape::{Shape, D}; pub use storage::Storage; -use strided_index::StridedIndex; +pub use strided_index::{StridedBlocks, StridedIndex}; pub use tensor::{Tensor, TensorId}; pub use variable::Var; diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index e6d2868b..455b903c 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -3,28 +3,35 @@ use crate::Layout; /// An iterator over offset position for items of an N-dimensional arrays stored in a /// flat buffer using some potential strides. #[derive(Debug)] -pub(crate) struct StridedIndex<'a> { +pub struct StridedIndex<'a> { next_storage_index: Option, multi_index: Vec, - layout: &'a Layout, + dims: &'a [usize], + stride: &'a [usize], + start_offset: usize, } impl<'a> StridedIndex<'a> { - pub(crate) fn new(layout: &'a Layout) -> Self { - let dims = layout.dims(); + pub(crate) fn new(dims: &'a [usize], stride: &'a [usize], start_offset: usize) -> Self { let elem_count: usize = dims.iter().product(); let next_storage_index = if elem_count == 0 { None } else { // This applies to the scalar case. - Some(layout.start_offset()) + Some(start_offset) }; StridedIndex { next_storage_index, multi_index: vec![0; dims.len()], - layout, + dims, + stride, + start_offset, } } + + pub(crate) fn from_layout(l: &'a Layout) -> Self { + Self::new(l.dims(), l.stride(), l.start_offset()) + } } impl<'a> Iterator for StridedIndex<'a> { @@ -36,12 +43,7 @@ impl<'a> Iterator for StridedIndex<'a> { Some(storage_index) => storage_index, }; let mut updated = false; - for (multi_i, max_i) in self - .multi_index - .iter_mut() - .zip(self.layout.dims().iter()) - .rev() - { + for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() { let next_i = *multi_i + 1; if next_i < *max_i { *multi_i = next_i; @@ -55,10 +57,10 @@ impl<'a> Iterator for StridedIndex<'a> { let next_storage_index = self .multi_index .iter() - .zip(self.layout.stride().iter()) + .zip(self.stride.iter()) .map(|(&x, &y)| x * y) .sum::() - + self.layout.start_offset(); + + self.start_offset; Some(next_storage_index) } else { None @@ -66,3 +68,15 @@ impl<'a> Iterator for StridedIndex<'a> { Some(storage_index) } } + +#[derive(Debug)] +pub enum StridedBlocks<'a> { + SingleBlock { + start_offset: usize, + len: usize, + }, + MultipleBlocks { + block_start_index: StridedIndex<'a>, + block_len: usize, + }, +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index c048790c..a93514fc 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -834,10 +834,20 @@ impl Tensor { Ok(from_storage(storage, shape, op, false)) } - pub(crate) fn strided_index(&self) -> crate::StridedIndex { + /// Returns an iterator over position of the elements in the storage when ranging over the + /// index tuples in lexicographic order. + pub fn strided_index(&self) -> crate::StridedIndex { self.layout.strided_index() } + /// Similar to `strided_index` but returns the position of the start of each contiguous block + /// as well as the length of the contiguous blocks. For a contiguous tensor, the index iterator + /// will only return the start offset and the size would be the number of elements in the + /// tensor. + pub fn strided_blocks(&self) -> crate::StridedBlocks { + self.layout.strided_blocks() + } + /// Returns the data contained in a 1D tensor as a vector of scalar values. pub fn to_vec1(&self) -> Result> { if self.rank() != 1 { diff --git a/candle-core/tests/layout_tests.rs b/candle-core/tests/layout_tests.rs new file mode 100644 index 00000000..29b3b5c0 --- /dev/null +++ b/candle-core/tests/layout_tests.rs @@ -0,0 +1,138 @@ +mod test_utils; +use candle::{Device, IndexOp, Result, Tensor}; + +fn contiguous(device: &Device) -> Result<()> { + let tensor = Tensor::arange(0u32, 24u32, device)?.reshape((2, 3, 4))?; + assert_eq!( + tensor.to_vec3::()?, + &[ + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ] + ); + assert_eq!( + tensor.t()?.contiguous()?.to_vec3::()?, + &[ + [[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11]], + [[12, 16, 20], [13, 17, 21], [14, 18, 22], [15, 19, 23]] + ] + ); + assert_eq!( + tensor.transpose(0, 1)?.contiguous()?.to_vec3::()?, + &[ + [[0, 1, 2, 3], [12, 13, 14, 15]], + [[4, 5, 6, 7], [16, 17, 18, 19]], + [[8, 9, 10, 11], [20, 21, 22, 23]] + ] + ); + assert_eq!( + tensor.transpose(0, 1)?.flatten_all()?.to_vec1::()?, + &[0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 16, 17, 18, 19, 8, 9, 10, 11, 20, 21, 22, 23] + ); + assert_eq!( + tensor + .i(1..)? + .transpose(0, 1)? + .contiguous()? + .to_vec3::()?, + &[[[12, 13, 14, 15]], [[16, 17, 18, 19]], [[20, 21, 22, 23]]] + ); + assert_eq!( + tensor.transpose(0, 2)?.contiguous()?.to_vec3::()?, + &[ + [[0, 12], [4, 16], [8, 20]], + [[1, 13], [5, 17], [9, 21]], + [[2, 14], [6, 18], [10, 22]], + [[3, 15], [7, 19], [11, 23]] + ] + ); + Ok(()) +} + +test_device!(contiguous, contiguous_cpu, contiguous_gpu); + +#[test] +fn strided_blocks() -> Result<()> { + use candle::Device::Cpu; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + match tensor.strided_blocks() { + candle::StridedBlocks::SingleBlock { start_offset, len } => { + assert_eq!(start_offset, 0); + assert_eq!(len, 24); + } + candle::StridedBlocks::MultipleBlocks { .. } => { + panic!("unexpected block structure") + } + }; + let tensor = Tensor::arange(0u32, 26u32, &Cpu)? + .i(2..)? + .reshape((2, 3, 4))?; + match tensor.strided_blocks() { + candle::StridedBlocks::SingleBlock { start_offset, len } => { + assert_eq!(start_offset, 2); + assert_eq!(len, 24); + } + candle::StridedBlocks::MultipleBlocks { .. } => { + panic!("unexpected block structure") + } + }; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + let tensor = tensor.i(1)?; + match tensor.strided_blocks() { + candle::StridedBlocks::SingleBlock { start_offset, len } => { + assert_eq!(start_offset, 12); + assert_eq!(len, 12); + } + candle::StridedBlocks::MultipleBlocks { .. } => { + panic!("unexpected block structure") + } + }; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + let tensor = tensor.i((.., 1))?; + match tensor.strided_blocks() { + candle::StridedBlocks::SingleBlock { start_offset, len } => { + assert_eq!(start_offset, 0); + assert_eq!(len, 8); + assert_eq!(tensor.to_vec2::()?, &[[4, 5, 6, 7], [16, 17, 18, 19]]); + } + candle::StridedBlocks::MultipleBlocks { .. } => { + panic!("unexpected block structure") + } + }; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + match tensor.t()?.strided_blocks() { + candle::StridedBlocks::SingleBlock { .. } => { + panic!("unexpected block structure") + } + candle::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + assert_eq!(block_len, 1); + assert_eq!( + block_start_index.collect::>(), + &[ + 0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11, 12, 16, 20, 13, 17, 21, 14, 18, 22, 15, + 19, 23 + ] + ) + } + }; + let tensor = Tensor::arange(0u32, 24u32, &Cpu)?.reshape((2, 3, 4))?; + match tensor.transpose(0, 1)?.strided_blocks() { + candle::StridedBlocks::SingleBlock { .. } => { + panic!("unexpected block structure") + } + candle::StridedBlocks::MultipleBlocks { + block_start_index, + block_len, + } => { + assert_eq!(block_len, 4); + assert_eq!( + block_start_index.collect::>(), + &[0, 12, 4, 16, 8, 20] + ) + } + }; + Ok(()) +}