Cleanup the broadcast setup.

This commit is contained in:
laurent
2023-06-26 10:49:34 +01:00
parent 217bdcdf4d
commit 5952c3fa91
4 changed files with 57 additions and 112 deletions

View File

@ -37,10 +37,6 @@ impl Tensor {
| Op::Mul(lhs, rhs)
| Op::Sub(lhs, rhs)
| Op::Div(lhs, rhs)
| Op::BroadcastAdd(lhs, rhs)
| Op::BroadcastMul(lhs, rhs)
| Op::BroadcastSub(lhs, rhs)
| Op::BroadcastDiv(lhs, rhs)
| Op::Embedding(lhs, rhs)
| Op::Matmul(lhs, rhs) => {
let (tg, nodes) = walk(lhs, nodes, already_seen);
@ -142,34 +138,6 @@ impl Tensor {
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
}
Op::BroadcastAdd(lhs, rhs) => {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&grad)?;
}
Op::BroadcastSub(lhs, rhs) => {
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?;
}
Op::BroadcastMul(lhs, rhs) => {
let lhs_grad = grad.broadcast_mul(rhs)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?;
let rhs_grad = grad.broadcast_mul(lhs)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
}
Op::BroadcastDiv(lhs, rhs) => {
let lhs_grad = grad.broadcast_div(rhs)?;
let lhs_sum_grad = grads.or_insert(lhs)?;
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?;
let rhs_grad = grad.broadcast_mul(lhs)?.broadcast_div(&rhs.sqr()?)?;
let rhs_sum_grad = grads.or_insert(rhs)?;
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
}
Op::WhereCond(_pred, _t, _f) => {
return Err(Error::BackwardNotSupported { op: "where_cond" })
}

View File

@ -58,8 +58,7 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
}
}
// This function maps over two strided index sequences. It supports broadcasting in case
// `lhs_stride` or `rhs_stride` has a length shorter than `shape`.
// This function maps over two strided index sequences.
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
shape: &Shape,
lhs_stride: &[usize],
@ -69,52 +68,15 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
mut f: F,
) -> Vec<T> {
let dims = shape.dims();
let broadcast_ldims = dims.len() - lhs_stride.len();
let broadcast_rdims = dims.len() - rhs_stride.len();
let elem_count = shape.elem_count();
if broadcast_ldims == 0 && broadcast_rdims == 0 {
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
} else {
let lhs_index = StridedIndex::new(dims, lhs_stride);
let rhs_index = StridedIndex::new(dims, rhs_stride);
lhs_index
.zip(rhs_index)
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect()
}
} else if broadcast_rdims == 0 {
let mut res = Vec::new();
res.reserve(elem_count);
let lhs_v: Vec<T> = StridedIndex::new(dims, lhs_stride)
.map(|i| lhs[i])
.collect();
let mut i = 0;
for rhs_i in StridedIndex::new(dims, rhs_stride) {
res.push(f(lhs_v[i], rhs[rhs_i]));
i += 1;
if i >= lhs_v.len() {
i = 0
}
}
res
} else if broadcast_ldims == 0 {
let mut res = Vec::new();
res.reserve(elem_count);
let rhs_v: Vec<T> = StridedIndex::new(dims, rhs_stride)
.map(|i| rhs[i])
.collect();
let mut i = 0;
for lhs_i in StridedIndex::new(dims, lhs_stride) {
res.push(f(lhs[lhs_i], rhs_v[i]));
i += 1;
if i >= rhs_v.len() {
i = 0
}
}
res
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
} else {
panic!("unexpected broadcasting dims: {shape:?} {lhs_stride:?} {rhs_stride:?}")
let lhs_index = StridedIndex::new(dims, lhs_stride);
let rhs_index = StridedIndex::new(dims, rhs_stride);
lhs_index
.zip(rhs_index)
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
.collect()
}
}

View File

@ -6,10 +6,6 @@ pub(crate) enum Op {
Mul(Tensor, Tensor),
Sub(Tensor, Tensor),
Div(Tensor, Tensor),
BroadcastAdd(Tensor, Tensor),
BroadcastMul(Tensor, Tensor),
BroadcastSub(Tensor, Tensor),
BroadcastDiv(Tensor, Tensor),
Matmul(Tensor, Tensor),
Embedding(Tensor, Tensor),
WhereCond(Tensor, Tensor, Tensor),

View File

@ -88,21 +88,20 @@ macro_rules! binary_op {
}
macro_rules! broadcast_binary_op {
($fn_name:ident, $impl_name:ident, $op_name:ident) => {
($fn_name:ident, $inner_fn_name:ident) => {
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
let shape = self.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
let storage = self.storage.binary_impl::<crate::op::$impl_name>(
&rhs.storage,
shape,
self.stride(),
rhs.stride(),
)?;
let op = if self.track_op() || rhs.track_op() {
Some(Op::$op_name(self.clone(), rhs.clone()))
} else {
None
};
Ok(from_storage(storage, shape.clone(), op, false))
let lhs = self;
let shape = lhs.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
let l_broadcast = shape != *lhs.shape();
let r_broadcast = shape != *rhs.shape();
match (l_broadcast, r_broadcast) {
(true, true) => lhs
.broadcast_as(&shape)?
.$inner_fn_name(&rhs.broadcast_as(&shape)?),
(false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
(true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
(false, false) => lhs.$inner_fn_name(rhs),
}
}
};
}
@ -250,21 +249,41 @@ impl Tensor {
&'a self,
rhs: &'a Self,
op: &'static str,
) -> Result<&'a Shape> {
) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.shape().dims();
let rhs_dims = rhs.shape().dims();
if lhs_dims.strip_suffix(rhs_dims).is_some() {
Ok(self.shape())
} else if rhs_dims.strip_suffix(lhs_dims).is_some() {
Ok(rhs.shape())
} else {
Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),
op,
})
let lhs_ndims = lhs_dims.len();
let rhs_ndims = rhs_dims.len();
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
let mut bcast_dims = vec![0; bcast_ndims];
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
let rev_idx = bcast_ndims - idx;
let l_value = if lhs_ndims < rev_idx {
1
} else {
lhs_dims[lhs_ndims - rev_idx]
};
let r_value = if rhs_ndims < rev_idx {
1
} else {
rhs_dims[rhs_ndims - rev_idx]
};
*bcast_value = if l_value == r_value {
l_value
} else if l_value == 1 {
r_value
} else if r_value == 1 {
l_value
} else {
Err(Error::ShapeMismatchBinaryOp {
lhs: self.shape().clone(),
rhs: rhs.shape().clone(),
op,
})?
}
}
Ok(Shape::from(bcast_dims))
}
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
@ -293,10 +312,10 @@ impl Tensor {
binary_op!(mul, Mul);
binary_op!(sub, Sub);
binary_op!(div, Div);
broadcast_binary_op!(broadcast_add, Add, BroadcastAdd);
broadcast_binary_op!(broadcast_mul, Mul, BroadcastMul);
broadcast_binary_op!(broadcast_sub, Sub, BroadcastSub);
broadcast_binary_op!(broadcast_div, Div, BroadcastDiv);
broadcast_binary_op!(broadcast_add, add);
broadcast_binary_op!(broadcast_mul, mul);
broadcast_binary_op!(broadcast_sub, sub);
broadcast_binary_op!(broadcast_div, div);
unary_op!(neg, Neg);
unary_op!(exp, Exp);