diff --git a/src/backprop.rs b/src/backprop.rs index ec6f4b59..3137dce8 100644 --- a/src/backprop.rs +++ b/src/backprop.rs @@ -37,10 +37,6 @@ impl Tensor { | Op::Mul(lhs, rhs) | Op::Sub(lhs, rhs) | Op::Div(lhs, rhs) - | Op::BroadcastAdd(lhs, rhs) - | Op::BroadcastMul(lhs, rhs) - | Op::BroadcastSub(lhs, rhs) - | Op::BroadcastDiv(lhs, rhs) | Op::Embedding(lhs, rhs) | Op::Matmul(lhs, rhs) => { let (tg, nodes) = walk(lhs, nodes, already_seen); @@ -142,34 +138,6 @@ impl Tensor { let rhs_sum_grad = grads.or_insert(rhs)?; *rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?; } - Op::BroadcastAdd(lhs, rhs) => { - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.broadcast_add(&grad)?; - } - Op::BroadcastSub(lhs, rhs) => { - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?; - } - Op::BroadcastMul(lhs, rhs) => { - let lhs_grad = grad.broadcast_mul(rhs)?; - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?; - let rhs_grad = grad.broadcast_mul(lhs)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?; - } - Op::BroadcastDiv(lhs, rhs) => { - let lhs_grad = grad.broadcast_div(rhs)?; - let lhs_sum_grad = grads.or_insert(lhs)?; - *lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?; - let rhs_grad = grad.broadcast_mul(lhs)?.broadcast_div(&rhs.sqr()?)?; - let rhs_sum_grad = grads.or_insert(rhs)?; - *rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?; - } Op::WhereCond(_pred, _t, _f) => { return Err(Error::BackwardNotSupported { op: "where_cond" }) } diff --git a/src/cpu_backend.rs b/src/cpu_backend.rs index 4571985a..61ffcb28 100644 --- a/src/cpu_backend.rs +++ b/src/cpu_backend.rs @@ -58,8 +58,7 @@ fn unary_map U>( } } -// This function maps over two strided index sequences. It supports broadcasting in case -// `lhs_stride` or `rhs_stride` has a length shorter than `shape`. +// This function maps over two strided index sequences. fn binary_map T>( shape: &Shape, lhs_stride: &[usize], @@ -69,52 +68,15 @@ fn binary_map T>( mut f: F, ) -> Vec { let dims = shape.dims(); - let broadcast_ldims = dims.len() - lhs_stride.len(); - let broadcast_rdims = dims.len() - rhs_stride.len(); - let elem_count = shape.elem_count(); - if broadcast_ldims == 0 && broadcast_rdims == 0 { - if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) { - (0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect() - } else { - let lhs_index = StridedIndex::new(dims, lhs_stride); - let rhs_index = StridedIndex::new(dims, rhs_stride); - lhs_index - .zip(rhs_index) - .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) - .collect() - } - } else if broadcast_rdims == 0 { - let mut res = Vec::new(); - res.reserve(elem_count); - let lhs_v: Vec = StridedIndex::new(dims, lhs_stride) - .map(|i| lhs[i]) - .collect(); - let mut i = 0; - for rhs_i in StridedIndex::new(dims, rhs_stride) { - res.push(f(lhs_v[i], rhs[rhs_i])); - i += 1; - if i >= lhs_v.len() { - i = 0 - } - } - res - } else if broadcast_ldims == 0 { - let mut res = Vec::new(); - res.reserve(elem_count); - let rhs_v: Vec = StridedIndex::new(dims, rhs_stride) - .map(|i| rhs[i]) - .collect(); - let mut i = 0; - for lhs_i in StridedIndex::new(dims, lhs_stride) { - res.push(f(lhs[lhs_i], rhs_v[i])); - i += 1; - if i >= rhs_v.len() { - i = 0 - } - } - res + if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) { + (0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect() } else { - panic!("unexpected broadcasting dims: {shape:?} {lhs_stride:?} {rhs_stride:?}") + let lhs_index = StridedIndex::new(dims, lhs_stride); + let rhs_index = StridedIndex::new(dims, rhs_stride); + lhs_index + .zip(rhs_index) + .map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i])) + .collect() } } diff --git a/src/op.rs b/src/op.rs index da096b5c..82068dfd 100644 --- a/src/op.rs +++ b/src/op.rs @@ -6,10 +6,6 @@ pub(crate) enum Op { Mul(Tensor, Tensor), Sub(Tensor, Tensor), Div(Tensor, Tensor), - BroadcastAdd(Tensor, Tensor), - BroadcastMul(Tensor, Tensor), - BroadcastSub(Tensor, Tensor), - BroadcastDiv(Tensor, Tensor), Matmul(Tensor, Tensor), Embedding(Tensor, Tensor), WhereCond(Tensor, Tensor, Tensor), diff --git a/src/tensor.rs b/src/tensor.rs index 57695819..99a9199a 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -88,21 +88,20 @@ macro_rules! binary_op { } macro_rules! broadcast_binary_op { - ($fn_name:ident, $impl_name:ident, $op_name:ident) => { + ($fn_name:ident, $inner_fn_name:ident) => { pub fn $fn_name(&self, rhs: &Self) -> Result { - let shape = self.broadcast_shape_binary_op(rhs, stringify!($fn_name))?; - let storage = self.storage.binary_impl::( - &rhs.storage, - shape, - self.stride(), - rhs.stride(), - )?; - let op = if self.track_op() || rhs.track_op() { - Some(Op::$op_name(self.clone(), rhs.clone())) - } else { - None - }; - Ok(from_storage(storage, shape.clone(), op, false)) + let lhs = self; + let shape = lhs.broadcast_shape_binary_op(rhs, stringify!($fn_name))?; + let l_broadcast = shape != *lhs.shape(); + let r_broadcast = shape != *rhs.shape(); + match (l_broadcast, r_broadcast) { + (true, true) => lhs + .broadcast_as(&shape)? + .$inner_fn_name(&rhs.broadcast_as(&shape)?), + (false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?), + (true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs), + (false, false) => lhs.$inner_fn_name(rhs), + } } }; } @@ -250,21 +249,41 @@ impl Tensor { &'a self, rhs: &'a Self, op: &'static str, - ) -> Result<&'a Shape> { + ) -> Result { let lhs = self; let lhs_dims = lhs.shape().dims(); let rhs_dims = rhs.shape().dims(); - if lhs_dims.strip_suffix(rhs_dims).is_some() { - Ok(self.shape()) - } else if rhs_dims.strip_suffix(lhs_dims).is_some() { - Ok(rhs.shape()) - } else { - Err(Error::ShapeMismatchBinaryOp { - lhs: self.shape().clone(), - rhs: rhs.shape().clone(), - op, - }) + let lhs_ndims = lhs_dims.len(); + let rhs_ndims = rhs_dims.len(); + let bcast_ndims = usize::max(lhs_ndims, rhs_ndims); + let mut bcast_dims = vec![0; bcast_ndims]; + for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() { + let rev_idx = bcast_ndims - idx; + let l_value = if lhs_ndims < rev_idx { + 1 + } else { + lhs_dims[lhs_ndims - rev_idx] + }; + let r_value = if rhs_ndims < rev_idx { + 1 + } else { + rhs_dims[rhs_ndims - rev_idx] + }; + *bcast_value = if l_value == r_value { + l_value + } else if l_value == 1 { + r_value + } else if r_value == 1 { + l_value + } else { + Err(Error::ShapeMismatchBinaryOp { + lhs: self.shape().clone(), + rhs: rhs.shape().clone(), + op, + })? + } } + Ok(Shape::from(bcast_dims)) } pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { @@ -293,10 +312,10 @@ impl Tensor { binary_op!(mul, Mul); binary_op!(sub, Sub); binary_op!(div, Div); - broadcast_binary_op!(broadcast_add, Add, BroadcastAdd); - broadcast_binary_op!(broadcast_mul, Mul, BroadcastMul); - broadcast_binary_op!(broadcast_sub, Sub, BroadcastSub); - broadcast_binary_op!(broadcast_div, Div, BroadcastDiv); + broadcast_binary_op!(broadcast_add, add); + broadcast_binary_op!(broadcast_mul, mul); + broadcast_binary_op!(broadcast_sub, sub); + broadcast_binary_op!(broadcast_div, div); unary_op!(neg, Neg); unary_op!(exp, Exp);