Add a broadcast variant to matmul. (#523)

* Add a broadcast variant to matmul.

* Get the test to pass.
This commit is contained in:
Laurent Mazare
2023-08-20 13:20:42 +01:00
committed by GitHub
parent a8f61e66cc
commit 2fcb386f17
3 changed files with 108 additions and 43 deletions

View File

@ -106,7 +106,9 @@ macro_rules! broadcast_binary_op {
($fn_name:ident, $inner_fn_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let lhs = self;
let shape = lhs.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
let shape = lhs
.shape()
.broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
let l_broadcast = shape != *lhs.shape();
let r_broadcast = shape != *rhs.shape();
match (l_broadcast, r_broadcast) {
@ -415,48 +417,6 @@ impl Tensor {
Self::new_impl(array, shape.into(), device, false)
}
pub(crate) fn broadcast_shape_binary_op<'a>(
&'a self,
rhs: &'a Self,
op: &'static str,
) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.shape().dims();
let rhs_dims = rhs.shape().dims();
let lhs_ndims = lhs_dims.len();
let rhs_ndims = rhs_dims.len();
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
let mut bcast_dims = vec![0; bcast_ndims];
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
let rev_idx = bcast_ndims - idx;
let l_value = if lhs_ndims < rev_idx {
1
} else {
lhs_dims[lhs_ndims - rev_idx]
};
let r_value = if rhs_ndims < rev_idx {
1
} else {
rhs_dims[rhs_ndims - rev_idx]
};
*bcast_value = if l_value == r_value {
l_value
} else if l_value == 1 {
r_value
} else if r_value == 1 {
l_value
} else {
Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),
op,
}
.bt())?
}
}
Ok(Shape::from(bcast_dims))
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
let lhs = self.shape();
let rhs = rhs.shape();
@ -961,6 +921,28 @@ impl Tensor {
Ok(from_storage(storage, c_shape, op, false))
}
/// Matrix-multiplication with broadcasting support.
///
/// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as
/// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has
/// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`.
pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
let lhs = self;
let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
let l_broadcast = l_shape != *lhs.shape();
let r_broadcast = r_shape != *rhs.shape();
// TODO: Avoid concretising the broadcasted matrixes via contiguous.
match (l_broadcast, r_broadcast) {
(true, true) => lhs
.broadcast_as(&l_shape)?
.contiguous()?
.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
(false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
(true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
(false, false) => lhs.matmul(rhs),
}
}
/// Returns a tensor with the same shape as the input tensor, the values are taken from
/// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
/// input tensor is equal to zero.