mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Cleanup the broadcast setup.
This commit is contained in:
@ -37,10 +37,6 @@ impl Tensor {
|
|||||||
| Op::Mul(lhs, rhs)
|
| Op::Mul(lhs, rhs)
|
||||||
| Op::Sub(lhs, rhs)
|
| Op::Sub(lhs, rhs)
|
||||||
| Op::Div(lhs, rhs)
|
| Op::Div(lhs, rhs)
|
||||||
| Op::BroadcastAdd(lhs, rhs)
|
|
||||||
| Op::BroadcastMul(lhs, rhs)
|
|
||||||
| Op::BroadcastSub(lhs, rhs)
|
|
||||||
| Op::BroadcastDiv(lhs, rhs)
|
|
||||||
| Op::Embedding(lhs, rhs)
|
| Op::Embedding(lhs, rhs)
|
||||||
| Op::Matmul(lhs, rhs) => {
|
| Op::Matmul(lhs, rhs) => {
|
||||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||||
@ -142,34 +138,6 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
let rhs_sum_grad = grads.or_insert(rhs)?;
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
Op::BroadcastAdd(lhs, rhs) => {
|
|
||||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
|
||||||
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
|
|
||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
|
||||||
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&grad)?;
|
|
||||||
}
|
|
||||||
Op::BroadcastSub(lhs, rhs) => {
|
|
||||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
|
||||||
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&grad)?;
|
|
||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
|
||||||
*rhs_sum_grad = rhs_sum_grad.broadcast_sub(&grad)?;
|
|
||||||
}
|
|
||||||
Op::BroadcastMul(lhs, rhs) => {
|
|
||||||
let lhs_grad = grad.broadcast_mul(rhs)?;
|
|
||||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
|
||||||
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?;
|
|
||||||
let rhs_grad = grad.broadcast_mul(lhs)?;
|
|
||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
|
||||||
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
|
|
||||||
}
|
|
||||||
Op::BroadcastDiv(lhs, rhs) => {
|
|
||||||
let lhs_grad = grad.broadcast_div(rhs)?;
|
|
||||||
let lhs_sum_grad = grads.or_insert(lhs)?;
|
|
||||||
*lhs_sum_grad = lhs_sum_grad.broadcast_add(&lhs_grad)?;
|
|
||||||
let rhs_grad = grad.broadcast_mul(lhs)?.broadcast_div(&rhs.sqr()?)?;
|
|
||||||
let rhs_sum_grad = grads.or_insert(rhs)?;
|
|
||||||
*rhs_sum_grad = rhs_sum_grad.broadcast_add(&rhs_grad)?;
|
|
||||||
}
|
|
||||||
Op::WhereCond(_pred, _t, _f) => {
|
Op::WhereCond(_pred, _t, _f) => {
|
||||||
return Err(Error::BackwardNotSupported { op: "where_cond" })
|
return Err(Error::BackwardNotSupported { op: "where_cond" })
|
||||||
}
|
}
|
||||||
|
@ -58,8 +58,7 @@ fn unary_map<T: Copy, U: Copy, F: FnMut(T) -> U>(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function maps over two strided index sequences. It supports broadcasting in case
|
// This function maps over two strided index sequences.
|
||||||
// `lhs_stride` or `rhs_stride` has a length shorter than `shape`.
|
|
||||||
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
||||||
shape: &Shape,
|
shape: &Shape,
|
||||||
lhs_stride: &[usize],
|
lhs_stride: &[usize],
|
||||||
@ -69,10 +68,6 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
|||||||
mut f: F,
|
mut f: F,
|
||||||
) -> Vec<T> {
|
) -> Vec<T> {
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let broadcast_ldims = dims.len() - lhs_stride.len();
|
|
||||||
let broadcast_rdims = dims.len() - rhs_stride.len();
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
if broadcast_ldims == 0 && broadcast_rdims == 0 {
|
|
||||||
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
|
if shape.is_contiguous(lhs_stride) && shape.is_contiguous(rhs_stride) {
|
||||||
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
|
(0..shape.elem_count()).map(|i| f(lhs[i], rhs[i])).collect()
|
||||||
} else {
|
} else {
|
||||||
@ -83,39 +78,6 @@ fn binary_map<T: Copy, F: FnMut(T, T) -> T>(
|
|||||||
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
.map(|(lhs_i, rhs_i)| f(lhs[lhs_i], rhs[rhs_i]))
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
} else if broadcast_rdims == 0 {
|
|
||||||
let mut res = Vec::new();
|
|
||||||
res.reserve(elem_count);
|
|
||||||
let lhs_v: Vec<T> = StridedIndex::new(dims, lhs_stride)
|
|
||||||
.map(|i| lhs[i])
|
|
||||||
.collect();
|
|
||||||
let mut i = 0;
|
|
||||||
for rhs_i in StridedIndex::new(dims, rhs_stride) {
|
|
||||||
res.push(f(lhs_v[i], rhs[rhs_i]));
|
|
||||||
i += 1;
|
|
||||||
if i >= lhs_v.len() {
|
|
||||||
i = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
res
|
|
||||||
} else if broadcast_ldims == 0 {
|
|
||||||
let mut res = Vec::new();
|
|
||||||
res.reserve(elem_count);
|
|
||||||
let rhs_v: Vec<T> = StridedIndex::new(dims, rhs_stride)
|
|
||||||
.map(|i| rhs[i])
|
|
||||||
.collect();
|
|
||||||
let mut i = 0;
|
|
||||||
for lhs_i in StridedIndex::new(dims, lhs_stride) {
|
|
||||||
res.push(f(lhs[lhs_i], rhs_v[i]));
|
|
||||||
i += 1;
|
|
||||||
if i >= rhs_v.len() {
|
|
||||||
i = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
res
|
|
||||||
} else {
|
|
||||||
panic!("unexpected broadcasting dims: {shape:?} {lhs_stride:?} {rhs_stride:?}")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn take<T: Copy>(
|
fn take<T: Copy>(
|
||||||
|
@ -6,10 +6,6 @@ pub(crate) enum Op {
|
|||||||
Mul(Tensor, Tensor),
|
Mul(Tensor, Tensor),
|
||||||
Sub(Tensor, Tensor),
|
Sub(Tensor, Tensor),
|
||||||
Div(Tensor, Tensor),
|
Div(Tensor, Tensor),
|
||||||
BroadcastAdd(Tensor, Tensor),
|
|
||||||
BroadcastMul(Tensor, Tensor),
|
|
||||||
BroadcastSub(Tensor, Tensor),
|
|
||||||
BroadcastDiv(Tensor, Tensor),
|
|
||||||
Matmul(Tensor, Tensor),
|
Matmul(Tensor, Tensor),
|
||||||
Embedding(Tensor, Tensor),
|
Embedding(Tensor, Tensor),
|
||||||
WhereCond(Tensor, Tensor, Tensor),
|
WhereCond(Tensor, Tensor, Tensor),
|
||||||
|
@ -88,21 +88,20 @@ macro_rules! binary_op {
|
|||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! broadcast_binary_op {
|
macro_rules! broadcast_binary_op {
|
||||||
($fn_name:ident, $impl_name:ident, $op_name:ident) => {
|
($fn_name:ident, $inner_fn_name:ident) => {
|
||||||
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
pub fn $fn_name(&self, rhs: &Self) -> Result<Self> {
|
||||||
let shape = self.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
|
let lhs = self;
|
||||||
let storage = self.storage.binary_impl::<crate::op::$impl_name>(
|
let shape = lhs.broadcast_shape_binary_op(rhs, stringify!($fn_name))?;
|
||||||
&rhs.storage,
|
let l_broadcast = shape != *lhs.shape();
|
||||||
shape,
|
let r_broadcast = shape != *rhs.shape();
|
||||||
self.stride(),
|
match (l_broadcast, r_broadcast) {
|
||||||
rhs.stride(),
|
(true, true) => lhs
|
||||||
)?;
|
.broadcast_as(&shape)?
|
||||||
let op = if self.track_op() || rhs.track_op() {
|
.$inner_fn_name(&rhs.broadcast_as(&shape)?),
|
||||||
Some(Op::$op_name(self.clone(), rhs.clone()))
|
(false, true) => lhs.$inner_fn_name(&rhs.broadcast_as(&shape)?),
|
||||||
} else {
|
(true, false) => lhs.broadcast_as(&shape)?.$inner_fn_name(rhs),
|
||||||
None
|
(false, false) => lhs.$inner_fn_name(rhs),
|
||||||
};
|
}
|
||||||
Ok(from_storage(storage, shape.clone(), op, false))
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -250,22 +249,42 @@ impl Tensor {
|
|||||||
&'a self,
|
&'a self,
|
||||||
rhs: &'a Self,
|
rhs: &'a Self,
|
||||||
op: &'static str,
|
op: &'static str,
|
||||||
) -> Result<&'a Shape> {
|
) -> Result<Shape> {
|
||||||
let lhs = self;
|
let lhs = self;
|
||||||
let lhs_dims = lhs.shape().dims();
|
let lhs_dims = lhs.shape().dims();
|
||||||
let rhs_dims = rhs.shape().dims();
|
let rhs_dims = rhs.shape().dims();
|
||||||
if lhs_dims.strip_suffix(rhs_dims).is_some() {
|
let lhs_ndims = lhs_dims.len();
|
||||||
Ok(self.shape())
|
let rhs_ndims = rhs_dims.len();
|
||||||
} else if rhs_dims.strip_suffix(lhs_dims).is_some() {
|
let bcast_ndims = usize::max(lhs_ndims, rhs_ndims);
|
||||||
Ok(rhs.shape())
|
let mut bcast_dims = vec![0; bcast_ndims];
|
||||||
|
for (idx, bcast_value) in bcast_dims.iter_mut().enumerate() {
|
||||||
|
let rev_idx = bcast_ndims - idx;
|
||||||
|
let l_value = if lhs_ndims < rev_idx {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
lhs_dims[lhs_ndims - rev_idx]
|
||||||
|
};
|
||||||
|
let r_value = if rhs_ndims < rev_idx {
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
rhs_dims[rhs_ndims - rev_idx]
|
||||||
|
};
|
||||||
|
*bcast_value = if l_value == r_value {
|
||||||
|
l_value
|
||||||
|
} else if l_value == 1 {
|
||||||
|
r_value
|
||||||
|
} else if r_value == 1 {
|
||||||
|
l_value
|
||||||
} else {
|
} else {
|
||||||
Err(Error::ShapeMismatchBinaryOp {
|
Err(Error::ShapeMismatchBinaryOp {
|
||||||
lhs: self.shape().clone(),
|
lhs: self.shape().clone(),
|
||||||
rhs: rhs.shape().clone(),
|
rhs: rhs.shape().clone(),
|
||||||
op,
|
op,
|
||||||
})
|
})?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Ok(Shape::from(bcast_dims))
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
pub(crate) fn same_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<&Shape> {
|
||||||
let lhs = self.shape();
|
let lhs = self.shape();
|
||||||
@ -293,10 +312,10 @@ impl Tensor {
|
|||||||
binary_op!(mul, Mul);
|
binary_op!(mul, Mul);
|
||||||
binary_op!(sub, Sub);
|
binary_op!(sub, Sub);
|
||||||
binary_op!(div, Div);
|
binary_op!(div, Div);
|
||||||
broadcast_binary_op!(broadcast_add, Add, BroadcastAdd);
|
broadcast_binary_op!(broadcast_add, add);
|
||||||
broadcast_binary_op!(broadcast_mul, Mul, BroadcastMul);
|
broadcast_binary_op!(broadcast_mul, mul);
|
||||||
broadcast_binary_op!(broadcast_sub, Sub, BroadcastSub);
|
broadcast_binary_op!(broadcast_sub, sub);
|
||||||
broadcast_binary_op!(broadcast_div, Div, BroadcastDiv);
|
broadcast_binary_op!(broadcast_div, div);
|
||||||
|
|
||||||
unary_op!(neg, Neg);
|
unary_op!(neg, Neg);
|
||||||
unary_op!(exp, Exp);
|
unary_op!(exp, Exp);
|
||||||
|
Reference in New Issue
Block a user