diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index e648b6c7..65632ef3 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -86,6 +86,10 @@ impl CpuStorage { lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { + let dims = shape.dims(); + if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() { + todo!("implement broadcast"); + } // The ggml implementation has different paths based on whether the rhs is contiguous // or not, for now we only consider the general case but we should benchmark and do the // same if it helps. diff --git a/src/cuda_backend.rs b/src/cuda_backend.rs index d9958e3b..c11504d7 100644 --- a/src/cuda_backend.rs +++ b/src/cuda_backend.rs @@ -331,8 +331,11 @@ impl CudaStorage { lhs_stride: &[usize], rhs_stride: &[usize], ) -> Result { - let elem_count = shape.elem_count(); let dims = shape.dims(); + if dims.len() != lhs_stride.len() || dims.len() != rhs_stride.len() { + return Err(CudaError::InternalError("TODO: implement broadcast")); + } + let elem_count = shape.elem_count(); let cfg = LaunchConfig::for_num_elems(elem_count as u32); let dev = self.device(); let dims_and_strides = dev.htod_copy([dims, lhs_stride, rhs_stride].concat())?; diff --git a/src/op.rs b/src/op.rs index 577bc8f0..ae8fa80d 100644 --- a/src/op.rs +++ b/src/op.rs @@ -6,6 +6,10 @@ pub(crate) enum Op { Mul(Tensor, Tensor), Sub(Tensor, Tensor), Div(Tensor, Tensor), + BroadcastAdd(Tensor, Tensor), + BroadcastMul(Tensor, Tensor), + BroadcastSub(Tensor, Tensor), + BroadcastDiv(Tensor, Tensor), Matmul(Tensor, Tensor), Embedding(Tensor, Tensor), diff --git a/src/storage.rs b/src/storage.rs index 74ab7f40..2a590d4e 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -91,7 +91,6 @@ impl Storage { } } - // TODO: Support broadcasting? pub(crate) fn binary_impl( &self, rhs: &Self, diff --git a/src/tensor.rs b/src/tensor.rs index b92823a3..1a5cf7f4 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -95,6 +95,34 @@ macro_rules! binary_op { }; } +macro_rules! broadcast_binary_op { + ($fn_name:ident, $impl_name:ident, $op_name:ident) => { + pub fn $fn_name(&self, rhs: &Self) -> Result { + let shape = self.broadcast_shape_binary_op(rhs, stringify!($fn_name))?; + let storage = self.storage.binary_impl::( + &rhs.storage, + shape, + self.stride(), + rhs.stride(), + )?; + let op = if self.track_op() || rhs.track_op() { + Some(Op::$op_name(self.clone(), rhs.clone())) + } else { + None + }; + let tensor_ = Tensor_ { + id: TensorId::new(), + storage, + shape: shape.clone(), + stride: shape.stride_contiguous(), + op, + is_variable: false, + }; + Ok(Self(Arc::new(tensor_))) + } + }; +} + impl Tensor { fn ones_impl>( shape: S, @@ -210,6 +238,27 @@ impl Tensor { Self::new_impl(array, shape.into(), device, true) } + pub(crate) fn broadcast_shape_binary_op<'a>( + &'a self, + rhs: &'a Self, + op: &'static str, + ) -> Result<&'a Shape> { + let lhs = self; + let lhs_dims = lhs.shape().dims(); + let rhs_dims = rhs.shape().dims(); + if lhs_dims.strip_suffix(rhs_dims).is_some() { + Ok(self.shape()) + } else if rhs_dims.strip_suffix(lhs_dims).is_some() { + Ok(rhs.shape()) + } else { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op, + }) + } + } + pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { let lhs = self.shape(); let rhs = rhs.shape(); @@ -236,6 +285,10 @@ impl Tensor { binary_op!(mul, Mul); binary_op!(sub, Sub); binary_op!(div, Div); + broadcast_binary_op!(broadcast_add, Add, BroadcastAdd); + broadcast_binary_op!(broadcast_mul, Mul, BroadcastMul); + broadcast_binary_op!(broadcast_sub, Sub, BroadcastSub); + broadcast_binary_op!(broadcast_div, Div, BroadcastDiv); unary_op!(neg, Neg); unary_op!(sqr, Sqr); @@ -773,6 +826,10 @@ impl Tensor { | Op::Mul(lhs, rhs) | Op::Sub(lhs, rhs) | Op::Div(lhs, rhs) + | Op::BroadcastAdd(lhs, rhs) + | Op::BroadcastMul(lhs, rhs) + | Op::BroadcastSub(lhs, rhs) + | Op::BroadcastDiv(lhs, rhs) | Op::Embedding(lhs, rhs) | Op::Matmul(lhs, rhs) => { let (tg, nodes) = walk(lhs, nodes, already_seen); @@ -865,6 +922,26 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } + Op::BroadcastAdd(_lhs, _rhs) => { + return Err(Error::BackwardNotSupported { + op: "broadcast_add", + }) + } + Op::BroadcastSub(_lhs, _rhs) => { + return Err(Error::BackwardNotSupported { + op: "broadcast_sub", + }) + } + Op::BroadcastMul(_lhs, _rhs) => { + return Err(Error::BackwardNotSupported { + op: "broadcast_mul", + }) + } + Op::BroadcastDiv(_lhs, _rhs) => { + return Err(Error::BackwardNotSupported { + op: "broadcast_div", + }) + } Op::Embedding(_lhs, _rhs) => { return Err(Error::BackwardNotSupported { op: "embedding" }) }