mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add a couple operators.
This commit is contained in:
@ -190,7 +190,10 @@ impl Tensor {
|
||||
// if this can create cycles in the compute graph.
|
||||
binary_op!(add, Add, add_impl);
|
||||
binary_op!(mul, Mul, mul_impl);
|
||||
binary_op!(sub, Sub, sub_impl);
|
||||
binary_op!(div, Div, div_impl);
|
||||
|
||||
unary_op!(neg, Neg, neg_impl);
|
||||
unary_op!(sqr, Sqr, sqr_impl);
|
||||
unary_op!(sqrt, Sqrt, sqrt_impl);
|
||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||
@ -320,7 +323,10 @@ impl Tensor {
|
||||
nodes
|
||||
} else if let Some(op) = &node.op {
|
||||
match op {
|
||||
Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => {
|
||||
Op::Add(lhs, rhs)
|
||||
| Op::Mul(lhs, rhs)
|
||||
| Op::Sub(lhs, rhs)
|
||||
| Op::Div(lhs, rhs) => {
|
||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||
@ -336,7 +342,7 @@ impl Tensor {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Sqr(node) | Op::Sqrt(node) => {
|
||||
Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
nodes
|
||||
@ -378,6 +384,12 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
|
||||
}
|
||||
Op::Sub(lhs, rhs) => {
|
||||
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?;
|
||||
}
|
||||
Op::Mul(lhs, rhs) => {
|
||||
let lhs_grad = grad.mul(rhs)?;
|
||||
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
||||
@ -386,22 +398,33 @@ impl Tensor {
|
||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Div(lhs, rhs) => {
|
||||
let lhs_grad = grad.div(rhs)?;
|
||||
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
||||
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||
}
|
||||
Op::Affine { arg, mul, .. } => {
|
||||
let arg_grad = grad.affine(*mul, 0.)?;
|
||||
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Neg(arg) => {
|
||||
let arg_grad = grad.neg()?;
|
||||
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Sqr(arg) => {
|
||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Sqrt(_arg) => {
|
||||
todo!()
|
||||
// TODO: Add div to enable the following.
|
||||
// let arg_grad = grad / (2 * arg)
|
||||
// let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
||||
// *sum_grad = sum_grad.add(arg_grad)?
|
||||
Op::Sqrt(arg) => {
|
||||
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
|
||||
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
};
|
||||
}
|
||||
|
Reference in New Issue
Block a user