From c1bbbf94f67d91773e0e7491e0aed9a30d75144a Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 28 Jun 2023 12:57:30 +0100 Subject: [PATCH] Start refactoring the stride. --- candle-core/src/cpu_backend.rs | 46 ++++++++------------- candle-core/src/layout.rs | 47 +++++++++++++++++++++ candle-core/src/lib.rs | 2 + candle-core/src/storage.rs | 63 ++++++++++++----------------- candle-core/src/tensor.rs | 74 +++++++++++++++------------------- 5 files changed, 124 insertions(+), 108 deletions(-) create mode 100644 candle-core/src/layout.rs diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 53c7ecf1..8cafec12 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,5 +1,5 @@ use crate::op::{BinaryOp, UnaryOp}; -use crate::{DType, Error, Result, Shape, StridedIndex}; +use crate::{DType, Error, Layout, Result, Shape, StridedIndex}; use gemm::{gemm, Parallelism}; use half::{bf16, f16}; @@ -18,12 +18,11 @@ pub enum CpuStorage { fn wcond( pred: &[u32], - shape: &Shape, - stride: &[usize], + layout: &Layout, t: &[T], - stride_t: &[usize], + layout_t: &Layout, f: &[T], - stride_f: &[usize], + layout_f: &Layout, ) -> Vec { if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f) { @@ -73,12 +72,7 @@ fn sum_impl1( Ok(dst) } -fn unary_map U>( - vs: &[T], - shape: &Shape, - stride: &[usize], - mut f: F, -) -> Vec { +fn unary_map U>(vs: &[T], layout: &Layout, mut f: F) -> Vec { if shape.is_contiguous(stride) { vs[..shape.elem_count()].iter().map(|&v| f(v)).collect() } else { @@ -461,65 +455,59 @@ impl CpuStorage { Ok(()) } - pub(crate) fn affine_impl( - &self, - shape: &Shape, - stride: &[usize], - mul: f64, - add: f64, - ) -> Result { + pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { match self { Self::U32(storage) => { let mul = mul as u32; let add = add as u32; - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::U32(data)) } Self::BF16(storage) => { let mul = bf16::from_f64(mul); let add = bf16::from_f64(add); - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::BF16(data)) } Self::F16(storage) => { let mul = f16::from_f64(mul); let add = f16::from_f64(add); - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::F16(data)) } Self::F32(storage) => { let mul = mul as f32; let add = add as f32; - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::F32(data)) } Self::F64(storage) => { - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::F64(data)) } } } - pub(crate) fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { match self { Self::BF16(storage) => { - let data = unary_map(storage, shape, stride, B::bf16); + let data = unary_map(storage, layout, B::bf16); Ok(Self::BF16(data)) } Self::F16(storage) => { - let data = unary_map(storage, shape, stride, B::f16); + let data = unary_map(storage, layout, B::f16); Ok(Self::F16(data)) } Self::F32(storage) => { - let data = unary_map(storage, shape, stride, B::f32); + let data = unary_map(storage, layout, B::f32); Ok(Self::F32(data)) } Self::F64(storage) => { - let data = unary_map(storage, shape, stride, B::f64); + let data = unary_map(storage, layout, B::f64); Ok(Self::F64(data)) } Self::U32(storage) => { - let data = unary_map(storage, shape, stride, B::u32); + let data = unary_map(storage, layout, B::u32); Ok(Self::U32(data)) } } diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs new file mode 100644 index 00000000..a0f3639b --- /dev/null +++ b/candle-core/src/layout.rs @@ -0,0 +1,47 @@ +use crate::Shape; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Layout { + shape: Shape, + // The strides are given in number of elements and not in bytes. + stride: Vec, + start_offset: usize, +} + +impl Layout { + pub fn contiguous>(shape: S) -> Self { + let shape = shape.into(); + let stride = shape.stride_contiguous(); + Self { + shape, + stride, + start_offset: 0, + } + } + + pub fn dims(&self) -> &[usize] { + self.shape.dims() + } + + pub fn shape(&self) -> &Shape { + &self.shape + } + + pub fn stride(&self) -> &[usize] { + &self.stride + } + + pub fn start_offset(&self) -> usize { + self.start_offset + } + + /// Returns true if the data is stored in a C contiguous (aka row major) way. + pub fn is_contiguous(&self) -> bool { + self.shape.is_contiguous(&self.stride) + } + + /// Returns true if the data is stored in a Fortran contiguous (aka column major) way. + pub fn is_fortran_contiguous(&self) -> bool { + self.shape.is_fortran_contiguous(&self.stride) + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 5771517f..6a860116 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -7,6 +7,7 @@ pub mod display; mod dtype; mod dummy_cuda_backend; mod error; +mod layout; mod npy; mod op; mod shape; @@ -19,6 +20,7 @@ pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; pub use dtype::{DType, WithDType}; pub use error::{Error, Result}; +pub use layout::Layout; pub use shape::Shape; pub use storage::Storage; use strided_index::StridedIndex; diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index e44a2db6..3bcb0ff0 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -1,4 +1,4 @@ -use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape}; +use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape}; // We do not want to implement Clone on Storage as cloning may fail because of // out of memory. Instead try_clone should be used. @@ -53,33 +53,27 @@ impl Storage { } } - pub(crate) fn affine_impl( - &self, - shape: &Shape, - stride: &[usize], - mul: f64, - add: f64, - ) -> Result { + pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { match self { Storage::Cpu(storage) => { - let storage = storage.affine_impl(shape, stride, mul, add)?; + let storage = storage.affine(layout, mul, add)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.affine_impl(shape, stride, mul, add)?; + let storage = storage.affine(layout, mul, add)?; Ok(Self::Cuda(storage)) } } } - pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result { + pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result { match self { Storage::Cpu(storage) => { - let storage = storage.sum(shape, stride, s)?; + let storage = storage.sum(layout, s)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.sum(shape, stride, s)?; + let storage = storage.sum(layout, s)?; Ok(Self::Cuda(storage)) } } @@ -93,32 +87,28 @@ impl Storage { Ok(()) } - pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { + pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { match self { Storage::Cpu(storage) => { - let storage = storage.to_dtype(shape, stride, dtype)?; + let storage = storage.to_dtype(layout, dtype)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.to_dtype(shape, stride, dtype)?; + let storage = storage.to_dtype(layout, dtype)?; Ok(Self::Cuda(storage)) } } } - pub(crate) fn unary_impl( - &self, - shape: &Shape, - stride: &[usize], - ) -> Result { + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { // TODO: Different code path for the contiguous case? match self { Storage::Cpu(storage) => { - let storage = storage.unary_impl::(shape, stride)?; + let storage = storage.unary_impl::(layout)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.unary_impl::(shape, stride)?; + let storage = storage.unary_impl::(layout)?; Ok(Self::Cuda(storage)) } } @@ -127,19 +117,18 @@ impl Storage { pub(crate) fn binary_impl( &self, rhs: &Self, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result { self.same_device(rhs, B::NAME)?; self.same_dtype(rhs, B::NAME)?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.binary_impl::(rhs, shape, lhs_stride, rhs_stride)?; + let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.binary_impl::(rhs, shape, lhs_stride, rhs_stride)?; + let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => { @@ -156,23 +145,22 @@ impl Storage { pub(crate) fn where_cond( &self, - shape: &Shape, - stride: &[usize], + layout: &Shape, t: &Self, - stride_t: &[usize], + layout_t: &Layout, f: &Self, - stride_f: &[usize], + layout_f: &Layout, ) -> Result { self.same_device(t, "where")?; self.same_device(f, "where")?; t.same_dtype(f, "where")?; match (self, t, f) { (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => { - let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?; + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; Ok(Self::Cpu(storage)) } (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => { - let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?; + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; Ok(Self::Cuda(storage)) } (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { @@ -185,8 +173,7 @@ impl Storage { pub(crate) fn embedding_impl( &self, - shape: &Shape, - stride: &[usize], + layout: &Layout, rhs: &Self, hidden_size: usize, vocab_size: usize, @@ -194,11 +181,11 @@ impl Storage { self.same_device(rhs, "embedding")?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding_impl(layout, rhs, hidden_size, vocab_size)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index feb59d3c..51eeb9ae 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,4 +1,4 @@ -use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape}; +use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::Arc; /// Unique identifier for tensors. @@ -17,9 +17,7 @@ impl TensorId { pub struct Tensor_ { id: TensorId, storage: Arc, - shape: Shape, - // The strides are given in number of elements and not in bytes. - stride: Vec, + layout: Layout, op: Option, is_variable: bool, } @@ -50,7 +48,7 @@ macro_rules! unary_op { let shape = self.shape(); let storage = self .storage - .unary_impl::(self.shape(), self.stride())?; + .unary_impl::(self.layout())?; let op = if self.track_op() { Some(Op::$op_name(self.clone())) } else { @@ -67,9 +65,8 @@ macro_rules! binary_op { let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?; let storage = self.storage.binary_impl::( &rhs.storage, - shape, - self.stride(), - rhs.stride(), + self.layout(), + rhs.layout(), )?; let op = if self.track_op() || rhs.track_op() { Some(Op::$op_name(self.clone(), rhs.clone())) @@ -107,13 +104,10 @@ fn from_storage>( op: Option, is_variable: bool, ) -> Tensor { - let shape = shape.into(); - let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(storage), - shape, - stride, + layout: Layout::contiguous(shape), op, is_variable, }; @@ -342,8 +336,7 @@ impl Tensor { } pub fn affine(&self, mul: f64, add: f64) -> Result { - let shape = self.shape(); - let storage = self.storage.affine_impl(shape, self.stride(), mul, add)?; + let storage = self.storage.affine(self.layout(), mul, add)?; let op = if self.track_op() { Some(Op::Affine { arg: self.clone(), @@ -353,7 +346,7 @@ impl Tensor { } else { None }; - Ok(from_storage(storage, shape.clone(), op, false)) + Ok(from_storage(storage, self.shape(), op, false)) } /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` @@ -401,9 +394,7 @@ impl Tensor { exp.broadcast_div(&sum_exp) } else { let shape = self.shape(); - let mut storage = self - .storage - .unary_impl::(shape, self.stride())?; + let mut storage = self.storage.unary_impl::(self.layout())?; // The resulting storage is contiguous. storage.divide_by_sum_over_dim(shape, dim)?; let op = if self.track_op() { @@ -416,7 +407,7 @@ impl Tensor { } pub fn sum(&self, sum_dims: &[usize]) -> Result { - let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?; + let storage = self.storage.sum(self.layout(), sum_dims)?; let op = if self.track_op() { Some(Op::Sum(self.clone(), sum_dims.to_vec())) } else { @@ -461,8 +452,8 @@ impl Tensor { let storage = self.storage.matmul_impl( &rhs.storage, (batching, m, n, k), - self.stride(), - rhs.stride(), + self.layout(), + rhs.layout(), )?; let op = if self.track_op() || rhs.track_op() { Some(Op::Matmul(self.clone(), rhs.clone())) @@ -476,12 +467,11 @@ impl Tensor { let _shap = self.same_shape_binary_op(on_true, "where_cond")?; let shape = self.same_shape_binary_op(on_false, "where_cond")?; let storage = self.storage.where_cond( - shape, - self.stride(), + self.layout(), &on_true.storage, - on_true.stride(), + on_true.layout(), &on_false.storage, - on_false.stride(), + on_false.layout(), )?; let op = if self.track_op() || on_true.track_op() || on_false.track_op() { Some(Op::WhereCond( @@ -498,10 +488,10 @@ impl Tensor { pub fn embedding(ids: &Self, rhs: &Self) -> Result { if !rhs.is_contiguous() { return Err(Error::RequiresContiguous { op: "embedding" }); - } else if rhs.shape().rank() != 2 || ids.shape().rank() != 1 { + } else if rhs.rank() != 2 || ids.rank() != 1 { return Err(Error::ShapeMismatchBinaryOp { - lhs: ids.shape.clone(), - rhs: rhs.shape.clone(), + lhs: ids.shape().clone(), + rhs: rhs.shape().clone(), op: "embedding", }); } @@ -509,7 +499,7 @@ impl Tensor { let seq_len = ids_shape.r1()?; let (vocab_size, hidden_size) = rhs.shape().r2()?; let storage = ids.storage.embedding_impl( - ids_shape, + ids.layout(), &ids.stride, &rhs.storage, hidden_size, @@ -625,8 +615,13 @@ impl Tensor { self.shape().dims() } - pub fn stride(&self) -> &[usize] { - &self.stride + pub fn layout(&self) -> &Layout { + &self.layout + } + + // TODO: Rename to `stride` once the PR that introduced the layout has been merged. + pub fn stride_tmp(&self) -> &[usize] { + &self.layout.stride() } pub fn rank(&self) -> usize { @@ -734,12 +729,12 @@ impl Tensor { /// Returns true if the data is stored in a C contiguous (aka row major) way. pub fn is_contiguous(&self) -> bool { - self.shape.is_contiguous(&self.stride) + self.layout.is_contiguous() } /// Returns true if the data is stored in a Fortran contiguous (aka column major) way. pub fn is_fortran_contiguous(&self) -> bool { - self.shape.is_fortran_contiguous(&self.stride) + self.layout.is_fortran_contiguous() } /// Compared to clone, this copies the actual storage but may fail because of running out of @@ -748,8 +743,7 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(self.storage.try_clone()?), - shape: self.shape.clone(), - stride: self.stride.clone(), + layout: self.layout.clone(), op: None, // TODO is_variable: false, }; @@ -762,8 +756,7 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape: self.shape.clone(), - stride: self.stride.clone(), + layout: self.layout.clone(), op: None, is_variable: false, }; @@ -796,8 +789,7 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(storage), - shape: self.shape.clone(), - stride: self.stride.clone(), + layout: self.layout.clone(), op, is_variable: false, }; @@ -810,7 +802,7 @@ impl Tensor { pub fn broadcast_left>(&self, left_shape: S) -> Result { let left_shape = left_shape.into(); let mut dims = left_shape.into_dims(); - dims.extend(self.shape.dims()); + dims.extend(self.dims()); self.broadcast_as(dims) } @@ -866,7 +858,7 @@ impl Tensor { Ok(self.clone()) } else { let shape = self.shape(); - let storage = self.storage.to_dtype(shape, self.stride(), dtype)?; + let storage = self.storage.to_dtype(self.layout(), dtype)?; let op = if self.track_op() { Some(Op::ToDType(self.clone())) } else {