diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 4f63ea98..1c5caa82 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -1,5 +1,5 @@ use crate::op::{BinaryOp, UnaryOp}; -use crate::{DType, Error, Layout, Result, Shape, StridedIndex}; +use crate::{DType, Error, Layout, Result, Shape}; use gemm::{gemm, Parallelism}; use half::{bf16, f16}; @@ -81,14 +81,13 @@ fn unary_map U>(vs: &[T], layout: &Layout, mut // This function maps over two strided index sequences. fn binary_map T>( - shape: &Shape, lhs_layout: &Layout, rhs_layout: &Layout, lhs: &[T], rhs: &[T], mut f: F, ) -> Vec { - let dims = shape.dims(); + let shape = lhs_layout.shape(); if lhs_layout.is_contiguous() && rhs_layout.is_contiguous() { (0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect() } else { @@ -148,17 +147,19 @@ fn copy_strided_src_( } } -fn matmul_impl( +fn matmul( lhs: &[T], rhs: &[T], (b, m, n, k): (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result> { let a_skip: usize = m * k; let b_skip: usize = n * k; let c_skip: usize = m * n; + let lhs_stride = lhs_layout.stride(); + let rhs_stride = rhs_layout.stride(); let rank = lhs_stride.len(); let lhs_cs = lhs_stride[rank - 1]; let lhs_rs = lhs_stride[rank - 2]; @@ -512,29 +513,28 @@ impl CpuStorage { pub(crate) fn binary_impl( &self, rhs: &Self, - shape: &Shape, lhs_layout: &Layout, rhs_layout: &Layout, ) -> Result { match (self, rhs) { (Self::BF16(lhs), Self::BF16(rhs)) => { - let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::bf16); + let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::bf16); Ok(Self::BF16(data)) } (Self::F16(lhs), Self::F16(rhs)) => { - let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f16); + let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f16); Ok(Self::F16(data)) } (Self::F32(lhs), Self::F32(rhs)) => { - let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f32); + let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f32); Ok(Self::F32(data)) } (Self::F64(lhs), Self::F64(rhs)) => { - let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::f64); + let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::f64); Ok(Self::F64(data)) } (Self::U32(lhs), Self::U32(rhs)) => { - let data = binary_map(shape, lhs_layout, rhs_layout, lhs, rhs, B::u32); + let data = binary_map(lhs_layout, rhs_layout, lhs, rhs, B::u32); Ok(Self::U32(data)) } _ => { @@ -622,24 +622,24 @@ impl CpuStorage { map1!(vs, take_impl1, ids, layout, vocab_size, hidden_size) } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, rhs: &Self, bmnk: (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result { match (self, rhs) { (CpuStorage::F16(lhs), CpuStorage::F16(rhs)) => { - let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?; + let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::F16(dst)) } (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { - let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?; + let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::F32(dst)) } (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { - let dst = matmul_impl(lhs, rhs, bmnk, lhs_stride, rhs_stride)?; + let dst = matmul(lhs, rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::F64(dst)) } _ => Err(Error::DTypeMismatchBinaryOp { diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index babc6e7d..ef079812 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -1,5 +1,5 @@ #![allow(dead_code)] -use crate::{CpuStorage, DType, Error, Result, Shape}; +use crate::{CpuStorage, DType, Error, Layout, Result, Shape}; #[derive(thiserror::Error, Debug)] pub enum DummyError {} @@ -60,11 +60,11 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn affine_impl(&self, _: &Shape, _: &[usize], _: f64, _: f64) -> Result { + pub(crate) fn affine(&self, _: &Layout, _: f64, _: f64) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn sum(&self, _: &Shape, _: &[usize], _: &[usize]) -> Result { + pub(crate) fn sum(&self, _: &Layout, _: &[usize]) -> Result { Err(Error::NotCompiledWithCudaSupport) } @@ -72,65 +72,49 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn to_dtype(&self, _: &Shape, _: &[usize], _: DType) -> Result { + pub(crate) fn to_dtype(&self, _: &Layout, _: DType) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn unary_impl(&self, _: &Shape, _: &[usize]) -> Result { + pub(crate) fn unary_impl(&self, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } pub(crate) fn binary_impl( &self, _: &Self, - _: &Shape, - _: &[usize], - _: &[usize], + _: &Layout, + _: &Layout, ) -> Result { Err(Error::NotCompiledWithCudaSupport) } pub(crate) fn where_cond( &self, - _: &Shape, - _: &[usize], + _: &Layout, _: &Self, - _: &[usize], + _: &Layout, _: &Self, - _: &[usize], + _: &Layout, ) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn embedding_impl( - &self, - _: &Shape, - _: &[usize], - _: &Self, - _: usize, - _: usize, - ) -> Result { + pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: usize, _: usize) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, _: &Self, _: (usize, usize, usize, usize), - _: &[usize], - _: &[usize], + _: &Layout, + _: &Layout, ) -> Result { Err(Error::NotCompiledWithCudaSupport) } - pub(crate) fn copy_strided_src( - &self, - _: &mut Self, - _: usize, - _: &Shape, - _: &[usize], - _: usize, - ) -> Result<()> { + pub(crate) fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> { Err(Error::NotCompiledWithCudaSupport) } } diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index 6ba0d79a..3f629d50 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -106,7 +106,7 @@ impl Layout { if shape.rank() < self.shape().rank() { Err(Error::BroadcastIncompatibleShapes { src_shape: self.shape().clone(), - dst_shape: shape, + dst_shape: shape.clone(), })? } let added_dims = shape.rank() - self.shape().rank(); @@ -135,6 +135,6 @@ impl Layout { } pub(crate) fn strided_index(&self) -> crate::StridedIndex { - crate::StridedIndex::new(&self) + crate::StridedIndex::new(self) } } diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 2f2d8b75..2c9624c7 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -79,6 +79,7 @@ impl Storage { } } + // This assumes a contiguous layout and no offset. pub(crate) fn divide_by_sum_over_dim(&mut self, shape: &Shape, dim: usize) -> Result<()> { match self { Storage::Cpu(storage) => storage.divide_by_sum_over_dim(shape, dim)?, @@ -196,22 +197,22 @@ impl Storage { } } - pub(crate) fn matmul_impl( + pub(crate) fn matmul( &self, rhs: &Self, bmnk: (usize, usize, usize, usize), - lhs_stride: &[usize], - rhs_stride: &[usize], + lhs_layout: &Layout, + rhs_layout: &Layout, ) -> Result { self.same_device(rhs, "matmul")?; self.same_dtype(rhs, "matmul")?; match (self, rhs) { (Self::Cpu(lhs), Self::Cpu(rhs)) => { - let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; + let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Cpu(storage)) } (Self::Cuda(lhs), Self::Cuda(rhs)) => { - let storage = lhs.matmul_impl(rhs, bmnk, lhs_stride, rhs_stride)?; + let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?; Ok(Self::Cuda(storage)) } (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 094c60a3..b04e90b1 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -432,7 +432,7 @@ impl Tensor { let c_shape = Shape::from(&a_dims[..dim - 2]).extend(&[m, n]); let batching: usize = a_dims[..dim - 2].iter().product(); - let storage = self.storage.matmul_impl( + let storage = self.storage.matmul( &rhs.storage, (batching, m, n, k), self.layout(), @@ -587,7 +587,7 @@ impl Tensor { } pub fn shape(&self) -> &Shape { - &self.layout().shape() + self.layout().shape() } pub fn dims(&self) -> &[usize] { @@ -600,7 +600,7 @@ impl Tensor { // TODO: Rename to `stride` once the PR that introduced the layout has been merged. pub fn stride_tmp(&self) -> &[usize] { - &self.layout.stride() + self.layout.stride() } pub fn rank(&self) -> usize {