From 2fcb386f17558f56c8b8e24fbc36c9d3686e73ba Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 20 Aug 2023 13:20:42 +0100 Subject: [PATCH] Add a broadcast variant to matmul. (#523) * Add a broadcast variant to matmul. * Get the test to pass. --- candle-core/src/shape.rs | 63 ++++++++++++++++++++++++++++ candle-core/src/tensor.rs | 68 ++++++++++++------------------- candle-core/tests/tensor_tests.rs | 20 +++++++++ 3 files changed, 108 insertions(+), 43 deletions(-) diff --git a/candle-core/src/shape.rs b/candle-core/src/shape.rs index d8f8f756..49fbf022 100644 --- a/candle-core/src/shape.rs +++ b/candle-core/src/shape.rs @@ -185,6 +185,69 @@ impl Shape { self.0.extend(additional_dims); self } + + /// Check whether the two shapes are compatible for broadcast, and if it is the case return the + /// broadcasted shape. This is to be used for binary pointwise ops. + pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result { + let lhs = self; + let lhs_dims = lhs.dims(); + let rhs_dims = rhs.dims(); + let lhs_ndims = lhs_dims.len(); + let rhs_ndims = rhs_dims.len(); + let bcast_ndims = usize::max(lhs_ndims, rhs_ndims); + let mut bcast_dims = vec![0; bcast_ndims]; + for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() { + let rev_idx = bcast_ndims - idx; + let l_value = if lhs_ndims < rev_idx { + 1 + } else { + lhs_dims[lhs_ndims - rev_idx] + }; + let r_value = if rhs_ndims < rev_idx { + 1 + } else { + rhs_dims[rhs_ndims - rev_idx] + }; + *bcast_value = if l_value == r_value { + l_value + } else if l_value == 1 { + r_value + } else if r_value == 1 { + l_value + } else { + Err(Error::ShapeMismatchBinaryOp { + lhs: lhs.clone(), + rhs: rhs.clone(), + op, + } + .bt())? + } + } + Ok(Shape::from(bcast_dims)) + } + + pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> { + let lhs = self; + let lhs_dims = lhs.dims(); + let rhs_dims = rhs.dims(); + if lhs_dims.len() < 2 || rhs_dims.len() < 2 { + crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}") + } + let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]); + let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]); + if lhs_k != rhs_k { + crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}") + } + + let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]); + let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]); + let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?; + let bcast_dims = bcast.dims(); + + let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat(); + let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat(); + Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs))) + } } pub trait Dim { diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 45aa07bc..4ea66186 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -106,7 +106,9 @@ macro_rules! broadcast_binary_op { ($fn_name:ident, $inner_fn_name:ident) => { pub fn $fn_name(&self, rhs: &Self) -> Result { let lhs = self; - let shape = lhs.broadcast_shape_binary_op(rhs, stringify!($fn_name))?; + let shape = lhs + .shape() + .broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?; let l_broadcast = shape != *lhs.shape(); let r_broadcast = shape != *rhs.shape(); match (l_broadcast, r_broadcast) { @@ -415,48 +417,6 @@ impl Tensor { Self::new_impl(array, shape.into(), device, false) } - pub(crate) fn broadcast_shape_binary_op<'a>( - &'a self, - rhs: &'a Self, - op: &'static str, - ) -> Result { - let lhs = self; - let lhs_dims = lhs.shape().dims(); - let rhs_dims = rhs.shape().dims(); - let lhs_ndims = lhs_dims.len(); - let rhs_ndims = rhs_dims.len(); - let bcast_ndims = usize::max(lhs_ndims, rhs_ndims); - let mut bcast_dims = vec![0; bcast_ndims]; - for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() { - let rev_idx = bcast_ndims - idx; - let l_value = if lhs_ndims < rev_idx { - 1 - } else { - lhs_dims[lhs_ndims - rev_idx] - }; - let r_value = if rhs_ndims < rev_idx { - 1 - } else { - rhs_dims[rhs_ndims - rev_idx] - }; - *bcast_value = if l_value == r_value { - l_value - } else if l_value == 1 { - r_value - } else if r_value == 1 { - l_value - } else { - Err(Error::ShapeMismatchBinaryOp { - lhs: self.shape().clone(), - rhs: rhs.shape().clone(), - op, - } - .bt())? - } - } - Ok(Shape::from(bcast_dims)) - } - pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> { let lhs = self.shape(); let rhs = rhs.shape(); @@ -961,6 +921,28 @@ impl Tensor { Ok(from_storage(storage, c_shape, op, false)) } + /// Matrix-multiplication with broadcasting support. + /// + /// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as + /// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has + /// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`. + pub fn broadcast_matmul(&self, rhs: &Self) -> Result { + let lhs = self; + let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?; + let l_broadcast = l_shape != *lhs.shape(); + let r_broadcast = r_shape != *rhs.shape(); + // TODO: Avoid concretising the broadcasted matrixes via contiguous. + match (l_broadcast, r_broadcast) { + (true, true) => lhs + .broadcast_as(&l_shape)? + .contiguous()? + .matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?), + (false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?), + (true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs), + (false, false) => lhs.matmul(rhs), + } + } + /// Returns a tensor with the same shape as the input tensor, the values are taken from /// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the /// input tensor is equal to zero. diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 0b77f1a5..907c876e 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -747,6 +747,25 @@ fn matmul(device: &Device) -> Result<()> { Ok(()) } +fn broadcast_matmul(device: &Device) -> Result<()> { + let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?; + let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?; + let out = lhs.broadcast_matmul(&rhs)?; + assert_eq!(out.dims(), &[3, 6, 4, 2]); + for idx1 in 0..3 { + for idx2 in 0..6 { + let out = out.i((idx1, idx2))?; + let lhs = lhs.i((idx1, 0))?; + let rhs = rhs.i(idx2)?; + let out2 = lhs.matmul(&rhs); + let sum_diff2 = (out - out2)?.sqr()?.sum_all()?; + // With cuda, we see errors of up to ~1e-12. + assert!(sum_diff2.to_vec0::()? < 1e-6) + } + } + Ok(()) +} + fn broadcasting(device: &Device) -> Result<()> { let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?; let t2 = Tensor::new(&[100f32, 200f32], device)?; @@ -864,6 +883,7 @@ test_device!(binary_op, binary_op_cpu, binary_op_gpu); test_device!(embeddings, embeddings_cpu, embeddings_gpu); test_device!(cmp, cmp_cpu, cmp_gpu); test_device!(matmul, matmul_cpu, matmul_gpu); +test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu); test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu); test_device!(index_select, index_select_cpu, index_select_gpu); test_device!(index_add, index_add_cpu, index_add_gpu);