From 303b853098330e05fca52b772723b1de87fda788 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 13:42:23 +0100 Subject: [PATCH] Propagate the layout refactoring. --- candle-core/src/cpu_backend.rs | 90 +++++++++++++------------------- candle-core/src/layout.rs | 67 ++++++++++++++++++++++-- candle-core/src/storage.rs | 18 +++---- candle-core/src/strided_index.rs | 23 +++++--- candle-core/src/tensor.rs | 61 +++------------------- 5 files changed, 130 insertions(+), 129 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 8cafec12..8a0cdb31 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -24,21 +24,22 @@ fn wcond( f: &[T], layout_f: &Layout, ) -> Vec { - if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f) - { - let elem_count = shape.elem_count(); - let pred = &pred[..elem_count]; - let t = &t[..elem_count]; - let f = &f[..elem_count]; + if layout.is_contiguous() && layout_t.is_contiguous() && layout_f.is_contiguous() { + let elem_count = layout.shape().elem_count(); + let offset = layout.start_offset(); + let offset_t = layout_t.start_offset(); + let offset_f = layout_f.start_offset(); + let pred = &pred[offset..offset + elem_count]; + let t = &t[offset_t..offset_t + elem_count]; + let f = &f[offset_f..offset_f + elem_count]; pred.iter() .zip(t.iter().zip(f.iter())) .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) .collect::>() } else { - let dims = shape.dims(); - let it_p = StridedIndex::new(dims, stride); - let it_t = StridedIndex::new(dims, stride_t); - let it_f = StridedIndex::new(dims, stride_f); + let it_p = StridedIndex::new(layout); + let it_t = StridedIndex::new(layout_t); + let it_f = StridedIndex::new(layout_f); it_p.zip(it_t.zip(it_f)) .map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] }) .collect::>() @@ -107,13 +108,13 @@ fn binary_map T>( fn take_impl1( vs: &[T], ids: &[u32], - shape: &Shape, - stride: &[usize], + layout: &Layout, vocab_size: usize, hidden_size: usize, ) -> Result> { - let mut values = Vec::with_capacity(shape.elem_count() * hidden_size); - for index in StridedIndex::new(shape.dims(), stride) { + // TODO: Optimize for the case where ids are contiguous. + let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size); + for index in StridedIndex::new(layout) { let index = ids[index].try_into()?; if index >= vocab_size { return Err(Error::InvalidIndex { @@ -132,16 +133,14 @@ fn copy_strided_src_( src: &[T], dst: &mut [T], dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_l: &Layout, ) { - let src = &src[src_offset..]; - if src_shape.is_contiguous(src_stride) { + let src = &src[src_l.start_offset()..]; + if src_l.is_contiguous() { let elem_to_copy = (dst.len() - dst_offset).min(src.len()); dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy]) } else { - let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); + let src_indexes = StridedIndex::new(src_l); for (dst_index, src_index) in src_indexes.enumerate() { let dst_index = dst_index + dst_offset; if dst_index >= dst.len() { @@ -556,29 +555,14 @@ impl CpuStorage { &self, dst: &mut Self, dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_l: Layout, ) -> Result<()> { - if src_shape.rank() != src_stride.len() { - panic!("incoherent shape and strides {src_shape:?} {src_stride:?}") - } match (self, dst) { - (Self::U32(src), Self::U32(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::BF16(src), Self::BF16(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::F16(src), Self::F16(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::F32(src), Self::F32(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::F64(src), Self::F64(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } + (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (_, dst) => { // This should be covered by the dtype check above. return Err(Error::DTypeMismatchBinaryOp { @@ -593,34 +577,33 @@ impl CpuStorage { pub(crate) fn where_cond( &self, - shape: &Shape, - stride: &[usize], + layout: &Layout, t: &Self, - stride_t: &[usize], + layout_t: &Layout, f: &Self, - stride_f: &[usize], + layout_f: &Layout, ) -> Result { // TODO: Support types that could be casted to a boolean. let pred = self.as_slice::()?; match (t, f) { (Self::BF16(t), Self::BF16(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::BF16(data)) } (Self::F16(t), Self::F16(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::F16(data)) } (Self::F32(t), Self::F32(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::F32(data)) } (Self::F64(t), Self::F64(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::F64(data)) } (Self::U32(t), Self::U32(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::U32(data)) } _ => Err(Error::DTypeMismatchBinaryOp { @@ -631,16 +614,15 @@ impl CpuStorage { } } - pub(crate) fn embedding_impl( + pub(crate) fn embedding( &self, - shape: &Shape, - stride: &[usize], + layout: &Layout, vs: &Self, hidden_size: usize, vocab_size: usize, ) -> Result { let ids = self.as_slice::()?; - map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size) + map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size) } pub(crate) fn matmul_impl( diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index ac5fca93..08d34c4b 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -9,16 +9,20 @@ pub struct Layout { } impl Layout { - pub fn contiguous>(shape: S) -> Self { + pub fn contiguous_with_offset>(shape: S, start_offset: usize) -> Self { let shape = shape.into(); let stride = shape.stride_contiguous(); Self { shape, stride, - start_offset: 0, + start_offset, } } + pub fn contiguous>(shape: S) -> Self { + Self::contiguous_with_offset(shape, 0) + } + pub fn dims(&self) -> &[usize] { self.shape.dims() } @@ -45,7 +49,7 @@ impl Layout { self.shape.is_fortran_contiguous(&self.stride) } - pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { + pub(crate) fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { let dims = self.shape().dims(); if dim >= dims.len() { Err(Error::UnexpectedNumberOfDims { @@ -65,4 +69,61 @@ impl Layout { start_offset: self.start_offset + self.stride[dim] * start, }) } + + pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result { + let rank = self.shape.rank(); + if rank <= dim1 || rank <= dim2 { + return Err(Error::UnexpectedNumberOfDims { + expected: usize::max(dim1, dim2), + got: rank, + shape: self.shape().clone(), + }); + } + let mut stride = self.stride().to_vec(); + let mut dims = self.shape().dims().to_vec(); + dims.swap(dim1, dim2); + stride.swap(dim1, dim2); + Ok(Self { + shape: Shape::from(dims), + stride, + start_offset: self.start_offset, + }) + } + + pub fn broadcast_as>(&self, shape: S) -> Result { + let shape = shape.into(); + if shape.rank() < self.shape().rank() { + Err(Error::BroadcastIncompatibleShapes { + src_shape: self.shape().clone(), + dst_shape: shape, + })? + } + let added_dims = shape.rank() - self.shape().rank(); + let mut stride = vec![0; added_dims]; + for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..] + .iter() + .zip(self.dims().iter().zip(self.stride())) + { + let s = if dst_dim == src_dim { + src_stride + } else if src_dim != 1 { + return Err(Error::BroadcastIncompatibleShapes { + src_shape: self.shape().clone(), + dst_shape: shape, + }); + } else { + 0 + }; + stride.push(s) + } + Ok(Self { + shape, + stride, + start_offset: self.start_offset, + }) + } + + pub(crate) fn strided_index(&self) -> crate::StridedIndex { + crate::StridedIndex::new(&self) + } } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index b7c94b46..2f2d8b75 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -145,7 +145,7 @@ impl Storage { pub(crate) fn where_cond( &self, - layout: &Shape, + layout: &Layout, t: &Self, layout_t: &Layout, f: &Self, @@ -171,7 +171,7 @@ impl Storage { } } - pub(crate) fn embedding_impl( + pub(crate) fn embedding( &self, layout: &Layout, rhs: &Self, @@ -181,11 +181,11 @@ impl Storage { self.same_device(rhs, "embedding")?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { @@ -227,15 +227,11 @@ impl Storage { &self, dst: &mut Self, dst_offset: usize, - src_layout: &Layout, + src_l: &Layout, ) -> Result<()> { match (self, dst) { - (Self::Cpu(src), Self::Cpu(dst)) => { - src.copy_strided_src(dst, dst_offset, src_layout, src_offset) - } - (Self::Cuda(src), Self::Cuda(dst)) => { - Ok(src.copy_strided_src(dst, dst_offset, src_layout, src_offset)?) - } + (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l), + (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?), (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index 2a23e9ec..f8dc522f 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -1,15 +1,17 @@ +use crate::Layout; + /// An iterator over offset position for items of an N-dimensional arrays stored in a /// flat buffer using some potential strides. #[derive(Debug)] pub(crate) struct StridedIndex<'a> { next_storage_index: Option, multi_index: Vec, - dims: &'a [usize], - stride: &'a [usize], + layout: &'a Layout, } impl<'a> StridedIndex<'a> { - pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self { + pub(crate) fn new(layout: &'a Layout) -> Self { + let dims = layout.dims(); let elem_count: usize = dims.iter().product(); let next_storage_index = if elem_count == 0 { None @@ -20,8 +22,7 @@ impl<'a> StridedIndex<'a> { StridedIndex { next_storage_index, multi_index: vec![0; dims.len()], - dims, - stride, + layout, } } } @@ -35,7 +36,12 @@ impl<'a> Iterator for StridedIndex<'a> { Some(storage_index) => storage_index, }; let mut updated = false; - for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() { + for (multi_i, max_i) in self + .multi_index + .iter_mut() + .zip(self.layout.dims().iter()) + .rev() + { let next_i = *multi_i + 1; if next_i < *max_i { *multi_i = next_i; @@ -49,9 +55,10 @@ impl<'a> Iterator for StridedIndex<'a> { let next_storage_index = self .multi_index .iter() - .zip(self.stride.iter()) + .zip(self.layout.stride().iter()) .map(|(&x, &y)| x * y) - .sum(); + .sum::() + + self.layout.start_offset(); Some(next_storage_index) } else { None diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index c7862250..094c60a3 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -481,13 +481,9 @@ impl Tensor { let ids_shape = ids.shape(); let seq_len = ids_shape.r1()?; let (vocab_size, hidden_size) = rhs.shape().r2()?; - let storage = ids.storage.embedding_impl( - ids.layout(), - &ids.stride, - &rhs.storage, - hidden_size, - vocab_size, - )?; + let storage = ids + .storage + .embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?; let shape: Shape = (seq_len, hidden_size).into(); let op = if ids.track_op() || rhs.track_op() { Some(Op::Embedding(ids.clone(), rhs.clone())) @@ -498,7 +494,7 @@ impl Tensor { } pub(crate) fn strided_index(&self) -> crate::StridedIndex { - crate::StridedIndex::new(self.dims(), self.stride()) + self.layout.strided_index() } /// Returns data from the underlying storage, this does not take the strides @@ -591,7 +587,7 @@ impl Tensor { } pub fn shape(&self) -> &Shape { - &self.shape + &self.layout().shape() } pub fn dims(&self) -> &[usize] { @@ -682,18 +678,6 @@ impl Tensor { /// Returns a tensor that is a transposed version of the input, the given dimensions are /// swapped. pub fn transpose(&self, dim1: usize, dim2: usize) -> Result { - let rank = self.rank(); - if rank <= dim1 || rank <= dim2 { - return Err(Error::UnexpectedNumberOfDims { - expected: usize::max(dim1, dim2), - got: rank, - shape: self.shape().clone(), - }); - } - let mut stride = self.stride().to_vec(); - let mut dims = self.shape().dims().to_vec(); - dims.swap(dim1, dim2); - stride.swap(dim1, dim2); let op = if self.track_op() { Some(Op::Transpose(self.clone(), dim1, dim2)) } else { @@ -702,8 +686,7 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape: Shape::from(dims), - stride, + layout: self.layout.transpose(dim1, dim2)?, op, is_variable: false, }; @@ -795,36 +778,10 @@ impl Tensor { } else { None }; - let shape = shape.into(); - if shape.rank() < self.rank() { - return Err(Error::BroadcastIncompatibleShapes { - src_shape: self.shape().clone(), - dst_shape: shape, - }); - } - let added_dims = shape.rank() - self.rank(); - let mut stride = vec![0; added_dims]; - for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..] - .iter() - .zip(self.dims().iter().zip(self.stride())) - { - let s = if dst_dim == src_dim { - src_stride - } else if src_dim != 1 { - return Err(Error::BroadcastIncompatibleShapes { - src_shape: self.shape().clone(), - dst_shape: shape, - }); - } else { - 0 - }; - stride.push(s) - } let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape, - stride, + layout: self.layout.broadcast_as(shape)?, op, is_variable: false, }; @@ -888,12 +845,10 @@ impl Tensor { None }; if self.is_contiguous() { - let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape, - stride, + layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()), op, is_variable: false, };