diff --git a/Cargo.toml b/Cargo.toml index 8c601e18..72eb00ce 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,11 +16,11 @@ members = [ ] [dependencies] -ggblas = "0.1.0" safetensors = "0.3.1" thiserror = "1" cudarc = { version = "0.9.9", optional = true } candle-kernels = { path = "kernels", optional = true } +gemm = "0.15.4" [dev-dependencies] anyhow = "1" diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index c71536ed..0eb4270a 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -1,6 +1,6 @@ use crate::storage::{BinaryOp, UnaryOp}; use crate::{DType, Error, Result, Shape, StridedIndex}; -use ggblas::batched_sgemm; +use gemm::{gemm, Parallelism}; // TODO: Think about whether we would be better off with a dtype and // a buffer as an owned slice of bytes. @@ -113,28 +113,83 @@ impl CpuStorage { lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { - println!("rhs {rhs:?}"); - println!("lhs_stride {lhs_stride:?}"); - println!("rhs_stride {rhs_stride:?}"); - // todo!("matmul"); let a_skip: usize = m * k; let b_skip: usize = n * k; let c_skip: usize = m * n; - let mut c = Self::F32(vec![0.0; b * m * n]); + let rank = lhs_stride.len(); + let lhs_cs = lhs_stride[rank - 1]; + let lhs_rs = lhs_stride[rank - 2]; - batched_sgemm( - self.as_slice()?, - a_skip, - rhs.as_slice()?, - b_skip, - c.as_mut_slice()?, - c_skip, - m, - n, - k, - b, - ); + let rhs_cs = rhs_stride[rank - 1]; + let rhs_rs = rhs_stride[rank - 2]; + + if lhs_stride.len() > 2 { + let lhs_batch_stride = &lhs_stride[..rank - 2]; + let rhs_batch_stride = &rhs_stride[..rank - 2]; + + if lhs_batch_stride != &[a_skip] || rhs_batch_stride != &[b_skip] { + // Temporary error before we support abitrary striding. + return Err(Error::UnexpectedStriding); + } + } + + let mut dst = vec![0.0; b * m * n]; + + let dst_shape: Shape = (m, n).into(); + let dst_strides = dst_shape.stride_contiguous(); + let dst_rs = dst_strides[0]; + let dst_cs = dst_strides[1]; + + for step in 0..b { + let lhs_p = &self.as_slice::()?[step * a_skip..]; + let rhs_p = &rhs.as_slice::()?[step * b_skip..]; + let dst_p = &mut dst[step * c_skip..]; + unsafe { + gemm( + // m: usize, + m, + // n: usize, + n, + // k: usize, + k, + // dst: *mut T, + dst_p.as_mut_ptr(), + // dst_cs: isize, + dst_cs as isize, + // dst_rs: isize, + dst_rs as isize, + // read_dst: bool, + false, + // lhs: *const T, + lhs_p.as_ptr(), + // lhs_cs: isize, + lhs_cs as isize, + // lhs_rs: isize, + lhs_rs as isize, + // rhs: *const T, + rhs_p.as_ptr(), + // rhs_cs: isize, + rhs_cs as isize, + // rhs_rs: isize, + rhs_rs as isize, + // alpha: T, + 1.0, + // beta: T, + 1.0, + // conj_dst: bool, + false, + // conj_lhs: bool, + false, + // conj_rhs: bool, + true, + // parallelism: Parallelism + Parallelism::None, + ) + } + } + + let c = Self::F32(dst); Ok(c) } @@ -175,31 +230,31 @@ mod tests { #[test] fn simple_matmul() -> Result<()> { let data = vec![1.0f32, 2.0, 3.0, 4.0]; - let a = Tensor::from_slice(&data, (2, 2), Device::Cpu)?; + let a = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?; let data = vec![1.0f32, 2.0, 3.0, 4.0]; - let b = Tensor::from_slice(&data, (2, 2), Device::Cpu)?; + let b = Tensor::from_slice(&data, (2, 2), &Device::Cpu)?; let c = a.matmul(&b)?; assert_eq!(c.to_vec2::()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]); let data = vec![1.0f32, 2.0]; - let a = Tensor::from_slice(&data, (2, 1), Device::Cpu)?; + let a = Tensor::from_slice(&data, (2, 1), &Device::Cpu)?; let data = vec![3.0f32, 4.0]; - let b = Tensor::from_slice(&data, (1, 2), Device::Cpu)?; + let b = Tensor::from_slice(&data, (1, 2), &Device::Cpu)?; let c = a.matmul(&b)?; assert_eq!(c.to_vec2::()?, &[&[3.0, 4.0], &[6.0, 8.0]]); let data: Vec<_> = (0..6).map(|i| i as f32).collect(); - let a = Tensor::from_slice(&data, (2, 3), Device::Cpu)?; + let a = Tensor::from_slice(&data, (2, 3), &Device::Cpu)?; let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect(); - let b = Tensor::from_slice(&data, (3, 2), Device::Cpu)?; + let b = Tensor::from_slice(&data, (3, 2), &Device::Cpu)?; let c = a.matmul(&b)?; assert_eq!(c.to_vec2::()?, &[&[16., 19.], &[52., 64.]]); let data: Vec<_> = (0..12).map(|i| i as f32).collect(); - let a = Tensor::from_slice(&data, (2, 2, 3), Device::Cpu)?; + let a = Tensor::from_slice(&data, (2, 2, 3), &Device::Cpu)?; let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect(); - let b = Tensor::from_slice(&data, (2, 3, 2), Device::Cpu)?; + let b = Tensor::from_slice(&data, (2, 3, 2), &Device::Cpu)?; let c = a.matmul(&b)?; assert_eq!( c.to_vec3::()?, diff --git a/src/error.rs b/src/error.rs index 27201cb4..6f40622c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -40,6 +40,9 @@ pub enum Error { shape: Shape, }, + #[error("temporary error where matmul doesn't support arbitrary striding")] + UnexpectedStriding, + #[error(transparent)] Cuda(#[from] crate::CudaError), } diff --git a/src/tensor.rs b/src/tensor.rs index 7607171c..7274c557 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -147,10 +147,11 @@ impl Tensor { pub fn new_impl( array: A, + shape: Shape, device: &Device, is_variable: bool, ) -> Result { - let shape = array.shape()?; + // let shape = array.shape()?; let storage = device.storage(array)?; let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { @@ -165,31 +166,29 @@ impl Tensor { } pub fn new(array: A, device: &Device) -> Result { - Self::new_impl(array, device, false) + let shape = array.shape()?.clone(); + Self::new_impl(array, shape, device, false) } pub fn var(array: A, device: &Device) -> Result { - Self::new_impl(array, device, true) + let shape = array.shape()?.clone(); + Self::new_impl(array, shape, device, true) } pub fn from_slice, D: crate::WithDType>( - a: &[D], + array: &[D], shape: S, - device: Device, + device: &Device, ) -> Result { - let shape = shape.into(); - let storage = device.storage(a)?; - let stride = shape.stride_contiguous(); - let is_variable = false; - let tensor_ = Tensor_ { - id: TensorId::new(), - storage, - shape, - stride, - op: None, - is_variable, - }; - Ok(Self(Arc::new(tensor_))) + Self::new_impl(array, shape.into(), device, false) + } + + pub fn var_from_slice, D: crate::WithDType>( + array: &[D], + shape: S, + device: &Device, + ) -> Result { + Self::new_impl(array, shape.into(), device, true) } pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { @@ -260,6 +259,7 @@ impl Tensor { let dim = a_dims.len(); + // TODO // if dim < 2 { // return Err(SmeltError::InsufficientRank { minimum_rank: 2 }); // } @@ -309,6 +309,13 @@ impl Tensor { crate::StridedIndex::new(self.dims(), self.stride()) } + pub fn as_slice(&self) -> Result<&[S]> { + match &self.storage { + Storage::Cpu(cpu_storage) => S::cpu_storage_as_slice(cpu_storage), + Storage::Cuda { .. } => todo!(), + } + } + pub fn to_vec1(&self) -> Result> { if self.rank() != 1 { return Err(Error::UnexpectedNumberOfDims { @@ -404,6 +411,31 @@ impl Tensor { self.id } + pub fn t(&self) -> Result { + let mut stride = self.stride().to_vec(); + let mut shape = self.shape().clone(); + let n = stride.len(); + if n < 2 { + return Err(Error::UnexpectedNumberOfDims { + expected: 2, + got: n, + shape: self.shape().clone(), + }); + } + (shape.0[n - 2], shape.0[n - 1]) = (shape.0[n - 1], shape.0[n - 2]); + (stride[n - 2], stride[n - 1]) = (stride[n - 1], stride[n - 2]); + let tensor_ = Tensor_ { + id: TensorId::new(), + storage: self.storage.clone(), + shape, + stride, + // TODO The op should have a backward + op: None, + is_variable: false, + }; + Ok(Tensor(Arc::new(tensor_))) + } + pub fn is_contiguous(&self) -> bool { self.shape.is_contiguous(&self.stride) } @@ -514,37 +546,17 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::Matmul(_lhs, _rhs) => { - // let (m, k) = lhs.shape; - // let n = rhs.shape.1; - // let strides = (m, n).strides(); - // Self::matmul( - // (m, n, k), - // true, - // grad_out.as_ptr(), - // strides, - // rhs.data.as_ptr(), - // [rhs.strides[1], rhs.strides[0]], - // grad_lhs.as_mut_ptr(), - // lhs.strides, - // ); - // Self::matmul( - // (k, m, n), - // true, - // lhs.data.as_ptr(), - // [lhs.strides[1], lhs.strides[0]], - // grad_out.as_ptr(), - // strides, - // grad_rhs.as_mut_ptr(), - // rhs.strides, - // ); + Op::Matmul(lhs, rhs) => { + // Skipping checks, the op went ok, we can skip + // the matmul size checks for now. - // let lhs_grad = grad.matmul(rhs)?; - // let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); - // *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; - // let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; - // let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); - // *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + let lhs_grad = grad.matmul(&rhs.t()?)?; + let lhs_sum_grad = grads.or_insert(lhs)?; + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + + let rhs_grad = lhs.t()?.matmul(&grad)?; + let rhs_sum_grad = grads.or_insert(rhs)?; + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } Op::Affine { arg, mul, .. } => { let arg_grad = grad.affine(*mul, 0.)?; diff --git a/tests/grad_tests.rs b/tests/grad_tests.rs index 56186e5d..77a32dfe 100644 --- a/tests/grad_tests.rs +++ b/tests/grad_tests.rs @@ -1,5 +1,5 @@ use anyhow::{Context, Result}; -use candle::{Device, Tensor}; +use candle::{Device, Shape, Tensor}; #[test] fn simple_grad() -> Result<()> { @@ -14,3 +14,27 @@ fn simple_grad() -> Result<()> { assert_eq!(grad_x.to_vec1::()?, [11., 7., 13.]); Ok(()) } + +#[test] +fn matmul_grad() -> Result<()> { + let data: Vec<_> = (0..12).map(|i| i as f32).collect(); + let x = Tensor::var_from_slice(&data, (2, 2, 3), &Device::Cpu)?; + let data: Vec<_> = (0..12).map(|i| i as f32).collect(); + let y = Tensor::var_from_slice(&data, (2, 3, 2), &Device::Cpu)?; + + let c = x.matmul(&y)?; + let grads = c.backward()?; + let grad_x = grads.get(&x).context("no grad for x")?; + let grad_y = grads.get(&y).context("no grad for y")?; + assert_eq!(grad_x.shape(), &Shape::from((2, 2, 3))); + assert_eq!(grad_y.shape(), &Shape::from((2, 3, 2))); + assert_eq!( + grad_x.as_slice::()?, + &[1., 5., 9., 1., 5., 9., 13., 17., 21., 13., 17., 21.] + ); + assert_eq!( + grad_y.as_slice::()?, + &[3., 3., 5., 5., 7., 7., 15., 15., 17., 17., 19., 19.] + ); + Ok(()) +}