diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 53c7ecf1..f1547b3c 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,5 +1,5 @@ use crate::op::{BinaryOp, UnaryOp}; -use crate::{DType, Error, Result, Shape, StridedIndex}; +use crate::{DType, Error, Layout, Result, Shape}; use gemm::{gemm, Parallelism}; use half::{bf16, f16}; @@ -18,31 +18,31 @@ pub enum CpuStorage { fn wcond( pred: &[u32], - shape: &Shape, - stride: &[usize], + layout: &Layout, t: &[T], - stride_t: &[usize], + layout_t: &Layout, f: &[T], - stride_f: &[usize], + layout_f: &Layout, ) -> Vec { - if shape.is_contiguous(stride) && shape.is_contiguous(stride_t) && shape.is_contiguous(stride_f) - { - let elem_count = shape.elem_count(); - let pred = &pred[..elem_count]; - let t = &t[..elem_count]; - let f = &f[..elem_count]; - pred.iter() - .zip(t.iter().zip(f.iter())) - .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) - .collect::>() - } else { - let dims = shape.dims(); - let it_p = StridedIndex::new(dims, stride); - let it_t = StridedIndex::new(dims, stride_t); - let it_f = StridedIndex::new(dims, stride_f); - it_p.zip(it_t.zip(it_f)) + match ( + layout.contiguous_offsets(), + layout_t.contiguous_offsets(), + layout_f.contiguous_offsets(), + ) { + (Some((o1, o2)), Some((o_t1, o_t2)), Some((o_f1, o_f2))) => { + let pred = &pred[o1..o2]; + let t = &t[o_t1..o_t2]; + let f = &f[o_f1..o_f2]; + pred.iter() + .zip(t.iter().zip(f.iter())) + .map(|(&p, (&t, &f))| if p > 0 { t } else { f }) + .collect::>() + } + _ => layout + .strided_index() + .zip(layout_t.strided_index().zip(layout_f.strided_index())) .map(|(i_p, (i_t, i_f))| if pred[i_p] > 0 { t[i_t] } else { f[i_f] }) - .collect::>() + .collect::>(), } } @@ -62,64 +62,50 @@ macro_rules! map1 { fn sum_impl1( src: &[T], dst_shape: &Shape, - src_dims: &[usize], - stride: &[usize], + src_layout: &Layout, to_dst_index: impl Fn(usize) -> usize, ) -> Result> { let mut dst = vec![T::zero(); dst_shape.elem_count()]; - for (unstr_index, src_index) in StridedIndex::new(src_dims, stride).enumerate() { + for (unstr_index, src_index) in src_layout.strided_index().enumerate() { dst[to_dst_index(unstr_index)] += src[src_index]; } Ok(dst) } -fn unary_map U>( - vs: &[T], - shape: &Shape, - stride: &[usize], - mut f: F, -) -> Vec { - if shape.is_contiguous(stride) { - vs[..shape.elem_count()].iter().map(|&v| f(v)).collect() - } else { - StridedIndex::new(shape.dims(), stride) - .map(|i| f(vs[i])) - .collect() +fn unary_map U>(vs: &[T], layout: &Layout, mut f: F) -> Vec { + match layout.contiguous_offsets() { + Some((o1, o2)) => vs[o1..o2].iter().map(|&v| f(v)).collect(), + None => layout.strided_index().map(|i| f(vs[i])).collect(), } } // This function maps over two strided index sequences. fn binary_map T>( - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, lhs: &[T], rhs: &[T], mut f: F, ) -> Vec { - let dims = shape.dims(); - if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) { - (0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect() - } else { - let lhs_index = StridedIndex::new(dims, lhs_stride); - let rhs_index = StridedIndex::new(dims, rhs_stride); - lhs_index - .zip(rhs_index) + match (lhs_l.contiguous_offsets(), rhs_l.contiguous_offsets()) { + (Some((o_l1, o_l2)), Some((o_r1, o_r2))) => lhs[o_l1..o_l2] + .iter() + .zip(rhs[o_r1..o_r2].iter()) + .map(|(&l, &r)| f(l, r)) + .collect(), + _ => lhs_l + .strided_index() + .zip(rhs_l.strided_index()) .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect() + .collect(), } } -fn take_impl1( - vs: &[T], - ids: &[u32], - shape: &Shape, - stride: &[usize], - vocab_size: usize, - hidden_size: usize, -) -> Result> { - let mut values = Vec::with_capacity(shape.elem_count() * hidden_size); - for index in StridedIndex::new(shape.dims(), stride) { +fn take_impl1(vs: &[T], ids: &[u32], layout: &Layout, rhs_l: &Layout) -> Result> { + // TODO: Optimize for the case where ids are contiguous. + let (vocab_size, hidden_size) = rhs_l.shape().r2()?; + let mut values = Vec::with_capacity(layout.shape().elem_count() * hidden_size); + for index in layout.strided_index() { let index = ids[index].try_into()?; if index >= vocab_size { return Err(Error::InvalidIndex { @@ -138,37 +124,40 @@ fn copy_strided_src_( src: &[T], dst: &mut [T], dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_l: &Layout, ) { - let src = &src[src_offset..]; - if src_shape.is_contiguous(src_stride) { - let elem_to_copy = (dst.len() - dst_offset).min(src.len()); - dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[..elem_to_copy]) - } else { - let src_indexes = StridedIndex::new(src_shape.dims(), src_stride); - for (dst_index, src_index) in src_indexes.enumerate() { - let dst_index = dst_index + dst_offset; - if dst_index >= dst.len() { - break; + match src_l.contiguous_offsets() { + Some((o_dst1, o_dst2)) => { + let elem_to_copy = (dst.len() - dst_offset).min(o_dst2 - o_dst1); + dst[dst_offset..dst_offset + elem_to_copy].copy_from_slice(&src[o_dst1..o_dst2]) + } + None => { + for (dst_index, src_index) in src_l.strided_index().enumerate() { + let dst_index = dst_index + dst_offset; + if dst_index >= dst.len() { + break; + } + dst[dst_index] = src[src_index] } - dst[dst_index] = src[src_index] } } } -fn matmul_impl( +fn matmul( lhs: &[T], rhs: &[T], (b, m, n, k): (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result> { + let lhs = &lhs[lhs_l.start_offset()..]; + let rhs = &rhs[rhs_l.start_offset()..]; let a_skip: usize = m * k; let b_skip: usize = n * k; let c_skip: usize = m * n; + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); let rank = lhs_stride.len(); let lhs_cs = lhs_stride[rank - 1]; let lhs_rs = lhs_stride[rank - 2]; @@ -238,118 +227,114 @@ impl CpuStorage { D::cpu_storage_as_slice(self) } - pub fn as_mut_slice(&mut self) -> Result<&mut [D]> { - D::cpu_storage_as_mut_slice(self) - } - - pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { + pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { // TODO: find a way around the quadratic number of cases below. match (self, dtype) { (Self::U32(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| bf16::from_f32(v as f32)); Ok(Self::BF16(data)) } (Self::BF16(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::BF16(data)) } (Self::F16(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, |v| bf16::from_f32(v.to_f32())); + let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32())); Ok(Self::BF16(data)) } (Self::F32(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, bf16::from_f32); + let data = unary_map(storage, layout, bf16::from_f32); Ok(Self::BF16(data)) } (Self::F64(storage), DType::BF16) => { - let data = unary_map(storage, shape, stride, bf16::from_f64); + let data = unary_map(storage, layout, bf16::from_f64); Ok(Self::BF16(data)) } (Self::U32(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, |v| f16::from_f32(v as f32)); + let data = unary_map(storage, layout, |v| f16::from_f32(v as f32)); Ok(Self::F16(data)) } (Self::BF16(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, |v| f16::from_f32(v.to_f32())); + let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32())); Ok(Self::F16(data)) } (Self::F16(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::F16(data)) } (Self::F32(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, f16::from_f32); + let data = unary_map(storage, layout, f16::from_f32); Ok(Self::F16(data)) } (Self::F64(storage), DType::F16) => { - let data = unary_map(storage, shape, stride, f16::from_f64); + let data = unary_map(storage, layout, f16::from_f64); Ok(Self::F16(data)) } (Self::U32(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v as f32); + let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } (Self::BF16(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v.to_f32()); + let data = unary_map(storage, layout, |v| v.to_f32()); Ok(Self::F32(data)) } (Self::F16(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v.to_f32()); + let data = unary_map(storage, layout, |v| v.to_f32()); Ok(Self::F32(data)) } (Self::F32(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::F32(data)) } (Self::F64(storage), DType::F32) => { - let data = unary_map(storage, shape, stride, |v| v as f32); + let data = unary_map(storage, layout, |v| v as f32); Ok(Self::F32(data)) } (Self::U32(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::U32(data)) } (Self::BF16(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32); + let data = unary_map(storage, layout, |v| v.to_f32() as u32); Ok(Self::U32(data)) } (Self::F16(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v.to_f32() as u32); + let data = unary_map(storage, layout, |v| v.to_f32() as u32); Ok(Self::U32(data)) } (Self::F32(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v as u32); + let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } (Self::F64(storage), DType::U32) => { - let data = unary_map(storage, shape, stride, |v| v as u32); + let data = unary_map(storage, layout, |v| v as u32); Ok(Self::U32(data)) } (Self::U32(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v as f64); + let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) } (Self::BF16(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v.to_f64()); + let data = unary_map(storage, layout, |v| v.to_f64()); Ok(Self::F64(data)) } (Self::F16(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v.to_f64()); + let data = unary_map(storage, layout, |v| v.to_f64()); Ok(Self::F64(data)) } (Self::F32(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v as f64); + let data = unary_map(storage, layout, |v| v as f64); Ok(Self::F64(data)) } (Self::F64(storage), DType::F64) => { - let data = unary_map(storage, shape, stride, |v| v); + let data = unary_map(storage, layout, |v| v); Ok(Self::F64(data)) } } } - pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result { - let src_dims = shape.dims(); + pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { + let src_dims = layout.dims(); let mut dst_dims = src_dims.to_vec(); for &sum_dim in sum_dims.iter() { dst_dims[sum_dim] = 1; @@ -375,7 +360,7 @@ impl CpuStorage { dst_index }; // TODO: Maybe provide an implementation with higher precision accumulators? - map1!(self, sum_impl1, &dst_shape, src_dims, stride, to_dst_index) + map1!(self, sum_impl1, &dst_shape, layout, to_dst_index) } pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { @@ -461,65 +446,59 @@ impl CpuStorage { Ok(()) } - pub(crate) fn affine_impl( - &self, - shape: &Shape, - stride: &[usize], - mul: f64, - add: f64, - ) -> Result { + pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { match self { Self::U32(storage) => { let mul = mul as u32; let add = add as u32; - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::U32(data)) } Self::BF16(storage) => { let mul = bf16::from_f64(mul); let add = bf16::from_f64(add); - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::BF16(data)) } Self::F16(storage) => { let mul = f16::from_f64(mul); let add = f16::from_f64(add); - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::F16(data)) } Self::F32(storage) => { let mul = mul as f32; let add = add as f32; - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::F32(data)) } Self::F64(storage) => { - let data = unary_map(storage, shape, stride, |v| v * mul + add); + let data = unary_map(storage, layout, |v| v * mul + add); Ok(Self::F64(data)) } } } - pub(crate) fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { match self { Self::BF16(storage) => { - let data = unary_map(storage, shape, stride, B::bf16); + let data = unary_map(storage, layout, B::bf16); Ok(Self::BF16(data)) } Self::F16(storage) => { - let data = unary_map(storage, shape, stride, B::f16); + let data = unary_map(storage, layout, B::f16); Ok(Self::F16(data)) } Self::F32(storage) => { - let data = unary_map(storage, shape, stride, B::f32); + let data = unary_map(storage, layout, B::f32); Ok(Self::F32(data)) } Self::F64(storage) => { - let data = unary_map(storage, shape, stride, B::f64); + let data = unary_map(storage, layout, B::f64); Ok(Self::F64(data)) } Self::U32(storage) => { - let data = unary_map(storage, shape, stride, B::u32); + let data = unary_map(storage, layout, B::u32); Ok(Self::U32(data)) } } @@ -528,29 +507,28 @@ impl CpuStorage { pub(crate) fn binary_impl( &self, rhs: &Self, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result { match (self, rhs) { (Self::BF16(lhs), Self::BF16(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::bf16); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::bf16); Ok(Self::BF16(data)) } (Self::F16(lhs), Self::F16(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f16); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f16); Ok(Self::F16(data)) } (Self::F32(lhs), Self::F32(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f32); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f32); Ok(Self::F32(data)) } (Self::F64(lhs), Self::F64(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::f64); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::f64); Ok(Self::F64(data)) } (Self::U32(lhs), Self::U32(rhs)) => { - let data = binary_map(shape, lhs_stride, rhs_stride, lhs, rhs, B::u32); + let data = binary_map(lhs_l, rhs_l, lhs, rhs, B::u32); Ok(Self::U32(data)) } _ => { @@ -568,29 +546,14 @@ impl CpuStorage { &self, dst: &mut Self, dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_l: &Layout, ) -> Result<()> { - if src_shape.rank() != src_stride.len() { - panic!("incoherent shape and strides {src_shape:?} {src_stride:?}") - } match (self, dst) { - (Self::U32(src), Self::U32(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::BF16(src), Self::BF16(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::F16(src), Self::F16(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::F32(src), Self::F32(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::F64(src), Self::F64(dst)) => { - copy_strided_src_(src, dst, dst_offset, src_shape, src_stride, src_offset) - } + (Self::U32(src), Self::U32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::BF16(src), Self::BF16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F16(src), Self::F16(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F32(src), Self::F32(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), + (Self::F64(src), Self::F64(dst)) => copy_strided_src_(src, dst, dst_offset, src_l), (_, dst) => { // This should be covered by the dtype check above. return Err(Error::DTypeMismatchBinaryOp { @@ -605,34 +568,33 @@ impl CpuStorage { pub(crate) fn where_cond( &self, - shape: &Shape, - stride: &[usize], + layout: &Layout, t: &Self, - stride_t: &[usize], + layout_t: &Layout, f: &Self, - stride_f: &[usize], + layout_f: &Layout, ) -> Result { // TODO: Support types that could be casted to a boolean. let pred = self.as_slice::()?; match (t, f) { (Self::BF16(t), Self::BF16(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::BF16(data)) } (Self::F16(t), Self::F16(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::F16(data)) } (Self::F32(t), Self::F32(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::F32(data)) } (Self::F64(t), Self::F64(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::F64(data)) } (Self::U32(t), Self::U32(f)) => { - let data = wcond(pred, shape, stride, t, stride_t, f, stride_f); + let data = wcond(pred, layout, t, layout_t, f, layout_f); Ok(Self::U32(data)) } _ => Err(Error::DTypeMismatchBinaryOp { @@ -643,36 +605,29 @@ impl CpuStorage { } } - pub(crate) fn embedding_impl( - &self, - shape: &Shape, - stride: &[usize], - vs: &Self, - hidden_size: usize, - vocab_size: usize, - ) -> Result { + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let ids = self.as_slice::()?; - map1!(vs, take_impl1, ids, shape, stride, vocab_size, hidden_size) + map1!(rhs, take_impl1, ids, layout, rhs_l) } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, rhs: &Self, bmnk: (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result { match (self, rhs) { (CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => { - let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?; + let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?; Ok(Self::F16(dst)) } (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { - let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?; + let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?; Ok(Self::F32(dst)) } (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { - let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?; + let dst = matmul(lhs, rhs, bmnk, lhs_l, rhs_l)?; Ok(Self::F64(dst)) } _ => Err(Error::DTypeMismatchBinaryOp { diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 9cbf82be..9d9a5f99 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -1,4 +1,4 @@ -use crate::{CpuStorage, DType, Shape}; +use crate::{CpuStorage, DType, Layout, Shape}; use candle_kernels as kernels; use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig}; use cudarc::driver::{CudaFunction, CudaSlice, DeviceSlice, LaunchAsync, LaunchConfig}; @@ -26,6 +26,9 @@ pub enum CudaError { #[error("internal error '{0}'")] InternalError(&'static str), + #[error("internal error '{0}'")] + WrappedError(Box), + #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] MatMulNonContiguous { lhs_stride: Vec, @@ -242,13 +245,14 @@ enum CudaStorageSlice { fn slice_src_and_dst<'a, T>( src: &'a CudaSlice, - src_offset: usize, + src_l: &Layout, dst: &'a mut CudaSlice, dst_offset: usize, ) -> ( cudarc::driver::CudaView<'a, T>, cudarc::driver::CudaViewMut<'a, T>, ) { + let src_offset = src_l.start_offset(); let to_copy = dst .len() .saturating_sub(dst_offset) @@ -268,12 +272,14 @@ fn gemm_config( alpha: T, beta: T, (b, m, n, k): (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result> { // https://docs.nvidia.com/cuda/cublas/index.html#cublas-t-gemm use cudarc::cublas::sys::cublasOperation_t; + let lhs_stride = lhs_l.stride(); + let rhs_stride = rhs_l.stride(); let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; @@ -352,20 +358,27 @@ impl CudaStorage { &self.device } - pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { + pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { use cudarc::driver::DevicePtr; + let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, stride].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat())?; + let start_o = layout.start_offset(); + // This returns an i64 rather than a &i64, this is useful to get around some temporary + // lifetime issue and is safe as long as self.slice does not go out of scope before inp + // is used. let inp = match &self.slice { - CudaStorageSlice::U32(inp) => inp.device_ptr(), - CudaStorageSlice::BF16(inp) => inp.device_ptr(), - CudaStorageSlice::F16(inp) => inp.device_ptr(), - CudaStorageSlice::F32(inp) => inp.device_ptr(), - CudaStorageSlice::F64(inp) => inp.device_ptr(), + CudaStorageSlice::U32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::BF16(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F16(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F32(inp) => *inp.slice(start_o..).device_ptr(), + CudaStorageSlice::F64(inp) => *inp.slice(start_o..).device_ptr(), }; + let inp = &inp; + let kernel_name = format!("cast_{}_{}", self.dtype().as_str(), dtype.as_str()); let func = dev.get_or_load_func(&kernel_name, kernels::CAST)?; let slice = match dtype { @@ -406,20 +419,16 @@ impl CudaStorage { }) } - pub(crate) fn affine_impl( - &self, - shape: &Shape, - stride: &[usize], - mul: f64, - add: f64, - ) -> Result { + pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { + let shape = layout.shape(); let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, stride].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat())?; let slice = match &self.slice { CudaStorageSlice::U32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_u32", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -429,6 +438,7 @@ impl CudaStorage { CudaStorageSlice::U32(out) } CudaStorageSlice::BF16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_bf16", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -446,6 +456,7 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } CudaStorageSlice::F16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_f16", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -463,6 +474,7 @@ impl CudaStorage { CudaStorageSlice::F16(out) } CudaStorageSlice::F32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_f32", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -472,6 +484,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } CudaStorageSlice::F64(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("affine_f64", kernels::AFFINE)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -485,7 +498,8 @@ impl CudaStorage { Ok(Self { slice, device }) } - pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], sum_dims: &[usize]) -> Result { + pub(crate) fn sum(&self, layout: &Layout, sum_dims: &[usize]) -> Result { + let shape = layout.shape(); let src_dims = shape.dims(); let el = shape.elem_count(); let mut dst_el = el; @@ -503,9 +517,10 @@ impl CudaStorage { .collect(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([src_dims, stride, &sum_dims_l, &sum_dims_s].concat())?; + let ds = dev.htod_copy([src_dims, layout.stride(), &sum_dims_l, &sum_dims_s].concat())?; let slice = match &self.slice { CudaStorageSlice::U32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_u32", kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -514,6 +529,7 @@ impl CudaStorage { CudaStorageSlice::U32(out) } CudaStorageSlice::BF16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_bf16", kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -522,6 +538,7 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } CudaStorageSlice::F16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_f16", kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -530,6 +547,7 @@ impl CudaStorage { CudaStorageSlice::F16(out) } CudaStorageSlice::F32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_f32", kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -538,6 +556,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } CudaStorageSlice::F64(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func("sum_f64", kernels::REDUCE)?; let out = dev.alloc_zeros::(dst_el)?; let params = (el, src_dims.len(), sum_dims.len(), &ds, arg, &out); @@ -556,21 +575,19 @@ impl CudaStorage { )) } - pub(crate) fn unary_impl( - &self, - shape: &Shape, - stride: &[usize], - ) -> Result { + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { + let shape = layout.shape(); let dims = shape.dims(); let el_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; - let ds = dev.htod_copy([dims, stride].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat())?; let slice = match &self.slice { CudaStorageSlice::U32(_arg) => { todo!("No unary kernels for u32"); } CudaStorageSlice::BF16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func(U::KERNEL_BF16, kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -580,6 +597,7 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } CudaStorageSlice::F16(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func(U::KERNEL_F16, kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -589,6 +607,7 @@ impl CudaStorage { CudaStorageSlice::F16(out) } CudaStorageSlice::F32(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func(U::KERNEL_F32, kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -598,6 +617,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } CudaStorageSlice::F64(arg) => { + let arg = &arg.slice(layout.start_offset()..); let func = dev.get_or_load_func(U::KERNEL_F64, kernels::UNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el_count) }?; @@ -614,17 +634,19 @@ impl CudaStorage { pub(crate) fn binary_impl( &self, rhs: &Self, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result { + let shape = lhs_l.shape(); let dims = shape.dims(); let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let dev = self.device(); - let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?; + let dims_and_strides = dev.htod_copy([dims, lhs_l.stride(), rhs_l.stride()].concat())?; let slice = match (&self.slice, &rhs.slice) { (CudaStorageSlice::BF16(lhs), CudaStorageSlice::BF16(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(B::KERNEL_BF16, kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(elem_count) }?; @@ -634,6 +656,8 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(B::KERNEL_F16, kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(elem_count) }?; @@ -643,6 +667,8 @@ impl CudaStorage { CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func(B::KERNEL_F32, kernels::BINARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(elem_count) }?; @@ -652,6 +678,8 @@ impl CudaStorage { CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func(B::KERNEL_F64, kernels::BINARY)?; let out = unsafe { dev.alloc::(elem_count) }?; @@ -661,6 +689,8 @@ impl CudaStorage { CudaStorageSlice::F64(out) } (CudaStorageSlice::U32(lhs), CudaStorageSlice::U32(rhs)) => { + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func(B::KERNEL_U32, kernels::BINARY)?; let out = unsafe { dev.alloc::(elem_count) }?; @@ -708,28 +738,32 @@ impl CudaStorage { pub(crate) fn where_cond( &self, - shape: &Shape, - stride: &[usize], + layout: &Layout, t: &Self, - stride_t: &[usize], + layout_t: &Layout, f: &Self, - stride_f: &[usize], + layout_f: &Layout, ) -> Result { let ids = match &self.slice { - CudaStorageSlice::U32(slice) => slice, + CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), _ => Err(CudaError::UnexpectedDType { msg: "where conditions should be u32", expected: DType::U32, got: self.dtype(), })?, }; + let ids = &ids; + let shape = layout.shape(); let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, stride, stride_t, stride_f].concat())?; + let ds = + dev.htod_copy([dims, layout.stride(), layout_t.stride(), layout_f.stride()].concat())?; let slice = match (&t.slice, &f.slice) { (CudaStorageSlice::BF16(t), CudaStorageSlice::BF16(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func("where_bf16", kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }?; @@ -739,6 +773,8 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } (CudaStorageSlice::F16(t), CudaStorageSlice::F16(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func("where_f16", kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }?; @@ -748,6 +784,8 @@ impl CudaStorage { CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(t), CudaStorageSlice::F32(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); let func = dev.get_or_load_func("where_f32", kernels::TERNARY)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el) }?; @@ -757,6 +795,8 @@ impl CudaStorage { CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(t), CudaStorageSlice::F64(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func("where_f64", kernels::TERNARY)?; let out = unsafe { dev.alloc::(el) }?; @@ -766,6 +806,8 @@ impl CudaStorage { CudaStorageSlice::F64(out) } (CudaStorageSlice::U32(t), CudaStorageSlice::U32(f)) => { + let t = &t.slice(layout_t.start_offset()..); + let f = &f.slice(layout_f.start_offset()..); // SAFETY: Set later by running the kernel. let func = dev.get_or_load_func("where_u32", kernels::TERNARY)?; let out = unsafe { dev.alloc::(el) }?; @@ -775,36 +817,36 @@ impl CudaStorage { CudaStorageSlice::U32(out) } // The dtypes should have been checked at this point so this is an internal error. - _ => return Err(CudaError::InternalError("dtype mismatch in binary op")), + _ => Err(CudaError::InternalError("dtype mismatch in binary op"))?, }; let device = dev.clone(); Ok(Self { slice, device }) } - pub(crate) fn embedding_impl( - &self, - shape: &Shape, - stride: &[usize], - rhs: &Self, - h_size: usize, // hidden size - v_size: usize, // vocab size - ) -> Result { + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let ids = match &self.slice { - CudaStorageSlice::U32(slice) => slice, + CudaStorageSlice::U32(slice) => slice.slice(layout.start_offset()..), _ => Err(CudaError::UnexpectedDType { msg: "embedding ids should be u32", expected: DType::U32, got: self.dtype(), })?, }; + let ids = &ids; + let shape = layout.shape(); + let (v_size, h_size) = rhs_l + .shape() + .r2() + .map_err(|e| CudaError::WrappedError(Box::new(e)))?; let dims = shape.dims(); let el = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el as u32); let dev = self.device(); - let ds = dev.htod_copy([dims, stride].concat())?; + let ds = dev.htod_copy([dims, layout.stride()].concat())?; let slice = match &rhs.slice { // The kernels below assume that rhs is contiguous. CudaStorageSlice::U32(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_u32", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }?; @@ -814,6 +856,7 @@ impl CudaStorage { CudaStorageSlice::U32(out) } CudaStorageSlice::BF16(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_bf16", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }?; @@ -823,6 +866,7 @@ impl CudaStorage { CudaStorageSlice::BF16(out) } CudaStorageSlice::F16(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_f16", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }?; @@ -832,6 +876,7 @@ impl CudaStorage { CudaStorageSlice::F16(out) } CudaStorageSlice::F32(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_f32", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }?; @@ -841,6 +886,7 @@ impl CudaStorage { CudaStorageSlice::F32(out) } CudaStorageSlice::F64(arg) => { + let arg = &arg.slice(rhs_l.start_offset()..); let func = dev.get_or_load_func("emb_f64", kernels::EMBEDDINGS)?; // SAFETY: Set later by running the kernel. let out = unsafe { dev.alloc::(el * h_size) }?; @@ -854,12 +900,12 @@ impl CudaStorage { Ok(Self { slice, device }) } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, rhs: &Self, (b, m, n, k): (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_l: &Layout, + rhs_l: &Layout, ) -> Result { let elem_count = b * m * n; let dev = &self.device; @@ -868,7 +914,9 @@ impl CudaStorage { todo!("bf16") } (CudaStorageSlice::F16(lhs), CudaStorageSlice::F16(rhs)) => { - let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_stride, rhs_stride)?; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(f16::ONE, f16::ZERO, (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }?; unsafe { self.device @@ -878,7 +926,9 @@ impl CudaStorage { CudaStorageSlice::F16(out) } (CudaStorageSlice::F32(lhs), CudaStorageSlice::F32(rhs)) => { - let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }?; unsafe { self.device @@ -888,7 +938,9 @@ impl CudaStorage { CudaStorageSlice::F32(out) } (CudaStorageSlice::F64(lhs), CudaStorageSlice::F64(rhs)) => { - let cfg = gemm_config(1., 0., (b, m, n, k), lhs_stride, rhs_stride)?; + let lhs = &lhs.slice(lhs_l.start_offset()..); + let rhs = &rhs.slice(rhs_l.start_offset()..); + let cfg = gemm_config(1., 0., (b, m, n, k), lhs_l, rhs_l)?; let mut out = unsafe { dev.alloc::(elem_count) }?; unsafe { self.device @@ -907,22 +959,18 @@ impl CudaStorage { &self, dst: &mut Self, dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_l: &Layout, ) -> Result<()> { - if src_shape.rank() != src_stride.len() { - panic!("incoherent shape and strides {src_shape:?} {src_stride:?}") - } + let src_shape = src_l.shape(); let dims = src_shape.dims(); let el_count = src_shape.elem_count(); let cfg = LaunchConfig::for_num_elems(el_count as u32); let dev = &self.device; - let ds = dev.htod_copy([dims, src_stride].concat())?; + let ds = dev.htod_copy([dims, src_l.stride()].concat())?; match (&self.slice, &mut dst.slice) { (CudaStorageSlice::BF16(src), CudaStorageSlice::BF16(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_bf16", kernels::UNARY)?; @@ -933,8 +981,8 @@ impl CudaStorage { } } (CudaStorageSlice::F16(src), CudaStorageSlice::F16(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_f16", kernels::UNARY)?; @@ -945,8 +993,8 @@ impl CudaStorage { } } (CudaStorageSlice::F32(src), CudaStorageSlice::F32(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_f32", kernels::UNARY)?; @@ -957,8 +1005,8 @@ impl CudaStorage { } } (CudaStorageSlice::U32(src), CudaStorageSlice::U32(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_u32", kernels::UNARY)?; @@ -969,8 +1017,8 @@ impl CudaStorage { } } (CudaStorageSlice::F64(src), CudaStorageSlice::F64(dst)) => { - let (src, mut dst) = slice_src_and_dst(src, src_offset, dst, dst_offset); - if src_shape.is_contiguous(src_stride) { + let (src, mut dst) = slice_src_and_dst(src, src_l, dst, dst_offset); + if src_l.is_contiguous() { dev.dtod_copy(&src, &mut dst)? } else { let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?; diff --git a/candle-core/src/dtype.rs b/candle-core/src/dtype.rs index 1711c2b4..fdbfdbba 100644 --- a/candle-core/src/dtype.rs +++ b/candle-core/src/dtype.rs @@ -41,7 +41,6 @@ pub trait WithDType: Sized + Copy { } fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; - fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>; fn cpu_storage_data(s: CpuStorage) -> Result>; } @@ -75,17 +74,6 @@ macro_rules! with_dtype { }), } } - - fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> { - match s { - CpuStorage::$dtype(data) => Ok(data), - _ => Err(Error::UnexpectedDType { - expected: DType::$dtype, - got: s.dtype(), - msg: "unexpected dtype", - }), - } - } } }; } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index babc6e7d..8193b1af 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -1,5 +1,5 @@ #![allow(dead_code)] -use crate::{CpuStorage, DType, Error, Result, Shape}; +use crate::{CpuStorage, DType, Error, Layout, Result, Shape}; #[derive(thiserror::Error, Debug)] pub enum DummyError {} @@ -60,11 +60,11 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result { + pub(crate) fn affine(&self, _: &Layout, _: f64, _: f64) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn sum(&self, _: &Shape, _: &[usize], _: &[usize]) -> Result { + pub(crate) fn sum(&self, _: &Layout, _: &[usize]) -> Result { Err(Error::NotCompiledWithCudaSupport) } @@ -72,65 +72,49 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result { + pub(crate) fn to_dtype(&self, _: &Layout, _: DType) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn unary_impl(&self, _: &Shape, _: &[usize]) -> Result { + pub(crate) fn unary_impl(&self, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } pub(crate) fn binary_impl( &self, _: &Self, - _: &Shape, - _: &[usize], - _: &[usize], + _: &Layout, + _: &Layout, ) -> Result { Err(Error::NotCompiledWithCudaSupport) } pub(crate) fn where_cond( &self, - _: &Shape, - _: &[usize], + _: &Layout, _: &Self, - _: &[usize], + _: &Layout, _: &Self, - _: &[usize], + _: &Layout, ) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn embedding_impl( - &self, - _: &Shape, - _: &[usize], - _: &Self, - _: usize, - _: usize, - ) -> Result { + pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, _: &Self, _: (usize, usize, usize, usize), - _: &[usize], - _: &[usize], + _: &Layout, + _: &Layout, ) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn copy_strided_src( - &self, - _: &mut Self, - _: usize, - _: &Shape, - _: &[usize], - _: usize, - ) -> Result<()> { + pub(crate) fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } } diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs new file mode 100644 index 00000000..3f629d50 --- /dev/null +++ b/candle-core/src/layout.rs @@ -0,0 +1,140 @@ +use crate::{Error, Result, Shape}; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Layout { + shape: Shape, + // The strides are given in number of elements and not in bytes. + stride: Vec, + start_offset: usize, +} + +impl Layout { + pub fn contiguous_with_offset>(shape: S, start_offset: usize) -> Self { + let shape = shape.into(); + let stride = shape.stride_contiguous(); + Self { + shape, + stride, + start_offset, + } + } + + pub fn contiguous>(shape: S) -> Self { + Self::contiguous_with_offset(shape, 0) + } + + pub fn dims(&self) -> &[usize] { + self.shape.dims() + } + + pub fn shape(&self) -> &Shape { + &self.shape + } + + pub fn stride(&self) -> &[usize] { + &self.stride + } + + pub fn start_offset(&self) -> usize { + self.start_offset + } + + /// Returns the appropriate start and stop offset if the data is stored in a C + /// contiguous (aka row major) way. + pub fn contiguous_offsets(&self) -> Option<(usize, usize)> { + if self.is_contiguous() { + let start_o = self.start_offset; + Some((start_o, start_o + self.shape.elem_count())) + } else { + None + } + } + + /// Returns true if the data is stored in a C contiguous (aka row major) way. + pub fn is_contiguous(&self) -> bool { + self.shape.is_contiguous(&self.stride) + } + + /// Returns true if the data is stored in a Fortran contiguous (aka column major) way. + pub fn is_fortran_contiguous(&self) -> bool { + self.shape.is_fortran_contiguous(&self.stride) + } + + pub(crate) fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { + let dims = self.shape().dims(); + if dim >= dims.len() { + Err(Error::UnexpectedNumberOfDims { + expected: dim + 1, + got: dims.len(), + shape: self.shape().clone(), + })? + } + if start + length > dims[dim] { + todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}") + } + let mut dims = dims.to_vec(); + dims[dim] = length; + Ok(Self { + shape: Shape::from(dims), + stride: self.stride.clone(), + start_offset: self.start_offset + self.stride[dim] * start, + }) + } + + pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result { + let rank = self.shape.rank(); + if rank <= dim1 || rank <= dim2 { + return Err(Error::UnexpectedNumberOfDims { + expected: usize::max(dim1, dim2), + got: rank, + shape: self.shape().clone(), + }); + } + let mut stride = self.stride().to_vec(); + let mut dims = self.shape().dims().to_vec(); + dims.swap(dim1, dim2); + stride.swap(dim1, dim2); + Ok(Self { + shape: Shape::from(dims), + stride, + start_offset: self.start_offset, + }) + } + + pub fn broadcast_as>(&self, shape: S) -> Result { + let shape = shape.into(); + if shape.rank() < self.shape().rank() { + Err(Error::BroadcastIncompatibleShapes { + src_shape: self.shape().clone(), + dst_shape: shape.clone(), + })? + } + let added_dims = shape.rank() - self.shape().rank(); + let mut stride = vec![0; added_dims]; + for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..] + .iter() + .zip(self.dims().iter().zip(self.stride())) + { + let s = if dst_dim == src_dim { + src_stride + } else if src_dim != 1 { + return Err(Error::BroadcastIncompatibleShapes { + src_shape: self.shape().clone(), + dst_shape: shape, + }); + } else { + 0 + }; + stride.push(s) + } + Ok(Self { + shape, + stride, + start_offset: self.start_offset, + }) + } + + pub(crate) fn strided_index(&self) -> crate::StridedIndex { + crate::StridedIndex::new(self) + } +} diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 5771517f..6a860116 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -7,6 +7,7 @@ pub mod display; mod dtype; mod dummy_cuda_backend; mod error; +mod layout; mod npy; mod op; mod shape; @@ -19,6 +20,7 @@ pub use cpu_backend::CpuStorage; pub use device::{Device, DeviceLocation}; pub use dtype::{DType, WithDType}; pub use error::{Error, Result}; +pub use layout::Layout; pub use shape::Shape; pub use storage::Storage; use strided_index::StridedIndex; diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index e44a2db6..7acf6dd0 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -1,4 +1,4 @@ -use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Result, Shape}; +use crate::{op, CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape}; // We do not want to implement Clone on Storage as cloning may fail because of // out of memory. Instead try_clone should be used. @@ -53,38 +53,33 @@ impl Storage { } } - pub(crate) fn affine_impl( - &self, - shape: &Shape, - stride: &[usize], - mul: f64, - add: f64, - ) -> Result { + pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result { match self { Storage::Cpu(storage) => { - let storage = storage.affine_impl(shape, stride, mul, add)?; + let storage = storage.affine(layout, mul, add)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.affine_impl(shape, stride, mul, add)?; + let storage = storage.affine(layout, mul, add)?; Ok(Self::Cuda(storage)) } } } - pub(crate) fn sum(&self, shape: &Shape, stride: &[usize], s: &[usize]) -> Result { + pub(crate) fn sum(&self, layout: &Layout, s: &[usize]) -> Result { match self { Storage::Cpu(storage) => { - let storage = storage.sum(shape, stride, s)?; + let storage = storage.sum(layout, s)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.sum(shape, stride, s)?; + let storage = storage.sum(layout, s)?; Ok(Self::Cuda(storage)) } } } + // This assumes a contiguous layout and no offset. pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { match self { Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?, @@ -93,32 +88,28 @@ impl Storage { Ok(()) } - pub(crate) fn to_dtype(&self, shape: &Shape, stride: &[usize], dtype: DType) -> Result { + pub(crate) fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result { match self { Storage::Cpu(storage) => { - let storage = storage.to_dtype(shape, stride, dtype)?; + let storage = storage.to_dtype(layout, dtype)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.to_dtype(shape, stride, dtype)?; + let storage = storage.to_dtype(layout, dtype)?; Ok(Self::Cuda(storage)) } } } - pub(crate) fn unary_impl( - &self, - shape: &Shape, - stride: &[usize], - ) -> Result { + pub(crate) fn unary_impl(&self, layout: &Layout) -> Result { // TODO: Different code path for the contiguous case? match self { Storage::Cpu(storage) => { - let storage = storage.unary_impl::(shape, stride)?; + let storage = storage.unary_impl::(layout)?; Ok(Self::Cpu(storage)) } Self::Cuda(storage) => { - let storage = storage.unary_impl::(shape, stride)?; + let storage = storage.unary_impl::(layout)?; Ok(Self::Cuda(storage)) } } @@ -127,19 +118,18 @@ impl Storage { pub(crate) fn binary_impl( &self, rhs: &Self, - shape: &Shape, - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result { self.same_device(rhs, B::NAME)?; self.same_dtype(rhs, B::NAME)?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.binary_impl::(rhs, shape, lhs_stride, rhs_stride)?; + let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.binary_impl::(rhs, shape, lhs_stride, rhs_stride)?; + let storage = lhs.binary_impl::(rhs, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => { @@ -156,49 +146,41 @@ impl Storage { pub(crate) fn where_cond( &self, - shape: &Shape, - stride: &[usize], + layout: &Layout, t: &Self, - stride_t: &[usize], + layout_t: &Layout, f: &Self, - stride_f: &[usize], + layout_f: &Layout, ) -> Result { self.same_device(t, "where")?; self.same_device(f, "where")?; t.same_dtype(f, "where")?; match (self, t, f) { (Storage::Cpu(cond), Storage::Cpu(t), Storage::Cpu(f)) => { - let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?; + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; Ok(Self::Cpu(storage)) } (Self::Cuda(cond), Self::Cuda(t), Self::Cuda(f)) => { - let storage = cond.where_cond(shape, stride, t, stride_t, f, stride_f)?; + let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?; Ok(Self::Cuda(storage)) } (_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), - op: "embedding", + op: "where", }), } } - pub(crate) fn embedding_impl( - &self, - shape: &Shape, - stride: &[usize], - rhs: &Self, - hidden_size: usize, - vocab_size: usize, - ) -> Result { + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { self.same_device(rhs, "embedding")?; match (self, rhs) { (Storage::Cpu(lhs), Storage::Cpu(rhs)) => { - let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding(layout, rhs, rhs_l)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.embedding_impl(shape, stride, rhs, hidden_size, vocab_size)?; + let storage = lhs.embedding(layout, rhs, rhs_l)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { @@ -209,22 +191,22 @@ impl Storage { } } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, rhs: &Self, bmnk: (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result { self.same_device(rhs, "matmul")?; self.same_dtype(rhs, "matmul")?; match (self, rhs) { (Self::Cpu(lhs), Self::Cpu(rhs)) => { - let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; + let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; + let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { @@ -240,17 +222,11 @@ impl Storage { &self, dst: &mut Self, dst_offset: usize, - src_shape: &Shape, - src_stride: &[usize], - src_offset: usize, + src_l: &Layout, ) -> Result<()> { match (self, dst) { - (Self::Cpu(src), Self::Cpu(dst)) => { - src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset) - } - (Self::Cuda(src), Self::Cuda(dst)) => { - Ok(src.copy_strided_src(dst, dst_offset, src_shape, src_stride, src_offset)?) - } + (Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l), + (Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?), (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { lhs: lhs.device().location(), rhs: rhs.device().location(), diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index 2a23e9ec..e6d2868b 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -1,27 +1,28 @@ +use crate::Layout; + /// An iterator over offset position for items of an N-dimensional arrays stored in a /// flat buffer using some potential strides. #[derive(Debug)] pub(crate) struct StridedIndex<'a> { next_storage_index: Option, multi_index: Vec, - dims: &'a [usize], - stride: &'a [usize], + layout: &'a Layout, } impl<'a> StridedIndex<'a> { - pub(crate) fn new(dims: &'a [usize], stride: &'a [usize]) -> Self { + pub(crate) fn new(layout: &'a Layout) -> Self { + let dims = layout.dims(); let elem_count: usize = dims.iter().product(); let next_storage_index = if elem_count == 0 { None } else { // This applies to the scalar case. - Some(0) + Some(layout.start_offset()) }; StridedIndex { next_storage_index, multi_index: vec![0; dims.len()], - dims, - stride, + layout, } } } @@ -35,7 +36,12 @@ impl<'a> Iterator for StridedIndex<'a> { Some(storage_index) => storage_index, }; let mut updated = false; - for (multi_i, max_i) in self.multi_index.iter_mut().zip(self.dims.iter()).rev() { + for (multi_i, max_i) in self + .multi_index + .iter_mut() + .zip(self.layout.dims().iter()) + .rev() + { let next_i = *multi_i + 1; if next_i < *max_i { *multi_i = next_i; @@ -49,9 +55,10 @@ impl<'a> Iterator for StridedIndex<'a> { let next_storage_index = self .multi_index .iter() - .zip(self.stride.iter()) + .zip(self.layout.stride().iter()) .map(|(&x, &y)| x * y) - .sum(); + .sum::() + + self.layout.start_offset(); Some(next_storage_index) } else { None diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index feb59d3c..f64bd6f2 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -1,4 +1,4 @@ -use crate::{op::Op, storage::Storage, DType, Device, Error, Result, Shape}; +use crate::{op::Op, storage::Storage, DType, Device, Error, Layout, Result, Shape}; use std::sync::Arc; /// Unique identifier for tensors. @@ -17,9 +17,7 @@ impl TensorId { pub struct Tensor_ { id: TensorId, storage: Arc, - shape: Shape, - // The strides are given in number of elements and not in bytes. - stride: Vec, + layout: Layout, op: Option, is_variable: bool, } @@ -50,7 +48,7 @@ macro_rules! unary_op { let shape = self.shape(); let storage = self .storage - .unary_impl::(self.shape(), self.stride())?; + .unary_impl::(self.layout())?; let op = if self.track_op() { Some(Op::$op_name(self.clone())) } else { @@ -67,9 +65,8 @@ macro_rules! binary_op { let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?; let storage = self.storage.binary_impl::( &rhs.storage, - shape, - self.stride(), - rhs.stride(), + self.layout(), + rhs.layout(), )?; let op = if self.track_op() || rhs.track_op() { Some(Op::$op_name(self.clone(), rhs.clone())) @@ -107,13 +104,10 @@ fn from_storage>( op: Option, is_variable: bool, ) -> Tensor { - let shape = shape.into(); - let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(storage), - shape, - stride, + layout: Layout::contiguous(shape), op, is_variable, }; @@ -323,6 +317,7 @@ impl Tensor { unary_op!(sqrt, Sqrt); unary_op!(gelu, Gelu); unary_op!(relu, Relu); + pub fn to_scalar(&self) -> Result { if self.rank() != 0 { return Err(Error::UnexpectedNumberOfDims { @@ -342,8 +337,7 @@ impl Tensor { } pub fn affine(&self, mul: f64, add: f64) -> Result { - let shape = self.shape(); - let storage = self.storage.affine_impl(shape, self.stride(), mul, add)?; + let storage = self.storage.affine(self.layout(), mul, add)?; let op = if self.track_op() { Some(Op::Affine { arg: self.clone(), @@ -353,42 +347,25 @@ impl Tensor { } else { None }; - Ok(from_storage(storage, shape.clone(), op, false)) + Ok(from_storage(storage, self.shape(), op, false)) } /// Returns a new tensor that is a narrowed version of the input, the dimension `dim` /// ranges from `start` to `start + length`. - // TODO: Once we've refactored the shape and strides, make this return a view of the same data - // rather than copying. pub fn narrow(&self, dim: usize, start: usize, length: usize) -> Result { - let dims = self.shape().dims(); - if dim >= dims.len() { - return Err(Error::UnexpectedNumberOfDims { - expected: dim + 1, - got: dims.len(), - shape: self.shape().clone(), - }); - } - if start + length > dims[dim] { - todo!("add a proper error: out of bounds for narrow {dim} {start} {length} {dims:?}") - } - let mut dims = dims.to_vec(); - dims[dim] = length; - let adjusted_shape = Shape::from(dims); - let mut storage = self.device().zeros(&adjusted_shape, self.dtype())?; - self.storage.copy_strided_src( - &mut storage, - /* dst_offset= */ 0, - &adjusted_shape, - &self.stride, - /* src_offest= */ self.stride[dim] * start, - )?; let op = if self.track_op() { Some(Op::Narrow(self.clone(), dim, start, length)) } else { None }; - Ok(from_storage(storage, adjusted_shape, op, false)) + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + layout: self.layout().narrow(dim, start, length)?, + op, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) } pub fn softmax(&self, dim: usize) -> Result { @@ -401,9 +378,7 @@ impl Tensor { exp.broadcast_div(&sum_exp) } else { let shape = self.shape(); - let mut storage = self - .storage - .unary_impl::(shape, self.stride())?; + let mut storage = self.storage.unary_impl::(self.layout())?; // The resulting storage is contiguous. storage.divide_by_sum_over_dim(shape, dim)?; let op = if self.track_op() { @@ -416,7 +391,7 @@ impl Tensor { } pub fn sum(&self, sum_dims: &[usize]) -> Result { - let storage = self.storage.sum(self.shape(), &self.stride, sum_dims)?; + let storage = self.storage.sum(self.layout(), sum_dims)?; let op = if self.track_op() { Some(Op::Sum(self.clone(), sum_dims.to_vec())) } else { @@ -458,11 +433,11 @@ impl Tensor { let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); let batching: usize = a_dims[..dim - 2].iter().product(); - let storage = self.storage.matmul_impl( + let storage = self.storage.matmul( &rhs.storage, (batching, m, n, k), - self.stride(), - rhs.stride(), + self.layout(), + rhs.layout(), )?; let op = if self.track_op() || rhs.track_op() { Some(Op::Matmul(self.clone(), rhs.clone())) @@ -476,12 +451,11 @@ impl Tensor { let _shap = self.same_shape_binary_op(on_true, "where_cond")?; let shape = self.same_shape_binary_op(on_false, "where_cond")?; let storage = self.storage.where_cond( - shape, - self.stride(), + self.layout(), &on_true.storage, - on_true.stride(), + on_true.layout(), &on_false.storage, - on_false.stride(), + on_false.layout(), )?; let op = if self.track_op() || on_true.track_op() || on_false.track_op() { Some(Op::WhereCond( @@ -498,23 +472,19 @@ impl Tensor { pub fn embedding(ids: &Self, rhs: &Self) -> Result { if !rhs.is_contiguous() { return Err(Error::RequiresContiguous { op: "embedding" }); - } else if rhs.shape().rank() != 2 || ids.shape().rank() != 1 { + } else if rhs.rank() != 2 || ids.rank() != 1 { return Err(Error::ShapeMismatchBinaryOp { - lhs: ids.shape.clone(), - rhs: rhs.shape.clone(), + lhs: ids.shape().clone(), + rhs: rhs.shape().clone(), op: "embedding", }); } let ids_shape = ids.shape(); let seq_len = ids_shape.r1()?; - let (vocab_size, hidden_size) = rhs.shape().r2()?; - let storage = ids.storage.embedding_impl( - ids_shape, - &ids.stride, - &rhs.storage, - hidden_size, - vocab_size, - )?; + let (_, hidden_size) = rhs.shape().r2()?; + let storage = ids + .storage + .embedding(ids.layout(), &rhs.storage, rhs.layout())?; let shape: Shape = (seq_len, hidden_size).into(); let op = if ids.track_op() || rhs.track_op() { Some(Op::Embedding(ids.clone(), rhs.clone())) @@ -525,7 +495,7 @@ impl Tensor { } pub(crate) fn strided_index(&self) -> crate::StridedIndex { - crate::StridedIndex::new(self.dims(), self.stride()) + self.layout.strided_index() } /// Returns data from the underlying storage, this does not take the strides @@ -618,15 +588,20 @@ impl Tensor { } pub fn shape(&self) -> &Shape { - &self.shape + self.layout().shape() } pub fn dims(&self) -> &[usize] { self.shape().dims() } - pub fn stride(&self) -> &[usize] { - &self.stride + pub fn layout(&self) -> &Layout { + &self.layout + } + + // TODO: Rename to `stride` once the PR that introduced the layout has been merged. + pub fn stride_tmp(&self) -> &[usize] { + self.layout.stride() } pub fn rank(&self) -> usize { @@ -704,18 +679,6 @@ impl Tensor { /// Returns a tensor that is a transposed version of the input, the given dimensions are /// swapped. pub fn transpose(&self, dim1: usize, dim2: usize) -> Result { - let rank = self.rank(); - if rank <= dim1 || rank <= dim2 { - return Err(Error::UnexpectedNumberOfDims { - expected: usize::max(dim1, dim2), - got: rank, - shape: self.shape().clone(), - }); - } - let mut stride = self.stride().to_vec(); - let mut dims = self.shape().dims().to_vec(); - dims.swap(dim1, dim2); - stride.swap(dim1, dim2); let op = if self.track_op() { Some(Op::Transpose(self.clone(), dim1, dim2)) } else { @@ -724,8 +687,7 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape: Shape::from(dims), - stride, + layout: self.layout.transpose(dim1, dim2)?, op, is_variable: false, }; @@ -734,12 +696,12 @@ impl Tensor { /// Returns true if the data is stored in a C contiguous (aka row major) way. pub fn is_contiguous(&self) -> bool { - self.shape.is_contiguous(&self.stride) + self.layout.is_contiguous() } /// Returns true if the data is stored in a Fortran contiguous (aka column major) way. pub fn is_fortran_contiguous(&self) -> bool { - self.shape.is_fortran_contiguous(&self.stride) + self.layout.is_fortran_contiguous() } /// Compared to clone, this copies the actual storage but may fail because of running out of @@ -748,8 +710,7 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(self.storage.try_clone()?), - shape: self.shape.clone(), - stride: self.stride.clone(), + layout: self.layout.clone(), op: None, // TODO is_variable: false, }; @@ -762,8 +723,7 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape: self.shape.clone(), - stride: self.stride.clone(), + layout: self.layout.clone(), op: None, is_variable: false, }; @@ -796,8 +756,7 @@ impl Tensor { let tensor_ = Tensor_ { id: TensorId::new(), storage: Arc::new(storage), - shape: self.shape.clone(), - stride: self.stride.clone(), + layout: self.layout.clone(), op, is_variable: false, }; @@ -810,7 +769,7 @@ impl Tensor { pub fn broadcast_left>(&self, left_shape: S) -> Result { let left_shape = left_shape.into(); let mut dims = left_shape.into_dims(); - dims.extend(self.shape.dims()); + dims.extend(self.dims()); self.broadcast_as(dims) } @@ -820,36 +779,10 @@ impl Tensor { } else { None }; - let shape = shape.into(); - if shape.rank() < self.rank() { - return Err(Error::BroadcastIncompatibleShapes { - src_shape: self.shape().clone(), - dst_shape: shape, - }); - } - let added_dims = shape.rank() - self.rank(); - let mut stride = vec![0; added_dims]; - for (&dst_dim, (&src_dim, &src_stride)) in shape.dims()[added_dims..] - .iter() - .zip(self.dims().iter().zip(self.stride())) - { - let s = if dst_dim == src_dim { - src_stride - } else if src_dim != 1 { - return Err(Error::BroadcastIncompatibleShapes { - src_shape: self.shape().clone(), - dst_shape: shape, - }); - } else { - 0 - }; - stride.push(s) - } let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape, - stride, + layout: self.layout.broadcast_as(shape)?, op, is_variable: false, }; @@ -866,7 +799,7 @@ impl Tensor { Ok(self.clone()) } else { let shape = self.shape(); - let storage = self.storage.to_dtype(shape, self.stride(), dtype)?; + let storage = self.storage.to_dtype(self.layout(), dtype)?; let op = if self.track_op() { Some(Op::ToDType(self.clone())) } else { @@ -883,7 +816,7 @@ impl Tensor { let shape = self.shape(); let mut storage = self.device().zeros(shape, self.dtype())?; self.storage - .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?; + .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage( storage, shape.clone(), @@ -913,12 +846,10 @@ impl Tensor { None }; if self.is_contiguous() { - let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), storage: self.storage.clone(), - shape, - stride, + layout: Layout::contiguous_with_offset(shape, self.layout.start_offset()), op, is_variable: false, }; @@ -926,7 +857,7 @@ impl Tensor { } else { let mut storage = self.device().zeros(&shape, self.dtype())?; self.storage - .copy_strided_src(&mut storage, 0, &self.shape, &self.stride, 0)?; + .copy_strided_src(&mut storage, 0, self.layout())?; Ok(from_storage(storage, shape, op, false)) } } @@ -1063,7 +994,7 @@ impl Tensor { for (arg, &offset) in args.iter().zip(offsets.iter()) { let arg = arg.as_ref(); arg.storage - .copy_strided_src(&mut storage, offset, &arg.shape, &arg.stride, 0)?; + .copy_strided_src(&mut storage, offset, arg.layout())?; } Ok(from_storage(storage, shape, op, false)) } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 78ca4b05..8ac0c9f2 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -263,12 +263,12 @@ fn matmul(device: &Device) -> Result<()> { let a_tt = a.t()?.contiguous()?.t()?; assert!(!a_tt.is_contiguous()); assert_eq!(a.dims(), a_tt.dims()); - assert_eq!(a_tt.stride(), &[6, 1, 2]); + assert_eq!(a_tt.stride_tmp(), &[6, 1, 2]); let b_tt = b.t()?.contiguous()?.t()?; assert!(!b_tt.is_contiguous()); assert_eq!(b.dims(), b_tt.dims()); - assert_eq!(b_tt.stride(), &[6, 1, 3]); + assert_eq!(b_tt.stride_tmp(), &[6, 1, 3]); assert_eq!(a_tt.matmul(&b)?.to_vec3::()?, &expected); assert_eq!(a.matmul(&b_tt)?.to_vec3::()?, &expected);