From d9cb1917ce88abcbb51658870027ecf9a96bd59a Mon Sep 17 00:00:00 2001 From: laurent Date: Tue, 20 Jun 2023 12:04:01 +0100 Subject: [PATCH] Add some unary ops. --- src/op.rs | 2 + src/shape.rs | 8 +-- src/storage.rs | 140 +++++++++++++++++++++++++++++++++++-------------- src/tensor.rs | 81 ++++++++++++++++++---------- 4 files changed, 158 insertions(+), 73 deletions(-) diff --git a/src/op.rs b/src/op.rs index fa09c6be..e991c120 100644 --- a/src/op.rs +++ b/src/op.rs @@ -4,5 +4,7 @@ use crate::Tensor; pub(crate) enum Op { Add(Tensor, Tensor), Mul(Tensor, Tensor), + Sqr(Tensor), + Sqrt(Tensor), // TODO: Support for custom ops. } diff --git a/src/shape.rs b/src/shape.rs index d0fe0483..3a23442d 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -91,18 +91,18 @@ impl Shape { } extract_dims!(r0, 0, |_: &Vec| (), ()); - extract_dims!(r1, 1, |d: &Vec| d[0], usize); - extract_dims!(r2, 2, |d: &Vec| (d[0], d[1]), (usize, usize)); + extract_dims!(r1, 1, |d: &[usize]| d[0], usize); + extract_dims!(r2, 2, |d: &[usize]| (d[0], d[1]), (usize, usize)); extract_dims!( r3, 3, - |d: &Vec| (d[0], d[1], d[2]), + |d: &[usize]| (d[0], d[1], d[2]), (usize, usize, usize) ); extract_dims!( r4, 4, - |d: &Vec| (d[0], d[1], d[2], d[3]), + |d: &[usize]| (d[0], d[1], d[2], d[3]), (usize, usize, usize, usize) ); diff --git a/src/storage.rs b/src/storage.rs index 8bfde057..c2f47bea 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -81,6 +81,63 @@ pub enum Storage { Cpu(CpuStorage), } +trait UnaryOp { + const NAME: &'static str; + fn f32(v1: f32) -> f32; + fn f64(v1: f64) -> f64; +} + +trait BinaryOp { + const NAME: &'static str; + fn f32(v1: f32, v2: f32) -> f32; + fn f64(v1: f64, v2: f64) -> f64; +} + +struct Add; +struct Mul; +struct Sqr; +struct Sqrt; + +impl BinaryOp for Add { + const NAME: &'static str = "add"; + fn f32(v1: f32, v2: f32) -> f32 { + v1 + v2 + } + fn f64(v1: f64, v2: f64) -> f64 { + v1 + v2 + } +} + +impl BinaryOp for Mul { + const NAME: &'static str = "mul"; + fn f32(v1: f32, v2: f32) -> f32 { + v1 * v2 + } + fn f64(v1: f64, v2: f64) -> f64 { + v1 * v2 + } +} + +impl UnaryOp for Sqr { + const NAME: &'static str = "sqr"; + fn f32(v1: f32) -> f32 { + v1 * v1 + } + fn f64(v1: f64) -> f64 { + v1 * v1 + } +} + +impl UnaryOp for Sqrt { + const NAME: &'static str = "sqrt"; + fn f32(v1: f32) -> f32 { + v1.sqrt() + } + fn f64(v1: f64) -> f64 { + v1.sqrt() + } +} + impl Storage { pub fn device(&self) -> Device { match self { @@ -114,16 +171,34 @@ impl Storage { } } + fn unary_impl(&self, shape: &Shape, stride: &[usize]) -> Result { + // TODO: Different code path for the contiguous case? + match self { + Storage::Cpu(storage) => match storage { + CpuStorage::F32(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let data = index.map(|i| B::f32(storage[i])).collect(); + Ok(Storage::Cpu(CpuStorage::F32(data))) + } + CpuStorage::F64(storage) => { + let index = StridedIndex::new(shape.dims(), stride); + let data = index.map(|i| B::f64(storage[i])).collect(); + Ok(Storage::Cpu(CpuStorage::F64(data))) + } + }, + } + } + // TODO: Support broadcasting? - pub(crate) fn add_impl( + fn binary_impl( &self, rhs: &Self, shape: &Shape, lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { - self.same_device(rhs, "add")?; - self.same_dtype(rhs, "add")?; + self.same_device(rhs, B::NAME)?; + self.same_dtype(rhs, B::NAME)?; // The ggml implementation has different paths based on whether the rhs is contiguous // or not, for now we only consider the general case but we should benchmark and do the // same if it helps. @@ -135,7 +210,7 @@ impl Storage { let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); let data = lhs_index .zip(rhs_index) - .map(|(lhs_i, rhs_i)| lhs[lhs_i] + rhs[rhs_i]) + .map(|(lhs_i, rhs_i)| B::f32(lhs[lhs_i], rhs[rhs_i])) .collect(); Ok(Storage::Cpu(CpuStorage::F32(data))) } @@ -144,7 +219,7 @@ impl Storage { let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); let data = lhs_index .zip(rhs_index) - .map(|(lhs_i, rhs_i)| lhs[lhs_i] + rhs[rhs_i]) + .map(|(lhs_i, rhs_i)| B::f64(lhs[lhs_i], rhs[rhs_i])) .collect(); Ok(Storage::Cpu(CpuStorage::F64(data))) } @@ -153,14 +228,23 @@ impl Storage { Err(Error::DTypeMismatchBinaryOp { lhs: lhs.dtype(), rhs: rhs.dtype(), - op: "add", + op: B::NAME, }) } }, } } - // TODO: Support broadcasting? + pub(crate) fn add_impl( + &self, + rhs: &Self, + shape: &Shape, + lhs_stride: &[usize], + rhs_stride: &[usize], + ) -> Result { + self.binary_impl::(rhs, shape, lhs_stride, rhs_stride) + } + pub(crate) fn mul_impl( &self, rhs: &Self, @@ -168,38 +252,14 @@ impl Storage { lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { - self.same_device(rhs, "mul")?; - self.same_dtype(rhs, "mul")?; - // TODO: share this code with the add implementation, using a macro or a trait? - match (self, rhs) { - (Storage::Cpu(lhs), Storage::Cpu(rhs)) => match (lhs, rhs) { - (CpuStorage::F32(lhs), CpuStorage::F32(rhs)) => { - let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); - let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); - let data = lhs_index - .zip(rhs_index) - .map(|(lhs_i, rhs_i)| lhs[lhs_i] * rhs[rhs_i]) - .collect(); - Ok(Storage::Cpu(CpuStorage::F32(data))) - } - (CpuStorage::F64(lhs), CpuStorage::F64(rhs)) => { - let lhs_index = StridedIndex::new(shape.dims(), lhs_stride); - let rhs_index = StridedIndex::new(shape.dims(), rhs_stride); - let data = lhs_index - .zip(rhs_index) - .map(|(lhs_i, rhs_i)| lhs[lhs_i] * rhs[rhs_i]) - .collect(); - Ok(Storage::Cpu(CpuStorage::F64(data))) - } - _ => { - // This should be covered by the dtype check above. - Err(Error::DTypeMismatchBinaryOp { - lhs: lhs.dtype(), - rhs: rhs.dtype(), - op: "add", - }) - } - }, - } + self.binary_impl::(rhs, shape, lhs_stride, rhs_stride) + } + + pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result { + self.unary_impl::(shape, stride) + } + + pub(crate) fn sqrt_impl(&self, shape: &Shape, stride: &[usize]) -> Result { + self.unary_impl::(shape, stride) } } diff --git a/src/tensor.rs b/src/tensor.rs index 97573158..50717c60 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -27,6 +27,40 @@ impl std::fmt::Debug for Tensor { } } +macro_rules! unary_op { + ($fn_name:ident, $op_name:ident, $impl_name:ident) => { + pub fn $fn_name(&self) -> Result { + let shape = self.shape(); + let storage = self.storage.$impl_name(self.shape(), self.stride())?; + let tensor_ = Tensor_ { + storage, + shape: shape.clone(), + stride: shape.stride_contiguous(), + op: Some(Op::$op_name(self.clone())), + }; + Ok(Self(Arc::new(tensor_))) + } + }; +} + +macro_rules! binary_op { + ($fn_name:ident, $op_name:ident, $impl_name:ident) => { + pub fn $fn_name(&self, rhs: &Self) -> Result { + let shape = self.same_shape_binary_op(rhs, stringify!($fn_name))?; + let storage = + self.storage + .$impl_name(&rhs.storage, shape, self.stride(), rhs.stride())?; + let tensor_ = Tensor_ { + storage, + shape: shape.clone(), + stride: shape.stride_contiguous(), + op: Some(Op::$op_name(self.clone(), rhs.clone())), + }; + Ok(Self(Arc::new(tensor_))) + } + }; +} + impl Tensor { pub fn zeros>(shape: S, dtype: DType, device: Device) -> Self { let shape = shape.into(); @@ -70,34 +104,11 @@ impl Tensor { // TODO: Also make an inplace version or a pre-allocated? This could be tricky // if this can create cycles in the compute graph. - pub fn add(&self, rhs: &Self) -> Result { - let shape = self.same_shape_binary_op(rhs, "add")?; - let storage = self - .storage - .add_impl(&rhs.storage, shape, self.stride(), rhs.stride())?; - let tensor_ = Tensor_ { - storage, - shape: shape.clone(), - stride: shape.stride_contiguous(), - op: Some(Op::Add(self.clone(), rhs.clone())), - }; - Ok(Self(Arc::new(tensor_))) - } - - pub fn mul(&self, rhs: &Self) -> Result { - let shape = self.same_shape_binary_op(rhs, "mul")?; - let storage = self - .storage - .mul_impl(&rhs.storage, shape, self.stride(), rhs.stride())?; - let tensor_ = Tensor_ { - storage, - shape: shape.clone(), - stride: shape.stride_contiguous(), - op: Some(Op::Mul(self.clone(), rhs.clone())), - }; - Ok(Self(Arc::new(tensor_))) - } + binary_op!(add, Add, add_impl); + binary_op!(mul, Mul, mul_impl); + unary_op!(sqr, Sqr, sqr_impl); + unary_op!(sqrt, Sqrt, sqrt_impl); pub fn to_scalar(&self) -> Result { if self.rank() != 0 { return Err(Error::UnexpectedNumberOfDims { @@ -135,8 +146,20 @@ impl Tensor { } pub fn to_vec2(&self) -> Result>> { - // TODO: Similar to to_vec1 then reshape the resulting vec? - todo!() + let (dim1, dim2) = self.shape().r2()?; + match &self.storage { + Storage::Cpu(cpu_storage) => { + let data = S::cpu_storage_as_slice(cpu_storage)?; + let mut rows = vec![]; + let mut src_index = self.strided_index(); + for _idx_row in 0..dim1 { + let row = (0..dim2).map(|_| data[src_index.next().unwrap()]).collect(); + rows.push(row) + } + assert!(src_index.next().is_none()); + Ok(rows) + } + } } pub fn dtype(&self) -> DType {