mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a couple operators.
This commit is contained in:
@ -2,13 +2,17 @@ use crate::Tensor;
|
|||||||
|
|
||||||
pub(crate) enum Op {
|
pub(crate) enum Op {
|
||||||
Add(Tensor, Tensor),
|
Add(Tensor, Tensor),
|
||||||
|
Mul(Tensor, Tensor),
|
||||||
|
Sub(Tensor, Tensor),
|
||||||
|
Div(Tensor, Tensor),
|
||||||
|
|
||||||
#[allow(dead_code)] // add is currently unused.
|
#[allow(dead_code)] // add is currently unused.
|
||||||
Affine {
|
Affine {
|
||||||
arg: Tensor,
|
arg: Tensor,
|
||||||
mul: f64,
|
mul: f64,
|
||||||
add: f64,
|
add: f64,
|
||||||
},
|
},
|
||||||
Mul(Tensor, Tensor),
|
Neg(Tensor),
|
||||||
Sqr(Tensor),
|
Sqr(Tensor),
|
||||||
Sqrt(Tensor),
|
Sqrt(Tensor),
|
||||||
// TODO: Support for custom ops.
|
// TODO: Support for custom ops.
|
||||||
|
@ -95,7 +95,10 @@ trait BinaryOp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct Add;
|
struct Add;
|
||||||
|
struct Div;
|
||||||
struct Mul;
|
struct Mul;
|
||||||
|
struct Sub;
|
||||||
|
struct Neg;
|
||||||
struct Sqr;
|
struct Sqr;
|
||||||
struct Sqrt;
|
struct Sqrt;
|
||||||
|
|
||||||
@ -109,6 +112,16 @@ impl BinaryOp for Add {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl BinaryOp for Sub {
|
||||||
|
const NAME: &'static str = "sub";
|
||||||
|
fn f32(v1: f32, v2: f32) -> f32 {
|
||||||
|
v1 - v2
|
||||||
|
}
|
||||||
|
fn f64(v1: f64, v2: f64) -> f64 {
|
||||||
|
v1 - v2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl BinaryOp for Mul {
|
impl BinaryOp for Mul {
|
||||||
const NAME: &'static str = "mul";
|
const NAME: &'static str = "mul";
|
||||||
fn f32(v1: f32, v2: f32) -> f32 {
|
fn f32(v1: f32, v2: f32) -> f32 {
|
||||||
@ -119,6 +132,26 @@ impl BinaryOp for Mul {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl BinaryOp for Div {
|
||||||
|
const NAME: &'static str = "div";
|
||||||
|
fn f32(v1: f32, v2: f32) -> f32 {
|
||||||
|
v1 / v2
|
||||||
|
}
|
||||||
|
fn f64(v1: f64, v2: f64) -> f64 {
|
||||||
|
v1 / v2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl UnaryOp for Neg {
|
||||||
|
const NAME: &'static str = "neg";
|
||||||
|
fn f32(v1: f32) -> f32 {
|
||||||
|
-v1
|
||||||
|
}
|
||||||
|
fn f64(v1: f64) -> f64 {
|
||||||
|
-v1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl UnaryOp for Sqr {
|
impl UnaryOp for Sqr {
|
||||||
const NAME: &'static str = "sqr";
|
const NAME: &'static str = "sqr";
|
||||||
fn f32(v1: f32) -> f32 {
|
fn f32(v1: f32) -> f32 {
|
||||||
@ -272,6 +305,16 @@ impl Storage {
|
|||||||
self.binary_impl::<Add>(rhs, shape, lhs_stride, rhs_stride)
|
self.binary_impl::<Add>(rhs, shape, lhs_stride, rhs_stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn sub_impl(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
shape: &Shape,
|
||||||
|
lhs_stride: &[usize],
|
||||||
|
rhs_stride: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.binary_impl::<Sub>(rhs, shape, lhs_stride, rhs_stride)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn mul_impl(
|
pub(crate) fn mul_impl(
|
||||||
&self,
|
&self,
|
||||||
rhs: &Self,
|
rhs: &Self,
|
||||||
@ -282,6 +325,20 @@ impl Storage {
|
|||||||
self.binary_impl::<Mul>(rhs, shape, lhs_stride, rhs_stride)
|
self.binary_impl::<Mul>(rhs, shape, lhs_stride, rhs_stride)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn div_impl(
|
||||||
|
&self,
|
||||||
|
rhs: &Self,
|
||||||
|
shape: &Shape,
|
||||||
|
lhs_stride: &[usize],
|
||||||
|
rhs_stride: &[usize],
|
||||||
|
) -> Result<Self> {
|
||||||
|
self.binary_impl::<Div>(rhs, shape, lhs_stride, rhs_stride)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn neg_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||||
|
self.unary_impl::<Neg>(shape, stride)
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
pub(crate) fn sqr_impl(&self, shape: &Shape, stride: &[usize]) -> Result<Self> {
|
||||||
self.unary_impl::<Sqr>(shape, stride)
|
self.unary_impl::<Sqr>(shape, stride)
|
||||||
}
|
}
|
||||||
|
@ -190,7 +190,10 @@ impl Tensor {
|
|||||||
// if this can create cycles in the compute graph.
|
// if this can create cycles in the compute graph.
|
||||||
binary_op!(add, Add, add_impl);
|
binary_op!(add, Add, add_impl);
|
||||||
binary_op!(mul, Mul, mul_impl);
|
binary_op!(mul, Mul, mul_impl);
|
||||||
|
binary_op!(sub, Sub, sub_impl);
|
||||||
|
binary_op!(div, Div, div_impl);
|
||||||
|
|
||||||
|
unary_op!(neg, Neg, neg_impl);
|
||||||
unary_op!(sqr, Sqr, sqr_impl);
|
unary_op!(sqr, Sqr, sqr_impl);
|
||||||
unary_op!(sqrt, Sqrt, sqrt_impl);
|
unary_op!(sqrt, Sqrt, sqrt_impl);
|
||||||
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
pub fn to_scalar<S: crate::WithDType>(&self) -> Result<S> {
|
||||||
@ -320,7 +323,10 @@ impl Tensor {
|
|||||||
nodes
|
nodes
|
||||||
} else if let Some(op) = &node.op {
|
} else if let Some(op) = &node.op {
|
||||||
match op {
|
match op {
|
||||||
Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) => {
|
Op::Add(lhs, rhs)
|
||||||
|
| Op::Mul(lhs, rhs)
|
||||||
|
| Op::Sub(lhs, rhs)
|
||||||
|
| Op::Div(lhs, rhs) => {
|
||||||
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
let (tg, nodes) = walk(lhs, nodes, already_seen);
|
||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
let (tg, nodes) = walk(rhs, nodes, already_seen);
|
||||||
@ -336,7 +342,7 @@ impl Tensor {
|
|||||||
nodes
|
nodes
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Op::Sqr(node) | Op::Sqrt(node) => {
|
Op::Sqr(node) | Op::Sqrt(node) | Op::Neg(node) => {
|
||||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||||
track_grad |= tg;
|
track_grad |= tg;
|
||||||
nodes
|
nodes
|
||||||
@ -378,6 +384,12 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&grad)?;
|
||||||
}
|
}
|
||||||
|
Op::Sub(lhs, rhs) => {
|
||||||
|
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
||||||
|
*lhs_sum_grad = lhs_sum_grad.add(&grad)?;
|
||||||
|
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||||
|
*rhs_sum_grad = rhs_sum_grad.add(&grad.neg()?)?;
|
||||||
|
}
|
||||||
Op::Mul(lhs, rhs) => {
|
Op::Mul(lhs, rhs) => {
|
||||||
let lhs_grad = grad.mul(rhs)?;
|
let lhs_grad = grad.mul(rhs)?;
|
||||||
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
||||||
@ -386,22 +398,33 @@ impl Tensor {
|
|||||||
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||||
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
}
|
}
|
||||||
|
Op::Div(lhs, rhs) => {
|
||||||
|
let lhs_grad = grad.div(rhs)?;
|
||||||
|
let lhs_sum_grad = grads.entry(lhs.id).or_insert_with(|| lhs.zeros_like());
|
||||||
|
*lhs_sum_grad = lhs_sum_grad.add(&lhs_grad)?;
|
||||||
|
let rhs_grad = grad.mul(lhs)?.div(&rhs.sqr()?)?;
|
||||||
|
let rhs_sum_grad = grads.entry(rhs.id).or_insert_with(|| rhs.zeros_like());
|
||||||
|
*rhs_sum_grad = rhs_sum_grad.add(&rhs_grad)?;
|
||||||
|
}
|
||||||
Op::Affine { arg, mul, .. } => {
|
Op::Affine { arg, mul, .. } => {
|
||||||
let arg_grad = grad.affine(*mul, 0.)?;
|
let arg_grad = grad.affine(*mul, 0.)?;
|
||||||
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
|
Op::Neg(arg) => {
|
||||||
|
let arg_grad = grad.neg()?;
|
||||||
|
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
||||||
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
|
}
|
||||||
Op::Sqr(arg) => {
|
Op::Sqr(arg) => {
|
||||||
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
let arg_grad = arg.mul(&grad)?.affine(2., 0.)?;
|
||||||
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
||||||
*sum_grad = sum_grad.add(&arg_grad)?
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
}
|
}
|
||||||
Op::Sqrt(_arg) => {
|
Op::Sqrt(arg) => {
|
||||||
todo!()
|
let arg_grad = grad.div(arg)?.affine(0.5, 0.)?;
|
||||||
// TODO: Add div to enable the following.
|
let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
||||||
// let arg_grad = grad / (2 * arg)
|
*sum_grad = sum_grad.add(&arg_grad)?
|
||||||
// let sum_grad = grads.entry(arg.id).or_insert_with(|| arg.zeros_like());
|
|
||||||
// *sum_grad = sum_grad.add(arg_grad)?
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user