From c1bbbf94f67d91773e0e7491e0aed9a30d75144a Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 12:57:30 +0100 Subject: [PATCH 01/10] Start refactoring the stride. --- candle-core/src/cpu_backend.rs | 46 ++++++++------------- candle-core/src/layout.rs | 47 +++++++++++++++++++++ candle-core/src/lib.rs | 2 + candle-core/src/storage.rs | 63 ++++++++++++----------------- candle-core/src/tensor.rs | 74 +++++++++++++++------------------- 5 files changed, 124 insertions(+), 108 deletions(-) create mode 100644 candle-core/src/layout.rs diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 53c7ecf1..8cafec12 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,5 +1,5 @@ use crate::op::{BinaryOp, UnaryOp}; -use crate::{DType, Error, Result, Shape, StridedIndex}; +use crate::{DType, Error, Layout, Result, Shape, StridedIndex}; use gemm::{gemm, Parallelism}; use half::{bf16, f16}; @@ -18,12 +18,11 @@ pub enum CpuStorage { fn wcond( pred: &[u32], - shape: &Shape, - stride: &[usize], + layout: &Layout, t: &[T], - stride_t: &[usize], + layout_t: &Layout, f: &[T], - stride_f: &[usize], + layout_f: &Layout, ) -> Vec { if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f) { @@ -73,12 +72,7 @@ fn sum_impl1( Ok(dst) } -fn unary_map U>( - vs: &[T], - shape: &Shape, - stride: &[usize], - mut f: F, -) -> Vec { +fn unary_map U>(vs: &[T], layout: &Layout, mut f: F) -> Vec { if shape.is_contiguous(stride) { vs[..shape.elem_count()].iter().map(|&v| f(v)).collect() } else { @@ -461,65 +455,59 @@ impl CpuStorage { Ok(()) } - pub(crate) fn affine_impl( - &self, - shape: &Shape, - stride: &[usize], - mul: f64, - add: f64, - ) -> Result { + pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { match self { Self::U32(storage) => { let mul = mul as u32; let add = add as u32; - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::U32(data)) } Self::BF16(storage) => { let mul = bf16::from_f64(mul); let add = bf16::from_f64(add); - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::BF16(data)) } Self::F16(storage) => { let mul = f16::from_f64(mul); let add = f16::from_f64(add); - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::F16(data)) } Self::F32(storage) => { let mul = mul as f32; let add = add as f32; - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::F32(data)) } Self::F64(storage) => { - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::F64(data)) } } } - pub(crate) fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { match self { Self::BF16(storage) => { - let data = unary_map(storage, shape, stride, B::bf16); + let data = unary_map(storage, layout, B::bf16); Ok(Self::BF16(data)) } Self::F16(storage) => { - let data = unary_map(storage, shape, stride, B::f16); + let data = unary_map(storage, layout, B::f16); Ok(Self::F16(data)) } Self::F32(storage) => { - let data = unary_map(storage, shape, stride, B::f32); + let data = unary_map(storage, layout, B::f32); Ok(Self::F32(data)) } Self::F64(storage) => { - let data = unary_map(storage, shape, stride, B::f64); + let data = unary_map(storage, layout, B::f64); Ok(Self::F64(data)) } Self::U32(storage) => { - let data = unary_map(storage, shape, stride, B::u32); + let data = unary_map(storage, layout, B::u32); Ok(Self::U32(data)) } } diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs new file mode 100644 index 00000000..a0f3639b --- /dev/null +++ b/candle-core/src/layout.rs @@ -0,0 +1,47 @@ +use crate::Shape; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Layout { + shape: Shape, + // The strides are given in number of elements and not in bytes. + stride: Vec, + start_offset: usize, +} + +impl Layout { + pub fn contiguous>(shape: S) -> Self { + let shape = shape.into(); + let stride = shape.stride_contiguous(); + Self { + shape, + stride, + start_offset: 0, + } + } + + pub fn dims(&self) -> &[usize] { + self.shape.dims() + } + + pub fn shape(&self) -> &Shape { + &self.shape + } + + pub fn stride(&self) -> &[usize] { + &self.stride + } + + pub fn start_offset(&self) -> usize { + self.start_offset + } + + /// Returns true if the data is stored in a C contiguous (aka row major) way. + pub fn is_contiguous(&self) -> bool { + self.shape.is_contiguous(&self.stride) + } + + /// Returns true if the data is stored in a Fortran contiguous (aka column major) way. + pub fn is_fortran_contiguous(&self) -> bool { + self.shape.is_fortran_contiguous(&self.stride) + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 5771517f..6a860116 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -7,6 +7,7 @@ pub mod display; mod dtype; mod dummy_cuda_backend; mod error; +mod layout; mod npy; mod op; mod shape; @@ -19,6 +20,7 @@ pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; pub use dtype::{DType, WithDType}; pub use error::{Error, Result}; +pub use layout::Layout; pub use shape::Shape; pub use storage::Storage; use strided_index::StridedIndex; diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index e44a2db6..3bcb0ff0 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -1,4 +1,4 @@ -use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape}; +use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape}; // We do not want to implement Clone on Storage as cloning may fail because of // out of memory. Instead try_clone should be used. @@ -53,33 +53,27 @@ impl Storage { } } - pub(crate) fn affine_impl( - &self, - shape: &Shape, - stride: &[usize], - mul: f64, - add: f64, - ) -> Result { + pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { match self { Storage::Cpu(storage) => { - let storage = storage.affine_impl(shape, stride, mul, add)?; + let storage = storage.affine(layout, mul, add)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.affine_impl(shape, stride, mul, add)?; + let storage = storage.affine(layout, mul, add)?; Ok(Self::Cuda(storage)) } } } - pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result { + pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result { match self { Storage::Cpu(storage) => { - let storage = storage.sum(shape, stride, s)?; + let storage = storage.sum(layout, s)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.sum(shape, stride, s)?; + let storage = storage.sum(layout, s)?; Ok(Self::Cuda(storage)) } } @@ -93,32 +87,28 @@ impl Storage { Ok(()) } - pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { + pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { match self { Storage::Cpu(storage) => { - let storage = storage.to_dtype(shape, stride, dtype)?; + let storage = storage.to_dtype(layout, dtype)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.to_dtype(shape, stride, dtype)?; + let storage = storage.to_dtype(layout, dtype)?; Ok(Self::Cuda(storage)) } } } - pub(crate) fn unary_impl( - &self, - shape: &Shape, - stride: &[usize], - ) -> Result { + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { // TODO: Different code path for the contiguous case? match self { Storage::Cpu(storage) => { - let storage = storage.unary_impl::(shape, stride)?; + let storage = storage.unary_impl::(layout)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.unary_impl::(shape, stride)?; + let storage = storage.unary_impl::(layout)?; Ok(Self::Cuda(storage)) } } @@ -127,19 +117,18 @@ impl Storage { pub(crate) fn binary_impl( &self, rhs: &Self, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result { self.same_device(rhs, B::NAME)?; self.same_dtype(rhs, B::NAME)?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.binary_impl::(rhs, shape, lhs_stride, rhs_stride)?; + let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.binary_impl::(rhs, shape, lhs_stride, rhs_stride)?; + let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => { @@ -156,23 +145,22 @@ impl Storage { pub(crate) fn where_cond( &self, - shape: &Shape, - stride: &[usize], + layout: &Shape, t: &Self, - stride_t: &[usize], + layout_t: &Layout, f: &Self, - stride_f: &[usize], + layout_f: &Layout, ) -> Result { self.same_device(t, "where")?; self.same_device(f, "where")?; t.same_dtype(f, "where")?; match (self, t, f) { (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => { - let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?; + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; Ok(Self::Cpu(storage)) } (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => { - let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?; + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; Ok(Self::Cuda(storage)) } (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { @@ -185,8 +173,7 @@ impl Storage { pub(crate) fn embedding_impl( &self, - shape: &Shape, - stride: &[usize], + layout: &Layout, rhs: &Self, hidden_size: usize, vocab_size: usize, @@ -194,11 +181,11 @@ impl Storage { self.same_device(rhs, "embedding")?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index feb59d3c..51eeb9ae 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,4 +1,4 @@ -use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape}; +use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::Arc; /// Unique identifier for tensors. @@ -17,9 +17,7 @@ impl TensorId { pub struct Tensor_ { id: TensorId, storage: Arc, - shape: Shape, - // The strides are given in number of elements and not in bytes. - stride: Vec, + layout: Layout, op: Option, is_variable: bool, } @@ -50,7 +48,7 @@ macro_rules! unary_op { let shape = self.shape(); let storage = self .storage - .unary_impl::(self.shape(), self.stride())?; + .unary_impl::(self.layout())?; let op = if self.track_op() { Some(Op::$op_name(self.clone())) } else { @@ -67,9 +65,8 @@ macro_rules! binary_op { let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?; let storage = self.storage.binary_impl::( &rhs.storage, - shape, - self.stride(), - rhs.stride(), + self.layout(), + rhs.layout(), )?; let op = if self.track_op() || rhs.track_op() { Some(Op::$op_name(self.clone(), rhs.clone())) @@ -107,13 +104,10 @@ fn from_storage>( op: Option, is_variable: bool, ) -> Tensor { - let shape = shape.into(); - let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(storage), - shape, - stride, + layout: Layout::contiguous(shape), op, is_variable, }; @@ -342,8 +336,7 @@ impl Tensor { } pub fn affine(&self, mul: f64, add: f64) -> Result { - let shape = self.shape(); - let storage = self.storage.affine_impl(shape, self.stride(), mul, add)?; + let storage = self.storage.affine(self.layout(), mul, add)?; let op = if self.track_op() { Some(Op::Affine { arg: self.clone(), @@ -353,7 +346,7 @@ impl Tensor { } else { None }; - Ok(from_storage(storage, shape.clone(), op, false)) + Ok(from_storage(storage, self.shape(), op, false)) } /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` @@ -401,9 +394,7 @@ impl Tensor { exp.broadcast_div(&sum_exp) } else { let shape = self.shape(); - let mut storage = self - .storage - .unary_impl::(shape, self.stride())?; + let mut storage = self.storage.unary_impl::(self.layout())?; // The resulting storage is contiguous. storage.divide_by_sum_over_dim(shape, dim)?; let op = if self.track_op() { @@ -416,7 +407,7 @@ impl Tensor { } pub fn sum(&self, sum_dims: &[usize]) -> Result { - let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?; + let storage = self.storage.sum(self.layout(), sum_dims)?; let op = if self.track_op() { Some(Op::Sum(self.clone(), sum_dims.to_vec())) } else { @@ -461,8 +452,8 @@ impl Tensor { let storage = self.storage.matmul_impl( &rhs.storage, (batching, m, n, k), - self.stride(), - rhs.stride(), + self.layout(), + rhs.layout(), )?; let op = if self.track_op() || rhs.track_op() { Some(Op::Matmul(self.clone(), rhs.clone())) @@ -476,12 +467,11 @@ impl Tensor { let _shap = self.same_shape_binary_op(on_true, "where_cond")?; let shape = self.same_shape_binary_op(on_false, "where_cond")?; let storage = self.storage.where_cond( - shape, - self.stride(), + self.layout(), &on_true.storage, - on_true.stride(), + on_true.layout(), &on_false.storage, - on_false.stride(), + on_false.layout(), )?; let op = if self.track_op() || on_true.track_op() || on_false.track_op() { Some(Op::WhereCond( @@ -498,10 +488,10 @@ impl Tensor { pub fn embedding(ids: &Self, rhs: &Self) -> Result { if !rhs.is_contiguous() { return Err(Error::RequiresContiguous { op: "embedding" }); - } else if rhs.shape().rank() != 2 || ids.shape().rank() != 1 { + } else if rhs.rank() != 2 || ids.rank() != 1 { return Err(Error::ShapeMismatchBinaryOp { - lhs: ids.shape.clone(), - rhs: rhs.shape.clone(), + lhs: ids.shape().clone(), + rhs: rhs.shape().clone(), op: "embedding", }); } @@ -509,7 +499,7 @@ impl Tensor { let seq_len = ids_shape.r1()?; let (vocab_size, hidden_size) = rhs.shape().r2()?; let storage = ids.storage.embedding_impl( - ids_shape, + ids.layout(), &ids.stride, &rhs.storage, hidden_size, @@ -625,8 +615,13 @@ impl Tensor { self.shape().dims() } - pub fn stride(&self) -> &[usize] { - &self.stride + pub fn layout(&self) -> &Layout { + &self.layout + } + + // TODO: Rename to `stride` once the PR that introduced the layout has been merged. + pub fn stride_tmp(&self) -> &[usize] { + &self.layout.stride() } pub fn rank(&self) -> usize { @@ -734,12 +729,12 @@ impl Tensor { /// Returns true if the data is stored in a C contiguous (aka row major) way. pub fn is_contiguous(&self) -> bool { - self.shape.is_contiguous(&self.stride) + self.layout.is_contiguous() } /// Returns true if the data is stored in a Fortran contiguous (aka column major) way. pub fn is_fortran_contiguous(&self) -> bool { - self.shape.is_fortran_contiguous(&self.stride) + self.layout.is_fortran_contiguous() } /// Compared to clone, this copies the actual storage but may fail because of running out of @@ -748,8 +743,7 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(self.storage.try_clone()?), - shape: self.shape.clone(), - stride: self.stride.clone(), + layout: self.layout.clone(), op: None, // TODO is_variable: false, }; @@ -762,8 +756,7 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape: self.shape.clone(), - stride: self.stride.clone(), + layout: self.layout.clone(), op: None, is_variable: false, }; @@ -796,8 +789,7 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(storage), - shape: self.shape.clone(), - stride: self.stride.clone(), + layout: self.layout.clone(), op, is_variable: false, }; @@ -810,7 +802,7 @@ impl Tensor { pub fn broadcast_left>(&self, left_shape: S) -> Result { let left_shape = left_shape.into(); let mut dims = left_shape.into_dims(); - dims.extend(self.shape.dims()); + dims.extend(self.dims()); self.broadcast_as(dims) } @@ -866,7 +858,7 @@ impl Tensor { Ok(self.clone()) } else { let shape = self.shape(); - let storage = self.storage.to_dtype(shape, self.stride(), dtype)?; + let storage = self.storage.to_dtype(self.layout(), dtype)?; let op = if self.track_op() { Some(Op::ToDType(self.clone())) } else { From 30b355ccd2668db9fbf899453f159f1cdd5e85ba Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 13:09:59 +0100 Subject: [PATCH 02/10] Simplify the narrow implementation. --- candle-core/src/layout.rs | 23 +++++++++++++++++++++- candle-core/src/storage.rs | 8 +++----- candle-core/src/tensor.rs | 39 +++++++++++--------------------------- 3 files changed, 36 insertions(+), 34 deletions(-) diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index a0f3639b..ac5fca93 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -1,4 +1,4 @@ -use crate::Shape; +use crate::{Error, Result, Shape}; #[derive(Debug, PartialEq, Eq, Clone)] pub struct Layout { @@ -44,4 +44,25 @@ impl Layout { pub fn is_fortran_contiguous(&self) -> bool { self.shape.is_fortran_contiguous(&self.stride) } + + pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { + let dims = self.shape().dims(); + if dim >= dims.len() { + Err(Error::UnexpectedNumberOfDims { + expected: dim + 1, + got: dims.len(), + shape: self.shape().clone(), + })? + } + if start + length > dims[dim] { + todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}") + } + let mut dims = dims.to_vec(); + dims[dim] = length; + Ok(Self { + shape: Shape::from(dims), + stride: self.stride.clone(), + start_offset: self.start_offset + self.stride[dim] * start, + }) + } } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 3bcb0ff0..b7c94b46 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -227,16 +227,14 @@ impl Storage { &self, dst: &mut Self, dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_layout: &Layout, ) -> Result<()> { match (self, dst) { (Self::Cpu(src), Self::Cpu(dst)) => { - src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset) + src.copy_strided_src(dst, dst_offset, src_layout, src_offset) } (Self::Cuda(src), Self::Cuda(dst)) => { - Ok(src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)?) + Ok(src.copy_strided_src(dst, dst_offset, src_layout, src_offset)?) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 51eeb9ae..c7862250 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -351,37 +351,20 @@ impl Tensor { /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` /// ranges from `start` to `start + length`. - // TODO: Once we've refactored the shape and strides, make this return a view of the same data - // rather than copying. pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { - let dims = self.shape().dims(); - if dim >= dims.len() { - return Err(Error::UnexpectedNumberOfDims { - expected: dim + 1, - got: dims.len(), - shape: self.shape().clone(), - }); - } - if start + length > dims[dim] { - todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}") - } - let mut dims = dims.to_vec(); - dims[dim] = length; - let adjusted_shape = Shape::from(dims); - let mut storage = self.device().zeros(&adjusted_shape, self.dtype())?; - self.storage.copy_strided_src( - &mut storage, - /* dst_offset= */ 0, - &adjusted_shape, - &self.stride, - /* src_offest= */ self.stride[dim] * start, - )?; let op = if self.track_op() { Some(Op::Narrow(self.clone(), dim, start, length)) } else { None }; - Ok(from_storage(storage, adjusted_shape, op, false)) + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: self.layout().narrow(dim, start, length)?, + op, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) } pub fn softmax(&self, dim: usize) -> Result { @@ -875,7 +858,7 @@ impl Tensor { let shape = self.shape(); let mut storage = self.device().zeros(shape, self.dtype())?; self.storage - .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?; + .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage( storage, shape.clone(), @@ -918,7 +901,7 @@ impl Tensor { } else { let mut storage = self.device().zeros(&shape, self.dtype())?; self.storage - .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?; + .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage(storage, shape, op, false)) } } @@ -1055,7 +1038,7 @@ impl Tensor { for (arg, &offset) in args.iter().zip(offsets.iter()) { let arg = arg.as_ref(); arg.storage - .copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?; + .copy_strided_src(&mut storage, offset, arg.layout())?; } Ok(from_storage(storage, shape, op, false)) } From 303b853098330e05fca52b772723b1de87fda788 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 13:42:23 +0100 Subject: [PATCH 03/10] Propagate the layout refactoring. --- candle-core/src/cpu_backend.rs | 90 +++++++++++++------------------- candle-core/src/layout.rs | 67 ++++++++++++++++++++++-- candle-core/src/storage.rs | 18 +++---- candle-core/src/strided_index.rs | 23 +++++--- candle-core/src/tensor.rs | 61 +++------------------- 5 files changed, 130 insertions(+), 129 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 8cafec12..8a0cdb31 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -24,21 +24,22 @@ fn wcond( f: &[T], layout_f: &Layout, ) -> Vec { - if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f) - { - let elem_count = shape.elem_count(); - let pred = &pred[..elem_count]; - let t = &t[..elem_count]; - let f = &f[..elem_count]; + if layout.is_contiguous() && layout_t.is_contiguous() && layout_f.is_contiguous() { + let elem_count = layout.shape().elem_count(); + let offset = layout.start_offset(); + let offset_t = layout_t.start_offset(); + let offset_f = layout_f.start_offset(); + let pred = &pred[offset..offset + elem_count]; + let t = &t[offset_t..offset_t + elem_count]; + let f = &f[offset_f..offset_f + elem_count]; pred.iter() .zip(t.iter().zip(f.iter())) .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) .collect::>() } else { - let dims = shape.dims(); - let it_p = StridedIndex::new(dims, stride); - let it_t = StridedIndex::new(dims, stride_t); - let it_f = StridedIndex::new(dims, stride_f); + let it_p = StridedIndex::new(layout); + let it_t = StridedIndex::new(layout_t); + let it_f = StridedIndex::new(layout_f); it_p.zip(it_t.zip(it_f)) .map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] }) .collect::>() @@ -107,13 +108,13 @@ fn binary_map T>( fn take_impl1( vs: &[T], ids: &[u32], - shape: &Shape, - stride: &[usize], + layout: &Layout, vocab_size: usize, hidden_size: usize, ) -> Result> { - let mut values = Vec::with_capacity(shape.elem_count() * hidden_size); - for index in StridedIndex::new(shape.dims(), stride) { + // TODO: Optimize for the case where ids are contiguous. + let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size); + for index in StridedIndex::new(layout) { let index = ids[index].try_into()?; if index >= vocab_size { return Err(Error::InvalidIndex { @@ -132,16 +133,14 @@ fn copy_strided_src_( src: &[T], dst: &mut [T], dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_l: &Layout, ) { - let src = &src[src_offset..]; - if src_shape.is_contiguous(src_stride) { + let src = &src[src_l.start_offset()..]; + if src_l.is_contiguous() { let elem_to_copy = (dst.len() - dst_offset).min(src.len()); dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy]) } else { - let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); + let src_indexes = StridedIndex::new(src_l); for (dst_index, src_index) in src_indexes.enumerate() { let dst_index = dst_index + dst_offset; if dst_index >= dst.len() { @@ -556,29 +555,14 @@ impl CpuStorage { &self, dst: &mut Self, dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_l: Layout, ) -> Result<()> { - if src_shape.rank() != src_stride.len() { - panic!("incoherent shape and strides {src_shape:?} {src_stride:?}") - } match (self, dst) { - (Self::U32(src), Self::U32(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::BF16(src), Self::BF16(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::F16(src), Self::F16(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::F32(src), Self::F32(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::F64(src), Self::F64(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } + (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (_, dst) => { // This should be covered by the dtype check above. return Err(Error::DTypeMismatchBinaryOp { @@ -593,34 +577,33 @@ impl CpuStorage { pub(crate) fn where_cond( &self, - shape: &Shape, - stride: &[usize], + layout: &Layout, t: &Self, - stride_t: &[usize], + layout_t: &Layout, f: &Self, - stride_f: &[usize], + layout_f: &Layout, ) -> Result { // TODO: Support types that could be casted to a boolean. let pred = self.as_slice::()?; match (t, f) { (Self::BF16(t), Self::BF16(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::BF16(data)) } (Self::F16(t), Self::F16(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::F16(data)) } (Self::F32(t), Self::F32(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::F32(data)) } (Self::F64(t), Self::F64(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::F64(data)) } (Self::U32(t), Self::U32(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::U32(data)) } _ => Err(Error::DTypeMismatchBinaryOp { @@ -631,16 +614,15 @@ impl CpuStorage { } } - pub(crate) fn embedding_impl( + pub(crate) fn embedding( &self, - shape: &Shape, - stride: &[usize], + layout: &Layout, vs: &Self, hidden_size: usize, vocab_size: usize, ) -> Result { let ids = self.as_slice::()?; - map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size) + map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size) } pub(crate) fn matmul_impl( diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index ac5fca93..08d34c4b 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -9,16 +9,20 @@ pub struct Layout { } impl Layout { - pub fn contiguous>(shape: S) -> Self { + pub fn contiguous_with_offset>(shape: S, start_offset: usize) -> Self { let shape = shape.into(); let stride = shape.stride_contiguous(); Self { shape, stride, - start_offset: 0, + start_offset, } } + pub fn contiguous>(shape: S) -> Self { + Self::contiguous_with_offset(shape, 0) + } + pub fn dims(&self) -> &[usize] { self.shape.dims() } @@ -45,7 +49,7 @@ impl Layout { self.shape.is_fortran_contiguous(&self.stride) } - pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { + pub(crate) fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { let dims = self.shape().dims(); if dim >= dims.len() { Err(Error::UnexpectedNumberOfDims { @@ -65,4 +69,61 @@ impl Layout { start_offset: self.start_offset + self.stride[dim] * start, }) } + + pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result { + let rank = self.shape.rank(); + if rank <= dim1 || rank <= dim2 { + return Err(Error::UnexpectedNumberOfDims { + expected: usize::max(dim1, dim2), + got: rank, + shape: self.shape().clone(), + }); + } + let mut stride = self.stride().to_vec(); + let mut dims = self.shape().dims().to_vec(); + dims.swap(dim1, dim2); + stride.swap(dim1, dim2); + Ok(Self { + shape: Shape::from(dims), + stride, + start_offset: self.start_offset, + }) + } + + pub fn broadcast_as>(&self, shape: S) -> Result { + let shape = shape.into(); + if shape.rank() < self.shape().rank() { + Err(Error::BroadcastIncompatibleShapes { + src_shape: self.shape().clone(), + dst_shape: shape, + })? + } + let added_dims = shape.rank() - self.shape().rank(); + let mut stride = vec![0; added_dims]; + for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..] + .iter() + .zip(self.dims().iter().zip(self.stride())) + { + let s = if dst_dim == src_dim { + src_stride + } else if src_dim != 1 { + return Err(Error::BroadcastIncompatibleShapes { + src_shape: self.shape().clone(), + dst_shape: shape, + }); + } else { + 0 + }; + stride.push(s) + } + Ok(Self { + shape, + stride, + start_offset: self.start_offset, + }) + } + + pub(crate) fn strided_index(&self) -> crate::StridedIndex { + crate::StridedIndex::new(&self) + } } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index b7c94b46..2f2d8b75 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -145,7 +145,7 @@ impl Storage { pub(crate) fn where_cond( &self, - layout: &Shape, + layout: &Layout, t: &Self, layout_t: &Layout, f: &Self, @@ -171,7 +171,7 @@ impl Storage { } } - pub(crate) fn embedding_impl( + pub(crate) fn embedding( &self, layout: &Layout, rhs: &Self, @@ -181,11 +181,11 @@ impl Storage { self.same_device(rhs, "embedding")?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { @@ -227,15 +227,11 @@ impl Storage { &self, dst: &mut Self, dst_offset: usize, - src_layout: &Layout, + src_l: &Layout, ) -> Result<()> { match (self, dst) { - (Self::Cpu(src), Self::Cpu(dst)) => { - src.copy_strided_src(dst, dst_offset, src_layout, src_offset) - } - (Self::Cuda(src), Self::Cuda(dst)) => { - Ok(src.copy_strided_src(dst, dst_offset, src_layout, src_offset)?) - } + (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l), + (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?), (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index 2a23e9ec..f8dc522f 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -1,15 +1,17 @@ +use crate::Layout; + /// An iterator over offset position for items of an N-dimensional arrays stored in a /// flat buffer using some potential strides. #[derive(Debug)] pub(crate) struct StridedIndex<'a> { next_storage_index: Option, multi_index: Vec, - dims: &'a [usize], - stride: &'a [usize], + layout: &'a Layout, } impl<'a> StridedIndex<'a> { - pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self { + pub(crate) fn new(layout: &'a Layout) -> Self { + let dims = layout.dims(); let elem_count: usize = dims.iter().product(); let next_storage_index = if elem_count == 0 { None @@ -20,8 +22,7 @@ impl<'a> StridedIndex<'a> { StridedIndex { next_storage_index, multi_index: vec![0; dims.len()], - dims, - stride, + layout, } } } @@ -35,7 +36,12 @@ impl<'a> Iterator for StridedIndex<'a> { Some(storage_index) => storage_index, }; let mut updated = false; - for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() { + for (multi_i, max_i) in self + .multi_index + .iter_mut() + .zip(self.layout.dims().iter()) + .rev() + { let next_i = *multi_i + 1; if next_i < *max_i { *multi_i = next_i; @@ -49,9 +55,10 @@ impl<'a> Iterator for StridedIndex<'a> { let next_storage_index = self .multi_index .iter() - .zip(self.stride.iter()) + .zip(self.layout.stride().iter()) .map(|(&x, &y)| x * y) - .sum(); + .sum::() + + self.layout.start_offset(); Some(next_storage_index) } else { None diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index c7862250..094c60a3 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -481,13 +481,9 @@ impl Tensor { let ids_shape = ids.shape(); let seq_len = ids_shape.r1()?; let (vocab_size, hidden_size) = rhs.shape().r2()?; - let storage = ids.storage.embedding_impl( - ids.layout(), - &ids.stride, - &rhs.storage, - hidden_size, - vocab_size, - )?; + let storage = ids + .storage + .embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?; let shape: Shape = (seq_len, hidden_size).into(); let op = if ids.track_op() || rhs.track_op() { Some(Op::Embedding(ids.clone(), rhs.clone())) @@ -498,7 +494,7 @@ impl Tensor { } pub(crate) fn strided_index(&self) -> crate::StridedIndex { - crate::StridedIndex::new(self.dims(), self.stride()) + self.layout.strided_index() } /// Returns data from the underlying storage, this does not take the strides @@ -591,7 +587,7 @@ impl Tensor { } pub fn shape(&self) -> &Shape { - &self.shape + &self.layout().shape() } pub fn dims(&self) -> &[usize] { @@ -682,18 +678,6 @@ impl Tensor { /// Returns a tensor that is a transposed version of the input, the given dimensions are /// swapped. pub fn transpose(&self, dim1: usize, dim2: usize) -> Result { - let rank = self.rank(); - if rank <= dim1 || rank <= dim2 { - return Err(Error::UnexpectedNumberOfDims { - expected: usize::max(dim1, dim2), - got: rank, - shape: self.shape().clone(), - }); - } - let mut stride = self.stride().to_vec(); - let mut dims = self.shape().dims().to_vec(); - dims.swap(dim1, dim2); - stride.swap(dim1, dim2); let op = if self.track_op() { Some(Op::Transpose(self.clone(), dim1, dim2)) } else { @@ -702,8 +686,7 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape: Shape::from(dims), - stride, + layout: self.layout.transpose(dim1, dim2)?, op, is_variable: false, }; @@ -795,36 +778,10 @@ impl Tensor { } else { None }; - let shape = shape.into(); - if shape.rank() < self.rank() { - return Err(Error::BroadcastIncompatibleShapes { - src_shape: self.shape().clone(), - dst_shape: shape, - }); - } - let added_dims = shape.rank() - self.rank(); - let mut stride = vec![0; added_dims]; - for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..] - .iter() - .zip(self.dims().iter().zip(self.stride())) - { - let s = if dst_dim == src_dim { - src_stride - } else if src_dim != 1 { - return Err(Error::BroadcastIncompatibleShapes { - src_shape: self.shape().clone(), - dst_shape: shape, - }); - } else { - 0 - }; - stride.push(s) - } let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape, - stride, + layout: self.layout.broadcast_as(shape)?, op, is_variable: false, }; @@ -888,12 +845,10 @@ impl Tensor { None }; if self.is_contiguous() { - let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape, - stride, + layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()), op, is_variable: false, }; From 54a6c40f2715d4ba6018d047bd3ec678fc0f3664 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 14:00:49 +0100 Subject: [PATCH 04/10] Propagate the changes on the cpu backend. --- candle-core/src/cpu_backend.rs | 159 ++++++++++++++++----------------- candle-core/src/layout.rs | 11 +++ 2 files changed, 89 insertions(+), 81 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 8a0cdb31..4f63ea98 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -24,25 +24,25 @@ fn wcond( f: &[T], layout_f: &Layout, ) -> Vec { - if layout.is_contiguous() && layout_t.is_contiguous() && layout_f.is_contiguous() { - let elem_count = layout.shape().elem_count(); - let offset = layout.start_offset(); - let offset_t = layout_t.start_offset(); - let offset_f = layout_f.start_offset(); - let pred = &pred[offset..offset + elem_count]; - let t = &t[offset_t..offset_t + elem_count]; - let f = &f[offset_f..offset_f + elem_count]; - pred.iter() - .zip(t.iter().zip(f.iter())) - .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) - .collect::>() - } else { - let it_p = StridedIndex::new(layout); - let it_t = StridedIndex::new(layout_t); - let it_f = StridedIndex::new(layout_f); - it_p.zip(it_t.zip(it_f)) + match ( + layout.contiguous_offsets(), + layout_t.contiguous_offsets(), + layout_f.contiguous_offsets(), + ) { + (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => { + let pred = &pred[o1..o2]; + let t = &t[o_t1..o_t2]; + let f = &f[o_f1..o_f2]; + pred.iter() + .zip(t.iter().zip(f.iter())) + .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) + .collect::>() + } + _ => layout + .strided_index() + .zip(layout_t.strided_index().zip(layout_f.strided_index())) .map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] }) - .collect::>() + .collect::>(), } } @@ -62,42 +62,38 @@ macro_rules! map1 { fn sum_impl1( src: &[T], dst_shape: &Shape, - src_dims: &[usize], - stride: &[usize], + src_layout: &Layout, to_dst_index: impl Fn(usize) -> usize, ) -> Result> { let mut dst = vec![T::zero(); dst_shape.elem_count()]; - for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() { + for (unstr_index, src_index) in src_layout.strided_index().enumerate() { dst[to_dst_index(unstr_index)] += src[src_index]; } Ok(dst) } fn unary_map U>(vs: &[T], layout: &Layout, mut f: F) -> Vec { - if shape.is_contiguous(stride) { - vs[..shape.elem_count()].iter().map(|&v| f(v)).collect() - } else { - StridedIndex::new(shape.dims(), stride) - .map(|i| f(vs[i])) - .collect() + match layout.contiguous_offsets() { + Some((o1, o2)) => vs[o1..o2].iter().map(|&v| f(v)).collect(), + None => layout.strided_index().map(|i| f(vs[i])).collect(), } } // This function maps over two strided index sequences. fn binary_map T>( shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, lhs: &[T], rhs: &[T], mut f: F, ) -> Vec { let dims = shape.dims(); - if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) { + if lhs_layout.is_contiguous() && rhs_layout.is_contiguous() { (0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect() } else { - let lhs_index = StridedIndex::new(dims, lhs_stride); - let rhs_index = StridedIndex::new(dims, rhs_stride); + let lhs_index = lhs_layout.strided_index(); + let rhs_index = rhs_layout.strided_index(); lhs_index .zip(rhs_index) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) @@ -114,7 +110,7 @@ fn take_impl1( ) -> Result> { // TODO: Optimize for the case where ids are contiguous. let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size); - for index in StridedIndex::new(layout) { + for index in layout.strided_index() { let index = ids[index].try_into()?; if index >= vocab_size { return Err(Error::InvalidIndex { @@ -135,18 +131,19 @@ fn copy_strided_src_( dst_offset: usize, src_l: &Layout, ) { - let src = &src[src_l.start_offset()..]; - if src_l.is_contiguous() { - let elem_to_copy = (dst.len() - dst_offset).min(src.len()); - dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy]) - } else { - let src_indexes = StridedIndex::new(src_l); - for (dst_index, src_index) in src_indexes.enumerate() { - let dst_index = dst_index + dst_offset; - if dst_index >= dst.len() { - break; + match src_l.contiguous_offsets() { + Some((o_dst1, o_dst2)) => { + let elem_to_copy = (dst.len() - dst_offset).min(o_dst2 - o_dst1); + dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[o_dst1..o_dst2]) + } + None => { + for (dst_index, src_index) in src_l.strided_index().enumerate() { + let dst_index = dst_index + dst_offset; + if dst_index >= dst.len() { + break; + } + dst[dst_index] = src[src_index] } - dst[dst_index] = src[src_index] } } } @@ -235,114 +232,114 @@ impl CpuStorage { D::cpu_storage_as_mut_slice(self) } - pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { + pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { // TODO: find a way around the quadratic number of cases below. match (self, dtype) { (Self::U32(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) } (Self::BF16(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::BF16(data)) } (Self::F16(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v.to_f32())); + let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); Ok(Self::BF16(data)) } (Self::F32(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, bf16::from_f32); + let data = unary_map(storage, layout, bf16::from_f32); Ok(Self::BF16(data)) } (Self::F64(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, bf16::from_f64); + let data = unary_map(storage, layout, bf16::from_f64); Ok(Self::BF16(data)) } (Self::U32(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, |v| f16::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) } (Self::BF16(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, |v| f16::from_f32(v.to_f32())); + let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); Ok(Self::F16(data)) } (Self::F16(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::F16(data)) } (Self::F32(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, f16::from_f32); + let data = unary_map(storage, layout, f16::from_f32); Ok(Self::F16(data)) } (Self::F64(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, f16::from_f64); + let data = unary_map(storage, layout, f16::from_f64); Ok(Self::F16(data)) } (Self::U32(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v as f32); + let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } (Self::BF16(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v.to_f32()); + let data = unary_map(storage, layout, |v| v.to_f32()); Ok(Self::F32(data)) } (Self::F16(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v.to_f32()); + let data = unary_map(storage, layout, |v| v.to_f32()); Ok(Self::F32(data)) } (Self::F32(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::F32(data)) } (Self::F64(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v as f32); + let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } (Self::U32(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::U32(data)) } (Self::BF16(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32); + let data = unary_map(storage, layout, |v| v.to_f32() as u32); Ok(Self::U32(data)) } (Self::F16(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32); + let data = unary_map(storage, layout, |v| v.to_f32() as u32); Ok(Self::U32(data)) } (Self::F32(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v as u32); + let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } (Self::F64(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v as u32); + let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } (Self::U32(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v as f64); + let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) } (Self::BF16(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v.to_f64()); + let data = unary_map(storage, layout, |v| v.to_f64()); Ok(Self::F64(data)) } (Self::F16(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v.to_f64()); + let data = unary_map(storage, layout, |v| v.to_f64()); Ok(Self::F64(data)) } (Self::F32(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v as f64); + let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) } (Self::F64(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::F64(data)) } } } - pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result { - let src_dims = shape.dims(); + pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { + let src_dims = layout.dims(); let mut dst_dims = src_dims.to_vec(); for &sum_dim in sum_dims.iter() { dst_dims[sum_dim] = 1; @@ -368,7 +365,7 @@ impl CpuStorage { dst_index }; // TODO: Maybe provide an implementation with higher precision accumulators? - map1!(self, sum_impl1, &dst_shape, src_dims, stride, to_dst_index) + map1!(self, sum_impl1, &dst_shape, layout, to_dst_index) } pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { @@ -516,28 +513,28 @@ impl CpuStorage { &self, rhs: &Self, shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result { match (self, rhs) { (Self::BF16(lhs), Self::BF16(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::bf16); + let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::bf16); Ok(Self::BF16(data)) } (Self::F16(lhs), Self::F16(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f16); + let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f16); Ok(Self::F16(data)) } (Self::F32(lhs), Self::F32(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32); + let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f32); Ok(Self::F32(data)) } (Self::F64(lhs), Self::F64(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f64); + let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f64); Ok(Self::F64(data)) } (Self::U32(lhs), Self::U32(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::u32); + let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::u32); Ok(Self::U32(data)) } _ => { @@ -555,7 +552,7 @@ impl CpuStorage { &self, dst: &mut Self, dst_offset: usize, - src_l: Layout, + src_l: &Layout, ) -> Result<()> { match (self, dst) { (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index 08d34c4b..6ba0d79a 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -39,6 +39,17 @@ impl Layout { self.start_offset } + /// Returns the appropriate start and stop offset if the data is stored in a C + /// contiguous (aka row major) way. + pub fn contiguous_offsets(&self) -> Option<(usize, usize)> { + if self.is_contiguous() { + let start_o = self.start_offset; + Some((start_o, start_o + self.shape.elem_count())) + } else { + None + } + } + /// Returns true if the data is stored in a C contiguous (aka row major) way. pub fn is_contiguous(&self) -> bool { self.shape.is_contiguous(&self.stride) From 14449ff80c7ce59b38b5b5d4ee706fddfd8c6762 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 14:12:38 +0100 Subject: [PATCH 05/10] Get the cpu backend to compile. --- candle-core/src/cpu_backend.rs | 36 ++++++++++----------- candle-core/src/dummy_cuda_backend.rs | 46 +++++++++------------------ candle-core/src/layout.rs | 4 +-- candle-core/src/storage.rs | 11 ++++--- candle-core/src/tensor.rs | 6 ++-- 5 files changed, 44 insertions(+), 59 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 4f63ea98..1c5caa82 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,5 +1,5 @@ use crate::op::{BinaryOp, UnaryOp}; -use crate::{DType, Error, Layout, Result, Shape, StridedIndex}; +use crate::{DType, Error, Layout, Result, Shape}; use gemm::{gemm, Parallelism}; use half::{bf16, f16}; @@ -81,14 +81,13 @@ fn unary_map U>(vs: &[T], layout: &Layout, mut // This function maps over two strided index sequences. fn binary_map T>( - shape: &Shape, lhs_layout: &Layout, rhs_layout: &Layout, lhs: &[T], rhs: &[T], mut f: F, ) -> Vec { - let dims = shape.dims(); + let shape = lhs_layout.shape(); if lhs_layout.is_contiguous() && rhs_layout.is_contiguous() { (0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect() } else { @@ -148,17 +147,19 @@ fn copy_strided_src_( } } -fn matmul_impl( +fn matmul( lhs: &[T], rhs: &[T], (b, m, n, k): (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result> { let a_skip: usize = m * k; let b_skip: usize = n * k; let c_skip: usize = m * n; + let lhs_stride = lhs_layout.stride(); + let rhs_stride = rhs_layout.stride(); let rank = lhs_stride.len(); let lhs_cs = lhs_stride[rank - 1]; let lhs_rs = lhs_stride[rank - 2]; @@ -512,29 +513,28 @@ impl CpuStorage { pub(crate) fn binary_impl( &self, rhs: &Self, - shape: &Shape, lhs_layout: &Layout, rhs_layout: &Layout, ) -> Result { match (self, rhs) { (Self::BF16(lhs), Self::BF16(rhs)) => { - let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::bf16); + let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::bf16); Ok(Self::BF16(data)) } (Self::F16(lhs), Self::F16(rhs)) => { - let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f16); + let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f16); Ok(Self::F16(data)) } (Self::F32(lhs), Self::F32(rhs)) => { - let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f32); + let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f32); Ok(Self::F32(data)) } (Self::F64(lhs), Self::F64(rhs)) => { - let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f64); + let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f64); Ok(Self::F64(data)) } (Self::U32(lhs), Self::U32(rhs)) => { - let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::u32); + let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::u32); Ok(Self::U32(data)) } _ => { @@ -622,24 +622,24 @@ impl CpuStorage { map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size) } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, rhs: &Self, bmnk: (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result { match (self, rhs) { (CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => { - let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?; + let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::F16(dst)) } (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { - let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?; + let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::F32(dst)) } (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { - let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?; + let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::F64(dst)) } _ => Err(Error::DTypeMismatchBinaryOp { diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index babc6e7d..ef079812 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -1,5 +1,5 @@ #![allow(dead_code)] -use crate::{CpuStorage, DType, Error, Result, Shape}; +use crate::{CpuStorage, DType, Error, Layout, Result, Shape}; #[derive(thiserror::Error, Debug)] pub enum DummyError {} @@ -60,11 +60,11 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result { + pub(crate) fn affine(&self, _: &Layout, _: f64, _: f64) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn sum(&self, _: &Shape, _: &[usize], _: &[usize]) -> Result { + pub(crate) fn sum(&self, _: &Layout, _: &[usize]) -> Result { Err(Error::NotCompiledWithCudaSupport) } @@ -72,65 +72,49 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result { + pub(crate) fn to_dtype(&self, _: &Layout, _: DType) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn unary_impl(&self, _: &Shape, _: &[usize]) -> Result { + pub(crate) fn unary_impl(&self, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } pub(crate) fn binary_impl( &self, _: &Self, - _: &Shape, - _: &[usize], - _: &[usize], + _: &Layout, + _: &Layout, ) -> Result { Err(Error::NotCompiledWithCudaSupport) } pub(crate) fn where_cond( &self, - _: &Shape, - _: &[usize], + _: &Layout, _: &Self, - _: &[usize], + _: &Layout, _: &Self, - _: &[usize], + _: &Layout, ) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn embedding_impl( - &self, - _: &Shape, - _: &[usize], - _: &Self, - _: usize, - _: usize, - ) -> Result { + pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: usize, _: usize) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, _: &Self, _: (usize, usize, usize, usize), - _: &[usize], - _: &[usize], + _: &Layout, + _: &Layout, ) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn copy_strided_src( - &self, - _: &mut Self, - _: usize, - _: &Shape, - _: &[usize], - _: usize, - ) -> Result<()> { + pub(crate) fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } } diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index 6ba0d79a..3f629d50 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -106,7 +106,7 @@ impl Layout { if shape.rank() < self.shape().rank() { Err(Error::BroadcastIncompatibleShapes { src_shape: self.shape().clone(), - dst_shape: shape, + dst_shape: shape.clone(), })? } let added_dims = shape.rank() - self.shape().rank(); @@ -135,6 +135,6 @@ impl Layout { } pub(crate) fn strided_index(&self) -> crate::StridedIndex { - crate::StridedIndex::new(&self) + crate::StridedIndex::new(self) } } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 2f2d8b75..2c9624c7 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -79,6 +79,7 @@ impl Storage { } } + // This assumes a contiguous layout and no offset. pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { match self { Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?, @@ -196,22 +197,22 @@ impl Storage { } } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, rhs: &Self, bmnk: (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result { self.same_device(rhs, "matmul")?; self.same_dtype(rhs, "matmul")?; match (self, rhs) { (Self::Cpu(lhs), Self::Cpu(rhs)) => { - let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; + let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; + let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 094c60a3..b04e90b1 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -432,7 +432,7 @@ impl Tensor { let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); let batching: usize = a_dims[..dim - 2].iter().product(); - let storage = self.storage.matmul_impl( + let storage = self.storage.matmul( &rhs.storage, (batching, m, n, k), self.layout(), @@ -587,7 +587,7 @@ impl Tensor { } pub fn shape(&self) -> &Shape { - &self.layout().shape() + self.layout().shape() } pub fn dims(&self) -> &[usize] { @@ -600,7 +600,7 @@ impl Tensor { // TODO: Rename to `stride` once the PR that introduced the layout has been merged. pub fn stride_tmp(&self) -> &[usize] { - &self.layout.stride() + self.layout.stride() } pub fn rank(&self) -> usize { From caafef6cc14fc355af8401985e0b596b4a481bb7 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 14:32:02 +0100 Subject: [PATCH 06/10] Get the cpu tests to run. --- candle-core/src/cpu_backend.rs | 6 ++---- candle-core/src/dtype.rs | 12 ------------ candle-core/src/strided_index.rs | 2 +- candle-core/src/tensor.rs | 1 + candle-core/tests/tensor_tests.rs | 4 ++-- 5 files changed, 6 insertions(+), 19 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 1c5caa82..a5fdb826 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -229,10 +229,6 @@ impl CpuStorage { D::cpu_storage_as_slice(self) } - pub fn as_mut_slice(&mut self) -> Result<&mut [D]> { - D::cpu_storage_as_mut_slice(self) - } - pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { // TODO: find a way around the quadratic number of cases below. match (self, dtype) { @@ -581,6 +577,7 @@ impl CpuStorage { layout_f: &Layout, ) -> Result { // TODO: Support types that could be casted to a boolean. + // TODO: this should use the layout. let pred = self.as_slice::()?; match (t, f) { (Self::BF16(t), Self::BF16(f)) => { @@ -618,6 +615,7 @@ impl CpuStorage { hidden_size: usize, vocab_size: usize, ) -> Result { + // TODO: this should use the layout. let ids = self.as_slice::()?; map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size) } diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 1711c2b4..fdbfdbba 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -41,7 +41,6 @@ pub trait WithDType: Sized + Copy { } fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; - fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>; fn cpu_storage_data(s: CpuStorage) -> Result>; } @@ -75,17 +74,6 @@ macro_rules! with_dtype { }), } } - - fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> { - match s { - CpuStorage::$dtype(data) => Ok(data), - _ => Err(Error::UnexpectedDType { - expected: DType::$dtype, - got: s.dtype(), - msg: "unexpected dtype", - }), - } - } } }; } diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index f8dc522f..e6d2868b 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -17,7 +17,7 @@ impl<'a> StridedIndex<'a> { None } else { // This applies to the scalar case. - Some(0) + Some(layout.start_offset()) }; StridedIndex { next_storage_index, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index b04e90b1..93846160 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -317,6 +317,7 @@ impl Tensor { unary_op!(sqrt, Sqrt); unary_op!(gelu, Gelu); unary_op!(relu, Relu); + pub fn to_scalar(&self) -> Result { if self.rank() != 0 { return Err(Error::UnexpectedNumberOfDims { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 78ca4b05..8ac0c9f2 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -263,12 +263,12 @@ fn matmul(device: &Device) -> Result<()> { let a_tt = a.t()?.contiguous()?.t()?; assert!(!a_tt.is_contiguous()); assert_eq!(a.dims(), a_tt.dims()); - assert_eq!(a_tt.stride(), &[6, 1, 2]); + assert_eq!(a_tt.stride_tmp(), &[6, 1, 2]); let b_tt = b.t()?.contiguous()?.t()?; assert!(!b_tt.is_contiguous()); assert_eq!(b.dims(), b_tt.dims()); - assert_eq!(b_tt.stride(), &[6, 1, 3]); + assert_eq!(b_tt.stride_tmp(), &[6, 1, 3]); assert_eq!(a_tt.matmul(&b)?.to_vec3::()?, &expected); assert_eq!(a.matmul(&b_tt)?.to_vec3::()?, &expected); From 1c755c0e5bb060e1676ef6583e0d2007a5016026 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 14:33:06 +0100 Subject: [PATCH 07/10] Remove some todos. --- candle-core/src/cpu_backend.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index a5fdb826..f83bb5e6 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -577,7 +577,6 @@ impl CpuStorage { layout_f: &Layout, ) -> Result { // TODO: Support types that could be casted to a boolean. - // TODO: this should use the layout. let pred = self.as_slice::()?; match (t, f) { (Self::BF16(t), Self::BF16(f)) => { @@ -615,7 +614,6 @@ impl CpuStorage { hidden_size: usize, vocab_size: usize, ) -> Result { - // TODO: this should use the layout. let ids = self.as_slice::()?; map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size) } From cca699be6c8167f565067ceb3c940dd3c1d87503 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 15:09:15 +0100 Subject: [PATCH 08/10] Fix some cpu issue. --- candle-core/src/cpu_backend.rs | 57 ++++++++++++++++++---------------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index f83bb5e6..9f0c8602 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -81,22 +81,23 @@ fn unary_map U>(vs: &[T], layout: &Layout, mut // This function maps over two strided index sequences. fn binary_map T>( - lhs_layout: &Layout, - rhs_layout: &Layout, + lhs_l: &Layout, + rhs_l: &Layout, lhs: &[T], rhs: &[T], mut f: F, ) -> Vec { - let shape = lhs_layout.shape(); - if lhs_layout.is_contiguous() && rhs_layout.is_contiguous() { - (0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect() - } else { - let lhs_index = lhs_layout.strided_index(); - let rhs_index = rhs_layout.strided_index(); - lhs_index - .zip(rhs_index) + match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { + (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2] + .iter() + .zip(rhs[o_r1..o_r2].iter()) + .map(|(&l, &r)| f(l, r)) + .collect(), + _ => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect() + .collect(), } } @@ -151,15 +152,17 @@ fn matmul( lhs: &[T], rhs: &[T], (b, m, n, k): (usize, usize, usize, usize), - lhs_layout: &Layout, - rhs_layout: &Layout, + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result> { + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; let a_skip: usize = m * k; let b_skip: usize = n * k; let c_skip: usize = m * n; - let lhs_stride = lhs_layout.stride(); - let rhs_stride = rhs_layout.stride(); + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); let rank = lhs_stride.len(); let lhs_cs = lhs_stride[rank - 1]; let lhs_rs = lhs_stride[rank - 2]; @@ -509,28 +512,28 @@ impl CpuStorage { pub(crate) fn binary_impl( &self, rhs: &Self, - lhs_layout: &Layout, - rhs_layout: &Layout, + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result { match (self, rhs) { (Self::BF16(lhs), Self::BF16(rhs)) => { - let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::bf16); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16); Ok(Self::BF16(data)) } (Self::F16(lhs), Self::F16(rhs)) => { - let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f16); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f16); Ok(Self::F16(data)) } (Self::F32(lhs), Self::F32(rhs)) => { - let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f32); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f32); Ok(Self::F32(data)) } (Self::F64(lhs), Self::F64(rhs)) => { - let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f64); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f64); Ok(Self::F64(data)) } (Self::U32(lhs), Self::U32(rhs)) => { - let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::u32); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32); Ok(Self::U32(data)) } _ => { @@ -622,20 +625,20 @@ impl CpuStorage { &self, rhs: &Self, bmnk: (usize, usize, usize, usize), - lhs_layout: &Layout, - rhs_layout: &Layout, + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result { match (self, rhs) { (CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => { - let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?; + let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?; Ok(Self::F16(dst)) } (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { - let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?; + let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?; Ok(Self::F32(dst)) } (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { - let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?; + let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?; Ok(Self::F64(dst)) } _ => Err(Error::DTypeMismatchBinaryOp { From 3f0d9fbb257baf94acde184de76eb9667e0fa025 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 15:43:03 +0100 Subject: [PATCH 09/10] Adapt the cuda bits. --- candle-core/src/cpu_backend.rs | 19 +--- candle-core/src/cuda_backend.rs | 157 ++++++++++++++++---------- candle-core/src/dummy_cuda_backend.rs | 2 +- candle-core/src/storage.rs | 14 +-- candle-core/src/tensor.rs | 4 +- 5 files changed, 109 insertions(+), 87 deletions(-) diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 9f0c8602..f1547b3c 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -101,14 +101,9 @@ fn binary_map T>( } } -fn take_impl1( - vs: &[T], - ids: &[u32], - layout: &Layout, - vocab_size: usize, - hidden_size: usize, -) -> Result> { +fn take_impl1(vs: &[T], ids: &[u32], layout: &Layout, rhs_l: &Layout) -> Result> { // TODO: Optimize for the case where ids are contiguous. + let (vocab_size, hidden_size) = rhs_l.shape().r2()?; let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size); for index in layout.strided_index() { let index = ids[index].try_into()?; @@ -610,15 +605,9 @@ impl CpuStorage { } } - pub(crate) fn embedding( - &self, - layout: &Layout, - vs: &Self, - hidden_size: usize, - vocab_size: usize, - ) -> Result { + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let ids = self.as_slice::()?; - map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size) + map1!(rhs, take_impl1, ids, layout, rhs_l) } pub(crate) fn matmul( diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 9cbf82be..f50d7cbb 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,4 +1,4 @@ -use crate::{CpuStorage, DType, Shape}; +use crate::{CpuStorage, DType, Layout, Shape}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig}; @@ -26,6 +26,9 @@ pub enum CudaError { #[error("internal error '{0}'")] InternalError(&'static str), + #[error("internal error '{0}'")] + WrappedError(Box), + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] MatMulNonContiguous { lhs_stride: Vec, @@ -268,12 +271,14 @@ fn gemm_config( alpha: T, beta: T, (b, m, n, k): (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result> { // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm use cudarc::cublas::sys::cublasOperation_t; + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; @@ -352,19 +357,21 @@ impl CudaStorage { &self.device } - pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { + pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { use cudarc::driver::DevicePtr; + let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, stride].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let start_o = layout.start_offset(); let inp = match &self.slice { - CudaStorageSlice::U32(inp) => inp.device_ptr(), - CudaStorageSlice::BF16(inp) => inp.device_ptr(), - CudaStorageSlice::F16(inp) => inp.device_ptr(), - CudaStorageSlice::F32(inp) => inp.device_ptr(), - CudaStorageSlice::F64(inp) => inp.device_ptr(), + CudaStorageSlice::U32(inp) => inp.slice(start_o..).device_ptr(), + CudaStorageSlice::BF16(inp) => inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F16(inp) => inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F32(inp) => inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F64(inp) => inp.slice(start_o..).device_ptr(), }; let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str()); let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; @@ -406,20 +413,16 @@ impl CudaStorage { }) } - pub(crate) fn affine_impl( - &self, - shape: &Shape, - stride: &[usize], - mul: f64, - add: f64, - ) -> Result { + pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + let shape = layout.shape(); let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, stride].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat())?; let slice = match &self.slice { CudaStorageSlice::U32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -429,6 +432,7 @@ impl CudaStorage { CudaStorageSlice::U32(out) } CudaStorageSlice::BF16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -446,6 +450,7 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } CudaStorageSlice::F16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -463,6 +468,7 @@ impl CudaStorage { CudaStorageSlice::F16(out) } CudaStorageSlice::F32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -472,6 +478,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } CudaStorageSlice::F64(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -485,7 +492,8 @@ impl CudaStorage { Ok(Self { slice, device }) } - pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result { + pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { + let shape = layout.shape(); let src_dims = shape.dims(); let el = shape.elem_count(); let mut dst_el = el; @@ -503,9 +511,10 @@ impl CudaStorage { .collect(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([src_dims, stride, &sum_dims_l, &sum_dims_s].concat())?; + let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?; let slice = match &self.slice { CudaStorageSlice::U32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -514,6 +523,7 @@ impl CudaStorage { CudaStorageSlice::U32(out) } CudaStorageSlice::BF16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -522,6 +532,7 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } CudaStorageSlice::F16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -530,6 +541,7 @@ impl CudaStorage { CudaStorageSlice::F16(out) } CudaStorageSlice::F32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -538,6 +550,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } CudaStorageSlice::F64(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -556,21 +569,19 @@ impl CudaStorage { )) } - pub(crate) fn unary_impl( - &self, - shape: &Shape, - stride: &[usize], - ) -> Result { + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { + let shape = layout.shape(); let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; - let ds = dev.htod_copy([dims, stride].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat())?; let slice = match &self.slice { CudaStorageSlice::U32(_arg) => { todo!("No unary kernels for u32"); } CudaStorageSlice::BF16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -580,6 +591,7 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } CudaStorageSlice::F16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -589,6 +601,7 @@ impl CudaStorage { CudaStorageSlice::F16(out) } CudaStorageSlice::F32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -598,6 +611,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } CudaStorageSlice::F64(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -614,17 +628,19 @@ impl CudaStorage { pub(crate) fn binary_impl( &self, rhs: &Self, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result { + let shape = lhs_l.shape(); let dims = shape.dims(); let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let dev = self.device(); - let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?; + let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?; let slice = match (&self.slice, &rhs.slice) { (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(elem_count) }?; @@ -634,6 +650,8 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(elem_count) }?; @@ -643,6 +661,8 @@ impl CudaStorage { CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(elem_count) }?; @@ -652,6 +672,8 @@ impl CudaStorage { CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?; let out = unsafe { dev.alloc::(elem_count) }?; @@ -661,6 +683,8 @@ impl CudaStorage { CudaStorageSlice::F64(out) } (CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?; let out = unsafe { dev.alloc::(elem_count) }?; @@ -708,28 +732,31 @@ impl CudaStorage { pub(crate) fn where_cond( &self, - shape: &Shape, - stride: &[usize], + layout: &Layout, t: &Self, - stride_t: &[usize], + layout_t: &Layout, f: &Self, - stride_f: &[usize], + layout_f: &Layout, ) -> Result { let ids = match &self.slice { - CudaStorageSlice::U32(slice) => slice, + CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..), _ => Err(CudaError::UnexpectedDType { msg: "where conditions should be u32", expected: DType::U32, got: self.dtype(), })?, }; + let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?; + let ds = + dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?; let slice = match (&t.slice, &f.slice) { (CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }?; @@ -739,6 +766,8 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } (CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }?; @@ -748,6 +777,8 @@ impl CudaStorage { CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }?; @@ -757,6 +788,8 @@ impl CudaStorage { CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?; let out = unsafe { dev.alloc::(el) }?; @@ -766,6 +799,8 @@ impl CudaStorage { CudaStorageSlice::F64(out) } (CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?; let out = unsafe { dev.alloc::(el) }?; @@ -775,36 +810,35 @@ impl CudaStorage { CudaStorageSlice::U32(out) } // The dtypes should have been checked at this point so this is an internal error. - _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), + _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, }; let device = dev.clone(); Ok(Self { slice, device }) } - pub(crate) fn embedding_impl( - &self, - shape: &Shape, - stride: &[usize], - rhs: &Self, - h_size: usize, // hidden size - v_size: usize, // vocab size - ) -> Result { + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let ids = match &self.slice { - CudaStorageSlice::U32(slice) => slice, + CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..), _ => Err(CudaError::UnexpectedDType { msg: "embedding ids should be u32", expected: DType::U32, got: self.dtype(), })?, }; + let shape = layout.shape(); + let (v_size, h_size) = rhs_l + .shape() + .r2() + .map_err(|e| CudaError::WrappedError(Box::new(e)))?; let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, stride].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat())?; let slice = match &rhs.slice { // The kernels below assume that rhs is contiguous. CudaStorageSlice::U32(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }?; @@ -814,6 +848,7 @@ impl CudaStorage { CudaStorageSlice::U32(out) } CudaStorageSlice::BF16(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }?; @@ -823,6 +858,7 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } CudaStorageSlice::F16(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }?; @@ -832,6 +868,7 @@ impl CudaStorage { CudaStorageSlice::F16(out) } CudaStorageSlice::F32(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }?; @@ -841,6 +878,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } CudaStorageSlice::F64(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }?; @@ -854,12 +892,12 @@ impl CudaStorage { Ok(Self { slice, device }) } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, rhs: &Self, (b, m, n, k): (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result { let elem_count = b * m * n; let dev = &self.device; @@ -868,7 +906,9 @@ impl CudaStorage { todo!("bf16") } (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { - let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }?; unsafe { self.device @@ -878,7 +918,9 @@ impl CudaStorage { CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { - let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }?; unsafe { self.device @@ -888,7 +930,9 @@ impl CudaStorage { CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { - let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }?; unsafe { self.device @@ -907,13 +951,8 @@ impl CudaStorage { &self, dst: &mut Self, dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_l: &Layout, ) -> Result<()> { - if src_shape.rank() != src_stride.len() { - panic!("incoherent shape and strides {src_shape:?} {src_stride:?}") - } let dims = src_shape.dims(); let el_count = src_shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index ef079812..8193b1af 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -100,7 +100,7 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: usize, _: usize) -> Result { + pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 2c9624c7..7acf6dd0 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -167,26 +167,20 @@ impl Storage { (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), - op: "embedding", + op: "where", }), } } - pub(crate) fn embedding( - &self, - layout: &Layout, - rhs: &Self, - hidden_size: usize, - vocab_size: usize, - ) -> Result { + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { self.same_device(rhs, "embedding")?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding(layout, rhs, rhs_l)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.embedding(layout, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding(layout, rhs, rhs_l)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 93846160..f64bd6f2 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -481,10 +481,10 @@ impl Tensor { } let ids_shape = ids.shape(); let seq_len = ids_shape.r1()?; - let (vocab_size, hidden_size) = rhs.shape().r2()?; + let (_, hidden_size) = rhs.shape().r2()?; let storage = ids .storage - .embedding(ids.layout(), &rhs.storage, hidden_size, vocab_size)?; + .embedding(ids.layout(), &rhs.storage, rhs.layout())?; let shape: Shape = (seq_len, hidden_size).into(); let op = if ids.track_op() || rhs.track_op() { Some(Op::Embedding(ids.clone(), rhs.clone())) From 6c9e6b5a99d4070be5c20d7c383e0ef7e3228260 Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 15:53:23 +0100 Subject: [PATCH 10/10] Get the cuda tests to pass. --- candle-core/src/cuda_backend.rs | 49 +++++++++++++++++++-------------- 1 file changed, 29 insertions(+), 20 deletions(-) diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index f50d7cbb..9d9a5f99 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -27,7 +27,7 @@ pub enum CudaError { InternalError(&'static str), #[error("internal error '{0}'")] - WrappedError(Box), + WrappedError(Box), #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] MatMulNonContiguous { @@ -245,13 +245,14 @@ enum CudaStorageSlice { fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, - src_offset: usize, + src_l: &Layout, dst: &'a mut CudaSlice, dst_offset: usize, ) -> ( cudarc::driver::CudaView<'a, T>, cudarc::driver::CudaViewMut<'a, T>, ) { + let src_offset = src_l.start_offset(); let to_copy = dst .len() .saturating_sub(dst_offset) @@ -366,13 +367,18 @@ impl CudaStorage { let dev = self.device(); let ds = dev.htod_copy([dims, layout.stride()].concat())?; let start_o = layout.start_offset(); + // This returns an i64 rather than a &i64, this is useful to get around some temporary + // lifetime issue and is safe as long as self.slice does not go out of scope before inp + // is used. let inp = match &self.slice { - CudaStorageSlice::U32(inp) => inp.slice(start_o..).device_ptr(), - CudaStorageSlice::BF16(inp) => inp.slice(start_o..).device_ptr(), - CudaStorageSlice::F16(inp) => inp.slice(start_o..).device_ptr(), - CudaStorageSlice::F32(inp) => inp.slice(start_o..).device_ptr(), - CudaStorageSlice::F64(inp) => inp.slice(start_o..).device_ptr(), + CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(), }; + let inp = &inp; + let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str()); let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; let slice = match dtype { @@ -739,13 +745,14 @@ impl CudaStorage { layout_f: &Layout, ) -> Result { let ids = match &self.slice { - CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..), + CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), _ => Err(CudaError::UnexpectedDType { msg: "where conditions should be u32", expected: DType::U32, got: self.dtype(), })?, }; + let ids = &ids; let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); @@ -818,13 +825,14 @@ impl CudaStorage { pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let ids = match &self.slice { - CudaStorageSlice::U32(slice) => &slice.slice(layout.start_offset()..), + CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), _ => Err(CudaError::UnexpectedDType { msg: "embedding ids should be u32", expected: DType::U32, got: self.dtype(), })?, }; + let ids = &ids; let shape = layout.shape(); let (v_size, h_size) = rhs_l .shape() @@ -953,15 +961,16 @@ impl CudaStorage { dst_offset: usize, src_l: &Layout, ) -> Result<()> { + let src_shape = src_l.shape(); let dims = src_shape.dims(); let el_count = src_shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; - let ds = dev.htod_copy([dims, src_stride].concat())?; + let ds = dev.htod_copy([dims, src_l.stride()].concat())?; match (&self.slice, &mut dst.slice) { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?; @@ -972,8 +981,8 @@ impl CudaStorage { } } (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?; @@ -984,8 +993,8 @@ impl CudaStorage { } } (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?; @@ -996,8 +1005,8 @@ impl CudaStorage { } } (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?; @@ -1008,8 +1017,8 @@ impl CudaStorage { } } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;