Add a broadcast variant to matmul. (#523)

* Add a broadcast variant to matmul.

* Get the test to pass.
This commit is contained in:
Laurent Mazare
2023-08-20 13:20:42 +01:00
committed by GitHub
parent a8f61e66cc
commit 2fcb386f17
3 changed files with 108 additions and 43 deletions

View File

@ -185,6 +185,69 @@ impl Shape {
self.0.extend(additional_dims);
self
}
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
/// broadcasted shape. This is to be used for binary pointwise ops.
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();
let lhs_ndims = lhs_dims.len();
let rhs_ndims = rhs_dims.len();
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
let mut bcast_dims = vec![0; bcast_ndims];
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
let rev_idx = bcast_ndims - idx;
let l_value = if lhs_ndims < rev_idx {
1
} else {
lhs_dims[lhs_ndims - rev_idx]
};
let r_value = if rhs_ndims < rev_idx {
1
} else {
rhs_dims[rhs_ndims - rev_idx]
};
*bcast_value = if l_value == r_value {
l_value
} else if l_value == 1 {
r_value
} else if r_value == 1 {
l_value
} else {
Err(Error::ShapeMismatchBinaryOp {
lhs: lhs.clone(),
rhs: rhs.clone(),
op,
}
.bt())?
}
}
Ok(Shape::from(bcast_dims))
}
pub(crate) fn broadcast_shape_matmul(&self, rhs: &Self) -> Result<(Shape, Shape)> {
let lhs = self;
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();
if lhs_dims.len() < 2 || rhs_dims.len() < 2 {
crate::bail!("only 2d matrixes are supported {lhs:?} {rhs:?}")
}
let (m, lhs_k) = (lhs_dims[lhs_dims.len() - 2], lhs_dims[lhs_dims.len() - 1]);
let (rhs_k, n) = (rhs_dims[rhs_dims.len() - 2], rhs_dims[rhs_dims.len() - 1]);
if lhs_k != rhs_k {
crate::bail!("different inner dimensions in broadcast matmul {lhs:?} {rhs:?}")
}
let lhs_b = Self::from(&lhs_dims[..lhs_dims.len() - 2]);
let rhs_b = Self::from(&rhs_dims[..rhs_dims.len() - 2]);
let bcast = lhs_b.broadcast_shape_binary_op(&rhs_b, "broadcast_matmul")?;
let bcast_dims = bcast.dims();
let bcast_lhs = [bcast_dims, &[m, lhs_k]].concat();
let bcast_rhs = [bcast_dims, &[rhs_k, n]].concat();
Ok((Shape::from(bcast_lhs), Shape::from(bcast_rhs)))
}
}
pub trait Dim {

View File

@ -106,7 +106,9 @@ macro_rules! broadcast_binary_op {
($fn_name:ident, $inner_fn_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let lhs = self;
let shape = lhs.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
let shape = lhs
.shape()
.broadcast_shape_binary_op(rhs.shape(), stringify!($fn_name))?;
let l_broadcast = shape != *lhs.shape();
let r_broadcast = shape != *rhs.shape();
match (l_broadcast, r_broadcast) {
@ -415,48 +417,6 @@ impl Tensor {
Self::new_impl(array, shape.into(), device, false)
}
pub(crate) fn broadcast_shape_binary_op<'a>(
&'a self,
rhs: &'a Self,
op: &'static str,
) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.shape().dims();
let rhs_dims = rhs.shape().dims();
let lhs_ndims = lhs_dims.len();
let rhs_ndims = rhs_dims.len();
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
let mut bcast_dims = vec![0; bcast_ndims];
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
let rev_idx = bcast_ndims - idx;
let l_value = if lhs_ndims < rev_idx {
1
} else {
lhs_dims[lhs_ndims - rev_idx]
};
let r_value = if rhs_ndims < rev_idx {
1
} else {
rhs_dims[rhs_ndims - rev_idx]
};
*bcast_value = if l_value == r_value {
l_value
} else if l_value == 1 {
r_value
} else if r_value == 1 {
l_value
} else {
Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),
op,
}
.bt())?
}
}
Ok(Shape::from(bcast_dims))
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
let lhs = self.shape();
let rhs = rhs.shape();
@ -961,6 +921,28 @@ impl Tensor {
Ok(from_storage(storage, c_shape, op, false))
}
/// Matrix-multiplication with broadcasting support.
///
/// Compared to `matmul` the two matrixes are allowed to have different dimensions as long as
/// they are compatible for broadcast. E.g. if `self` has shape `(j, 1, n, k)` and `rhs` has
/// shape `(l, k, m)`, the output will have shape `(j, l, n, m)`.
pub fn broadcast_matmul(&self, rhs: &Self) -> Result<Self> {
let lhs = self;
let (l_shape, r_shape) = lhs.shape().broadcast_shape_matmul(rhs.shape())?;
let l_broadcast = l_shape != *lhs.shape();
let r_broadcast = r_shape != *rhs.shape();
// TODO: Avoid concretising the broadcasted matrixes via contiguous.
match (l_broadcast, r_broadcast) {
(true, true) => lhs
.broadcast_as(&l_shape)?
.contiguous()?
.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
(false, true) => lhs.matmul(&rhs.broadcast_as(&r_shape)?.contiguous()?),
(true, false) => lhs.broadcast_as(&l_shape)?.contiguous()?.matmul(rhs),
(false, false) => lhs.matmul(rhs),
}
}
/// Returns a tensor with the same shape as the input tensor, the values are taken from
/// `on_true` if the input tensor value is not zero, and `on_false` at the positions where the
/// input tensor is equal to zero.

View File

@ -747,6 +747,25 @@ fn matmul(device: &Device) -> Result<()> {
Ok(())
}
fn broadcast_matmul(device: &Device) -> Result<()> {
let lhs = Tensor::randn(0f32, 1f32, (3, 1, 4, 5), device)?;
let rhs = Tensor::randn(0f32, 1f32, (6, 5, 2), device)?;
let out = lhs.broadcast_matmul(&rhs)?;
assert_eq!(out.dims(), &[3, 6, 4, 2]);
for idx1 in 0..3 {
for idx2 in 0..6 {
let out = out.i((idx1, idx2))?;
let lhs = lhs.i((idx1, 0))?;
let rhs = rhs.i(idx2)?;
let out2 = lhs.matmul(&rhs);
let sum_diff2 = (out - out2)?.sqr()?.sum_all()?;
// With cuda, we see errors of up to ~1e-12.
assert!(sum_diff2.to_vec0::<f32>()? < 1e-6)
}
}
Ok(())
}
fn broadcasting(device: &Device) -> Result<()> {
let t1 = Tensor::arange(0f32, 24f32, device)?.reshape((4, 2, 3))?;
let t2 = Tensor::new(&[100f32, 200f32], device)?;
@ -864,6 +883,7 @@ test_device!(binary_op, binary_op_cpu, binary_op_gpu);
test_device!(embeddings, embeddings_cpu, embeddings_gpu);
test_device!(cmp, cmp_cpu, cmp_gpu);
test_device!(matmul, matmul_cpu, matmul_gpu);
test_device!(broadcast_matmul, broadcast_matmul_cpu, broadcast_matmul_gpu);
test_device!(broadcasting, broadcasting_cpu, broadcasting_gpu);
test_device!(index_select, index_select_cpu, index_select_gpu);
test_device!(index_add, index_add_cpu, index_add_gpu);