mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Cleanup the broadcast setup.
This commit is contained in:
@ -37,10 +37,6 @@ impl Tensor {
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
| Op::Div(lhs, rhs)
|
||||
| Op::BroadcastAdd(lhs, rhs)
|
||||
| Op::BroadcastMul(lhs, rhs)
|
||||
| Op::BroadcastSub(lhs, rhs)
|
||||
| Op::BroadcastDiv(lhs, rhs)
|
||||
| Op::Embedding(lhs, rhs)
|
||||
| Op::Matmul(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
@ -142,34 +138,6 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::BroadcastAdd(lhs, rhs) => {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&grad)?;
|
||||
}
|
||||
Op::BroadcastSub(lhs, rhs) => {
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?;
|
||||
}
|
||||
Op::BroadcastMul(lhs, rhs) => {
|
||||
let lhs_grad = grad.broadcast_mul(rhs)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?;
|
||||
let rhs_grad = grad.broadcast_mul(lhs)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
|
||||
}
|
||||
Op::BroadcastDiv(lhs, rhs) => {
|
||||
let lhs_grad = grad.broadcast_div(rhs)?;
|
||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
||||
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?;
|
||||
let rhs_grad = grad.broadcast_mul(lhs)?.broadcast_div(&rhs.sqr()?)?;
|
||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
|
||||
}
|
||||
Op::WhereCond(_pred, _t, _f) => {
|
||||
return Err(Error::BackwardNotSupported { op: "where_cond" })
|
||||
}
|
||||
|
@ -58,8 +58,7 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
||||
}
|
||||
}
|
||||
|
||||
// This function maps over two strided index sequences. It supports broadcasting in case
|
||||
// `lhs_stride` or `rhs_stride` has a length shorter than `shape`.
|
||||
// This function maps over two strided index sequences.
|
||||
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
shape: &Shape,
|
||||
lhs_stride: &[usize],
|
||||
@ -69,52 +68,15 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||
mut f: F,
|
||||
) -> Vec<T> {
|
||||
let dims = shape.dims();
|
||||
let broadcast_ldims = dims.len() - lhs_stride.len();
|
||||
let broadcast_rdims = dims.len() - rhs_stride.len();
|
||||
let elem_count = shape.elem_count();
|
||||
if broadcast_ldims == 0 && broadcast_rdims == 0 {
|
||||
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
|
||||
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
|
||||
} else {
|
||||
let lhs_index = StridedIndex::new(dims, lhs_stride);
|
||||
let rhs_index = StridedIndex::new(dims, rhs_stride);
|
||||
lhs_index
|
||||
.zip(rhs_index)
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect()
|
||||
}
|
||||
} else if broadcast_rdims == 0 {
|
||||
let mut res = Vec::new();
|
||||
res.reserve(elem_count);
|
||||
let lhs_v: Vec<T> = StridedIndex::new(dims, lhs_stride)
|
||||
.map(|i| lhs[i])
|
||||
.collect();
|
||||
let mut i = 0;
|
||||
for rhs_i in StridedIndex::new(dims, rhs_stride) {
|
||||
res.push(f(lhs_v[i], rhs[rhs_i]));
|
||||
i += 1;
|
||||
if i >= lhs_v.len() {
|
||||
i = 0
|
||||
}
|
||||
}
|
||||
res
|
||||
} else if broadcast_ldims == 0 {
|
||||
let mut res = Vec::new();
|
||||
res.reserve(elem_count);
|
||||
let rhs_v: Vec<T> = StridedIndex::new(dims, rhs_stride)
|
||||
.map(|i| rhs[i])
|
||||
.collect();
|
||||
let mut i = 0;
|
||||
for lhs_i in StridedIndex::new(dims, lhs_stride) {
|
||||
res.push(f(lhs[lhs_i], rhs_v[i]));
|
||||
i += 1;
|
||||
if i >= rhs_v.len() {
|
||||
i = 0
|
||||
}
|
||||
}
|
||||
res
|
||||
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
|
||||
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
|
||||
} else {
|
||||
panic!("unexpected broadcasting dims: {shape:?} {lhs_stride:?} {rhs_stride:?}")
|
||||
let lhs_index = StridedIndex::new(dims, lhs_stride);
|
||||
let rhs_index = StridedIndex::new(dims, rhs_stride);
|
||||
lhs_index
|
||||
.zip(rhs_index)
|
||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,10 +6,6 @@ pub(crate) enum Op {
|
||||
Mul(Tensor, Tensor),
|
||||
Sub(Tensor, Tensor),
|
||||
Div(Tensor, Tensor),
|
||||
BroadcastAdd(Tensor, Tensor),
|
||||
BroadcastMul(Tensor, Tensor),
|
||||
BroadcastSub(Tensor, Tensor),
|
||||
BroadcastDiv(Tensor, Tensor),
|
||||
Matmul(Tensor, Tensor),
|
||||
Embedding(Tensor, Tensor),
|
||||
WhereCond(Tensor, Tensor, Tensor),
|
||||
|
@ -88,21 +88,20 @@ macro_rules! binary_op {
|
||||
}
|
||||
|
||||
macro_rules! broadcast_binary_op {
|
||||
($fn_name:ident, $impl_name:ident, $op_name:ident) => {
|
||||
($fn_name:ident, $inner_fn_name:ident) => {
|
||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||
let shape = self.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let storage = self.storage.binary_impl::<crate::op::$impl_name>(
|
||||
&rhs.storage,
|
||||
shape,
|
||||
self.stride(),
|
||||
rhs.stride(),
|
||||
)?;
|
||||
let op = if self.track_op() || rhs.track_op() {
|
||||
Some(Op::$op_name(self.clone(), rhs.clone()))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(from_storage(storage, shape.clone(), op, false))
|
||||
let lhs = self;
|
||||
let shape = lhs.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||
let l_broadcast = shape != *lhs.shape();
|
||||
let r_broadcast = shape != *rhs.shape();
|
||||
match (l_broadcast, r_broadcast) {
|
||||
(true, true) => lhs
|
||||
.broadcast_as(&shape)?
|
||||
.$inner_fn_name(&rhs.broadcast_as(&shape)?),
|
||||
(false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
|
||||
(true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
|
||||
(false, false) => lhs.$inner_fn_name(rhs),
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@ -250,21 +249,41 @@ impl Tensor {
|
||||
&'a self,
|
||||
rhs: &'a Self,
|
||||
op: &'static str,
|
||||
) -> Result<&'a Shape> {
|
||||
) -> Result<Shape> {
|
||||
let lhs = self;
|
||||
let lhs_dims = lhs.shape().dims();
|
||||
let rhs_dims = rhs.shape().dims();
|
||||
if lhs_dims.strip_suffix(rhs_dims).is_some() {
|
||||
Ok(self.shape())
|
||||
} else if rhs_dims.strip_suffix(lhs_dims).is_some() {
|
||||
Ok(rhs.shape())
|
||||
} else {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op,
|
||||
})
|
||||
let lhs_ndims = lhs_dims.len();
|
||||
let rhs_ndims = rhs_dims.len();
|
||||
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
|
||||
let mut bcast_dims = vec![0; bcast_ndims];
|
||||
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
|
||||
let rev_idx = bcast_ndims - idx;
|
||||
let l_value = if lhs_ndims < rev_idx {
|
||||
1
|
||||
} else {
|
||||
lhs_dims[lhs_ndims - rev_idx]
|
||||
};
|
||||
let r_value = if rhs_ndims < rev_idx {
|
||||
1
|
||||
} else {
|
||||
rhs_dims[rhs_ndims - rev_idx]
|
||||
};
|
||||
*bcast_value = if l_value == r_value {
|
||||
l_value
|
||||
} else if l_value == 1 {
|
||||
r_value
|
||||
} else if r_value == 1 {
|
||||
l_value
|
||||
} else {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
lhs: self.shape().clone(),
|
||||
rhs: rhs.shape().clone(),
|
||||
op,
|
||||
})?
|
||||
}
|
||||
}
|
||||
Ok(Shape::from(bcast_dims))
|
||||
}
|
||||
|
||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||
@ -293,10 +312,10 @@ impl Tensor {
|
||||
binary_op!(mul, Mul);
|
||||
binary_op!(sub, Sub);
|
||||
binary_op!(div, Div);
|
||||
broadcast_binary_op!(broadcast_add, Add, BroadcastAdd);
|
||||
broadcast_binary_op!(broadcast_mul, Mul, BroadcastMul);
|
||||
broadcast_binary_op!(broadcast_sub, Sub, BroadcastSub);
|
||||
broadcast_binary_op!(broadcast_div, Div, BroadcastDiv);
|
||||
broadcast_binary_op!(broadcast_add, add);
|
||||
broadcast_binary_op!(broadcast_mul, mul);
|
||||
broadcast_binary_op!(broadcast_sub, sub);
|
||||
broadcast_binary_op!(broadcast_div, div);
|
||||
|
||||
unary_op!(neg, Neg);
|
||||
unary_op!(exp, Exp);
|
||||
|
Reference in New Issue
Block a user