mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add a couple unitary ops.
This commit is contained in:
@ -37,6 +37,11 @@ __device__ T gelu_fwd(T x) {
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
UNARY_OP(__half, ucopy_f16, x)
|
||||
UNARY_OP(__half, uneg_f16, -x)
|
||||
UNARY_OP(__half, uexp_f16, expg(x))
|
||||
UNARY_OP(__half, ulog_f16, logg(x))
|
||||
UNARY_OP(__half, usin_f16, sing(x))
|
||||
UNARY_OP(__half, ucos_f16, cosg(x))
|
||||
UNARY_OP(__half, uabs_f16, absg(x))
|
||||
UNARY_OP(__half, usqr_f16, x*x)
|
||||
UNARY_OP(__half, usqrt_f16, sqrtg(x))
|
||||
UNARY_OP(__half, gelu_f16, gelu_fwd(x))
|
||||
@ -46,6 +51,16 @@ UNARY_OP(float, ucopy_f32, x)
|
||||
UNARY_OP(double, ucopy_f64, x)
|
||||
UNARY_OP(float, uneg_f32, -x)
|
||||
UNARY_OP(double, uneg_f64, -x)
|
||||
UNARY_OP(float, uexp_f32, expg(x))
|
||||
UNARY_OP(double, uexp_f64, expg(x))
|
||||
UNARY_OP(float, ulog_f32, logg(x))
|
||||
UNARY_OP(double, ulog_f64, logg(x))
|
||||
UNARY_OP(float, usin_f32, sing(x))
|
||||
UNARY_OP(double, usin_f64, sing(x))
|
||||
UNARY_OP(float, ucos_f32, cosg(x))
|
||||
UNARY_OP(double, ucos_f64, cosg(x))
|
||||
UNARY_OP(float, uabsg_f32, absg(x))
|
||||
UNARY_OP(double, uabsg_f64, absg(x))
|
||||
UNARY_OP(float, usqr_f32, x*x)
|
||||
UNARY_OP(double, usqr_f64, x*x)
|
||||
UNARY_OP(float, usqrt_f32, sqrtg(x))
|
||||
|
85
src/op.rs
85
src/op.rs
@ -21,6 +21,11 @@ pub(crate) enum Op {
|
||||
mul: f64,
|
||||
add: f64,
|
||||
},
|
||||
Exp(Tensor),
|
||||
Log(Tensor),
|
||||
Sin(Tensor),
|
||||
Cos(Tensor),
|
||||
Abs(Tensor),
|
||||
Neg(Tensor),
|
||||
Reshape(Tensor),
|
||||
#[allow(dead_code)]
|
||||
@ -60,6 +65,11 @@ pub(crate) struct Add;
|
||||
pub(crate) struct Div;
|
||||
pub(crate) struct Mul;
|
||||
pub(crate) struct Sub;
|
||||
pub(crate) struct Exp;
|
||||
pub(crate) struct Log;
|
||||
pub(crate) struct Sin;
|
||||
pub(crate) struct Cos;
|
||||
pub(crate) struct Abs;
|
||||
pub(crate) struct Neg;
|
||||
pub(crate) struct Sqr;
|
||||
pub(crate) struct Sqrt;
|
||||
@ -129,6 +139,81 @@ impl BinaryOp for Div {
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOp for Exp {
|
||||
const NAME: &'static str = "exp";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.exp()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.exp()
|
||||
}
|
||||
fn u32(v1: u32) -> u32 {
|
||||
(v1 as f64).exp() as u32
|
||||
}
|
||||
const KERNEL_F32: &'static str = "uexp_f32";
|
||||
const KERNEL_F64: &'static str = "uexp_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Log {
|
||||
const NAME: &'static str = "log";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.ln()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.ln()
|
||||
}
|
||||
fn u32(v1: u32) -> u32 {
|
||||
(v1 as f64).ln() as u32
|
||||
}
|
||||
const KERNEL_F32: &'static str = "ulog_f32";
|
||||
const KERNEL_F64: &'static str = "ulog_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Sin {
|
||||
const NAME: &'static str = "sin";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.sin()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.sin()
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
const KERNEL_F32: &'static str = "usin_f32";
|
||||
const KERNEL_F64: &'static str = "usin_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Cos {
|
||||
const NAME: &'static str = "cos";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.cos()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.cos()
|
||||
}
|
||||
fn u32(_: u32) -> u32 {
|
||||
0
|
||||
}
|
||||
const KERNEL_F32: &'static str = "ucos_f32";
|
||||
const KERNEL_F64: &'static str = "ucos_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Abs {
|
||||
const NAME: &'static str = "abs";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
v1.abs()
|
||||
}
|
||||
fn f64(v1: f64) -> f64 {
|
||||
v1.abs()
|
||||
}
|
||||
fn u32(v1: u32) -> u32 {
|
||||
v1
|
||||
}
|
||||
const KERNEL_F32: &'static str = "uabs_f32";
|
||||
const KERNEL_F64: &'static str = "uabs_f64";
|
||||
}
|
||||
|
||||
impl UnaryOp for Neg {
|
||||
const NAME: &'static str = "neg";
|
||||
fn f32(v1: f32) -> f32 {
|
||||
|
@ -254,6 +254,11 @@ impl Tensor {
|
||||
broadcast_binary_op!(broadcast_div, Div, BroadcastDiv);
|
||||
|
||||
unary_op!(neg, Neg);
|
||||
unary_op!(exp, Exp);
|
||||
unary_op!(log, Log);
|
||||
unary_op!(sin, Sin);
|
||||
unary_op!(cos, Cos);
|
||||
unary_op!(abs, Abs);
|
||||
unary_op!(sqr, Sqr);
|
||||
unary_op!(sqrt, Sqrt);
|
||||
unary_op!(gelu, Gelu);
|
||||
@ -774,6 +779,11 @@ impl Tensor {
|
||||
| Op::Sqr(node)
|
||||
| Op::Sqrt(node)
|
||||
| Op::Gelu(node)
|
||||
| Op::Exp(node)
|
||||
| Op::Log(node)
|
||||
| Op::Sin(node)
|
||||
| Op::Cos(node)
|
||||
| Op::Abs(node)
|
||||
| Op::Neg(node) => {
|
||||
let (tg, nodes) = walk(node, nodes, already_seen);
|
||||
track_grad |= tg;
|
||||
@ -887,6 +897,23 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Log(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * *node)?)?
|
||||
}
|
||||
Op::Sin(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
|
||||
}
|
||||
Op::Cos(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
|
||||
}
|
||||
Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }),
|
||||
Op::Exp(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&(&grad / arg)?)?
|
||||
}
|
||||
Op::Neg(arg) => {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.sub(&grad)?
|
||||
|
Reference in New Issue
Block a user