Add a couple unitary ops.

This commit is contained in:
laurent
2023-06-23 20:19:20 +01:00
parent fe75a01188
commit 8ed350dc94
3 changed files with 127 additions and 0 deletions

View File

@ -37,6 +37,11 @@ __device__ T gelu_fwd(T x) {
#if __CUDA_ARCH__ >= 530
UNARY_OP(__half, ucopy_f16, x)
UNARY_OP(__half, uneg_f16, -x)
UNARY_OP(__half, uexp_f16, expg(x))
UNARY_OP(__half, ulog_f16, logg(x))
UNARY_OP(__half, usin_f16, sing(x))
UNARY_OP(__half, ucos_f16, cosg(x))
UNARY_OP(__half, uabs_f16, absg(x))
UNARY_OP(__half, usqr_f16, x*x)
UNARY_OP(__half, usqrt_f16, sqrtg(x))
UNARY_OP(__half, gelu_f16, gelu_fwd(x))
@ -46,6 +51,16 @@ UNARY_OP(float, ucopy_f32, x)
UNARY_OP(double, ucopy_f64, x)
UNARY_OP(float, uneg_f32, -x)
UNARY_OP(double, uneg_f64, -x)
UNARY_OP(float, uexp_f32, expg(x))
UNARY_OP(double, uexp_f64, expg(x))
UNARY_OP(float, ulog_f32, logg(x))
UNARY_OP(double, ulog_f64, logg(x))
UNARY_OP(float, usin_f32, sing(x))
UNARY_OP(double, usin_f64, sing(x))
UNARY_OP(float, ucos_f32, cosg(x))
UNARY_OP(double, ucos_f64, cosg(x))
UNARY_OP(float, uabsg_f32, absg(x))
UNARY_OP(double, uabsg_f64, absg(x))
UNARY_OP(float, usqr_f32, x*x)
UNARY_OP(double, usqr_f64, x*x)
UNARY_OP(float, usqrt_f32, sqrtg(x))

View File

@ -21,6 +21,11 @@ pub(crate) enum Op {
mul: f64,
add: f64,
},
Exp(Tensor),
Log(Tensor),
Sin(Tensor),
Cos(Tensor),
Abs(Tensor),
Neg(Tensor),
Reshape(Tensor),
#[allow(dead_code)]
@ -60,6 +65,11 @@ pub(crate) struct Add;
pub(crate) struct Div;
pub(crate) struct Mul;
pub(crate) struct Sub;
pub(crate) struct Exp;
pub(crate) struct Log;
pub(crate) struct Sin;
pub(crate) struct Cos;
pub(crate) struct Abs;
pub(crate) struct Neg;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
@ -129,6 +139,81 @@ impl BinaryOp for Div {
}
}
impl UnaryOp for Exp {
const NAME: &'static str = "exp";
fn f32(v1: f32) -> f32 {
v1.exp()
}
fn f64(v1: f64) -> f64 {
v1.exp()
}
fn u32(v1: u32) -> u32 {
(v1 as f64).exp() as u32
}
const KERNEL_F32: &'static str = "uexp_f32";
const KERNEL_F64: &'static str = "uexp_f64";
}
impl UnaryOp for Log {
const NAME: &'static str = "log";
fn f32(v1: f32) -> f32 {
v1.ln()
}
fn f64(v1: f64) -> f64 {
v1.ln()
}
fn u32(v1: u32) -> u32 {
(v1 as f64).ln() as u32
}
const KERNEL_F32: &'static str = "ulog_f32";
const KERNEL_F64: &'static str = "ulog_f64";
}
impl UnaryOp for Sin {
const NAME: &'static str = "sin";
fn f32(v1: f32) -> f32 {
v1.sin()
}
fn f64(v1: f64) -> f64 {
v1.sin()
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_F32: &'static str = "usin_f32";
const KERNEL_F64: &'static str = "usin_f64";
}
impl UnaryOp for Cos {
const NAME: &'static str = "cos";
fn f32(v1: f32) -> f32 {
v1.cos()
}
fn f64(v1: f64) -> f64 {
v1.cos()
}
fn u32(_: u32) -> u32 {
0
}
const KERNEL_F32: &'static str = "ucos_f32";
const KERNEL_F64: &'static str = "ucos_f64";
}
impl UnaryOp for Abs {
const NAME: &'static str = "abs";
fn f32(v1: f32) -> f32 {
v1.abs()
}
fn f64(v1: f64) -> f64 {
v1.abs()
}
fn u32(v1: u32) -> u32 {
v1
}
const KERNEL_F32: &'static str = "uabs_f32";
const KERNEL_F64: &'static str = "uabs_f64";
}
impl UnaryOp for Neg {
const NAME: &'static str = "neg";
fn f32(v1: f32) -> f32 {

View File

@ -254,6 +254,11 @@ impl Tensor {
broadcast_binary_op!(broadcast_div, Div, BroadcastDiv);
unary_op!(neg, Neg);
unary_op!(exp, Exp);
unary_op!(log, Log);
unary_op!(sin, Sin);
unary_op!(cos, Cos);
unary_op!(abs, Abs);
unary_op!(sqr, Sqr);
unary_op!(sqrt, Sqrt);
unary_op!(gelu, Gelu);
@ -774,6 +779,11 @@ impl Tensor {
| Op::Sqr(node)
| Op::Sqrt(node)
| Op::Gelu(node)
| Op::Exp(node)
| Op::Log(node)
| Op::Sin(node)
| Op::Cos(node)
| Op::Abs(node)
| Op::Neg(node) => {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
@ -887,6 +897,23 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&arg_grad)?
}
Op::Log(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * *node)?)?
}
Op::Sin(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad * arg.cos())?)?
}
Op::Cos(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&(&grad * arg.sin())?)?
}
Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }),
Op::Exp(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&(&grad / arg)?)?
}
Op::Neg(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)?