From ce977b489e4863cf5c53495a093c0efef2d41013 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 21 Jun 2023 16:52:35 +0200 Subject: [PATCH] Adding matmul? --- Cargo.toml | 1 + src/cpu_backend.rs | 83 +++++++++++++++++++++++++++++ src/device.rs | 2 +- src/dtype.rs | 11 ++++ src/op.rs | 1 + src/storage.rs | 18 +++++++ src/tensor.rs | 130 ++++++++++++++++++++++++++++++++++++++++++++- 7 files changed, 243 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 883664fc..b2e2a890 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ members = [ ] [dependencies] +ggblas = "0.1.0" safetensors = "0.3.1" thiserror = "1" cudarc = { version = "0.9.9", optional = true } diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 01c17245..c71536ed 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -1,5 +1,6 @@ use crate::storage::{BinaryOp, UnaryOp}; use crate::{DType, Error, Result, Shape, StridedIndex}; +use ggblas::batched_sgemm; // TODO: Think about whether we would be better off with a dtype and // a buffer as an owned slice of bytes. @@ -17,6 +18,14 @@ impl CpuStorage { } } + pub fn as_slice(&self) -> Result<&[D]> { + D::cpu_storage_as_slice(self) + } + + pub fn as_mut_slice(&mut self) -> Result<&mut [D]> { + D::cpu_storage_as_mut_slice(self) + } + pub(crate) fn affine_impl( &self, shape: &Shape, @@ -97,6 +106,38 @@ impl CpuStorage { } } + pub(crate) fn matmul_impl( + &self, + rhs: &Self, + (b, m, n, k): (usize, usize, usize, usize), + lhs_stride: &[usize], + rhs_stride: &[usize], + ) -> Result { + println!("rhs {rhs:?}"); + println!("lhs_stride {lhs_stride:?}"); + println!("rhs_stride {rhs_stride:?}"); + // todo!("matmul"); + let a_skip: usize = m * k; + let b_skip: usize = n * k; + let c_skip: usize = m * n; + + let mut c = Self::F32(vec![0.0; b * m * n]); + + batched_sgemm( + self.as_slice()?, + a_skip, + rhs.as_slice()?, + b_skip, + c.as_mut_slice()?, + c_skip, + m, + n, + k, + b, + ); + Ok(c) + } + pub(crate) fn ones_impl(shape: &Shape, dtype: DType) -> Self { let elem_count = shape.elem_count(); match dtype { @@ -125,3 +166,45 @@ impl CpuStorage { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{Device, Tensor}; + + #[test] + fn simple_matmul() -> Result<()> { + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let a = Tensor::from_slice(&data, (2, 2), Device::Cpu)?; + let data = vec![1.0f32, 2.0, 3.0, 4.0]; + let b = Tensor::from_slice(&data, (2, 2), Device::Cpu)?; + + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::()?, &[&[7.0f32, 10.0], &[15.0, 22.0]]); + + let data = vec![1.0f32, 2.0]; + let a = Tensor::from_slice(&data, (2, 1), Device::Cpu)?; + let data = vec![3.0f32, 4.0]; + let b = Tensor::from_slice(&data, (1, 2), Device::Cpu)?; + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::()?, &[&[3.0, 4.0], &[6.0, 8.0]]); + + let data: Vec<_> = (0..6).map(|i| i as f32).collect(); + let a = Tensor::from_slice(&data, (2, 3), Device::Cpu)?; + let data: Vec<_> = (0..6).map(|i| (i + 2) as f32).collect(); + let b = Tensor::from_slice(&data, (3, 2), Device::Cpu)?; + let c = a.matmul(&b)?; + assert_eq!(c.to_vec2::()?, &[&[16., 19.], &[52., 64.]]); + + let data: Vec<_> = (0..12).map(|i| i as f32).collect(); + let a = Tensor::from_slice(&data, (2, 2, 3), Device::Cpu)?; + let data: Vec<_> = (0..12).map(|i| (i + 2) as f32).collect(); + let b = Tensor::from_slice(&data, (2, 3, 2), Device::Cpu)?; + let c = a.matmul(&b)?; + assert_eq!( + c.to_vec3::()?, + &[&[&[16., 19.], &[52., 64.]], &[&[214., 235.], &[304., 334.]]] + ); + Ok(()) + } +} diff --git a/src/device.rs b/src/device.rs index ab7bad26..8acb1338 100644 --- a/src/device.rs +++ b/src/device.rs @@ -101,7 +101,7 @@ impl Device { } } - pub(crate) fn tensor(&self, array: A) -> Result { + pub(crate) fn storage(&self, array: A) -> Result { match self { Device::Cpu => Ok(Storage::Cpu(array.to_cpu_storage())), Device::Cuda(device) => { diff --git a/src/dtype.rs b/src/dtype.rs index fd0eaa1b..f6249ff2 100644 --- a/src/dtype.rs +++ b/src/dtype.rs @@ -25,6 +25,7 @@ pub trait WithDType: Sized + Copy { } fn cpu_storage_as_slice(s: &CpuStorage) -> Result<&[Self]>; + fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]>; } macro_rules! with_dtype { @@ -45,6 +46,16 @@ macro_rules! with_dtype { }), } } + + fn cpu_storage_as_mut_slice(s: &mut CpuStorage) -> Result<&mut [Self]> { + match s { + CpuStorage::$dtype(data) => Ok(data), + _ => Err(Error::UnexpectedDType { + expected: DType::$dtype, + got: s.dtype(), + }), + } + } } }; } diff --git a/src/op.rs b/src/op.rs index 240ecba3..157ce3b3 100644 --- a/src/op.rs +++ b/src/op.rs @@ -5,6 +5,7 @@ pub(crate) enum Op { Mul(Tensor, Tensor), Sub(Tensor, Tensor), Div(Tensor, Tensor), + Matmul(Tensor, Tensor), #[allow(dead_code)] // add is currently unused. Affine { diff --git a/src/storage.rs b/src/storage.rs index 573cf945..f1a2d5a0 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -241,4 +241,22 @@ impl Storage { pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result { self.unary_impl::(shape, stride) } + + pub(crate) fn matmul_impl( + &self, + rhs: &Self, + bmnk: (usize, usize, usize, usize), + lhs_stride: &[usize], + rhs_stride: &[usize], + ) -> Result { + self.same_device(rhs, "matmul")?; + self.same_dtype(rhs, "matmul")?; + match (self, rhs) { + (Storage::Cpu(storage), Storage::Cpu(rhs_storage)) => { + let storage = storage.matmul_impl(rhs_storage, bmnk, lhs_stride, rhs_stride)?; + Ok(Self::Cpu(storage)) + } + _ => todo!(), + } + } } diff --git a/src/tensor.rs b/src/tensor.rs index e8e01d5c..e55050c6 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -151,7 +151,7 @@ impl Tensor { is_variable: bool, ) -> Result { let shape = array.shape()?; - let storage = device.tensor(array)?; + let storage = device.storage(array)?; let stride = shape.stride_contiguous(); let tensor_ = Tensor_ { id: TensorId::new(), @@ -172,6 +172,26 @@ impl Tensor { Self::new_impl(array, device, true) } + pub fn from_slice, D: crate::WithDType>( + a: &[D], + shape: S, + device: Device, + ) -> Result { + let shape = shape.into(); + let storage = device.storage(a); + let stride = shape.stride_contiguous(); + let is_variable = false; + let tensor_ = Tensor_ { + id: TensorId::new(), + storage, + shape, + stride, + op: None, + is_variable, + }; + Ok(Self(Arc::new(tensor_))) + } + pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { let lhs = self.shape(); let rhs = rhs.shape(); @@ -234,6 +254,57 @@ impl Tensor { Ok(Self(Arc::new(tensor_))) } + pub fn matmul(&self, rhs: &Self) -> Result { + let a_dims = self.shape().dims(); + let b_dims = rhs.shape().dims(); + + let dim = a_dims.len(); + + // if dim < 2 { + // return Err(SmeltError::InsufficientRank { minimum_rank: 2 }); + // } + if b_dims.len() != dim { + return Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + }); + } + + let m = a_dims[dim - 2]; + let k = a_dims[dim - 1]; + let k2 = b_dims[dim - 2]; + let n = b_dims[dim - 1]; + if k != k2 { + return Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op: "matmul", + }); + } + + let mut c_shape: Vec<_> = a_dims[..dim - 2].into(); + c_shape.extend(&[m, n]); + let c_shape: Shape = Shape(c_shape); + let batching: usize = a_dims[..dim - 2].iter().product(); + + let storage = self.storage.matmul_impl( + &rhs.storage, + (batching, m, n, k), + self.stride(), + rhs.stride(), + )?; + let tensor_ = Tensor_ { + id: TensorId::new(), + storage, + shape: c_shape.clone(), + stride: c_shape.stride_contiguous(), + op: Some(Op::Matmul(self.clone(), rhs.clone())), + is_variable: false, + }; + Ok(Self(Arc::new(tensor_))) + } + pub(crate) fn strided_index(&self) -> crate::StridedIndex { crate::StridedIndex::new(self.dims(), self.stride()) } @@ -279,6 +350,28 @@ impl Tensor { } } + pub fn to_vec3(&self) -> Result>>> { + let (dim1, dim2, dim3) = self.shape().r3()?; + match &self.storage { + Storage::Cpu(cpu_storage) => { + let data = S::cpu_storage_as_slice(cpu_storage)?; + let mut top_rows = vec![]; + let mut src_index = self.strided_index(); + for _idx in 0..dim1 { + let mut rows = vec![]; + for _jdx in 0..dim2 { + let row = (0..dim3).map(|_| data[src_index.next().unwrap()]).collect(); + rows.push(row) + } + top_rows.push(rows); + } + assert!(src_index.next().is_none()); + Ok(top_rows) + } + Storage::Cuda { .. } => todo!(), + } + } + pub fn dtype(&self) -> DType { self.storage.dtype() } @@ -340,7 +433,8 @@ impl Tensor { Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) | Op::Sub(lhs, rhs) - | Op::Div(lhs, rhs) => { + | Op::Div(lhs, rhs) + | Op::Matmul(lhs, rhs) => { let (tg, nodes) = walk(lhs, nodes, already_seen); track_grad |= tg; let (tg, nodes) = walk(rhs, nodes, already_seen); @@ -420,6 +514,38 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } + Op::Matmul(lhs, rhs) => { + // let (m, k) = lhs.shape; + // let n = rhs.shape.1; + // let strides = (m, n).strides(); + // Self::matmul( + // (m, n, k), + // true, + // grad_out.as_ptr(), + // strides, + // rhs.data.as_ptr(), + // [rhs.strides[1], rhs.strides[0]], + // grad_lhs.as_mut_ptr(), + // lhs.strides, + // ); + // Self::matmul( + // (k, m, n), + // true, + // lhs.data.as_ptr(), + // [lhs.strides[1], lhs.strides[0]], + // grad_out.as_ptr(), + // strides, + // grad_rhs.as_mut_ptr(), + // rhs.strides, + // ); + + let lhs_grad = grad.matmul(rhs)?; + let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like()); + *lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?; + let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?; + let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like()); + *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; + } Op::Affine { arg, mul, .. } => { let arg_grad = grad.affine(*mul, 0.)?; let sum_grad = grads.or_insert(arg)?;