diff --git a/kernels/src/unary.cu b/kernels/src/unary.cu index ef672f63..03f9353b 100644 --- a/kernels/src/unary.cu +++ b/kernels/src/unary.cu @@ -37,6 +37,11 @@ __device__ T gelu_fwd(T x) { #if __CUDA_ARCH__ >= 530 UNARY_OP(__half, ucopy_f16, x) UNARY_OP(__half, uneg_f16, -x) +UNARY_OP(__half, uexp_f16, expg(x)) +UNARY_OP(__half, ulog_f16, logg(x)) +UNARY_OP(__half, usin_f16, sing(x)) +UNARY_OP(__half, ucos_f16, cosg(x)) +UNARY_OP(__half, uabs_f16, absg(x)) UNARY_OP(__half, usqr_f16, x*x) UNARY_OP(__half, usqrt_f16, sqrtg(x)) UNARY_OP(__half, gelu_f16, gelu_fwd(x)) @@ -46,6 +51,16 @@ UNARY_OP(float, ucopy_f32, x) UNARY_OP(double, ucopy_f64, x) UNARY_OP(float, uneg_f32, -x) UNARY_OP(double, uneg_f64, -x) +UNARY_OP(float, uexp_f32, expg(x)) +UNARY_OP(double, uexp_f64, expg(x)) +UNARY_OP(float, ulog_f32, logg(x)) +UNARY_OP(double, ulog_f64, logg(x)) +UNARY_OP(float, usin_f32, sing(x)) +UNARY_OP(double, usin_f64, sing(x)) +UNARY_OP(float, ucos_f32, cosg(x)) +UNARY_OP(double, ucos_f64, cosg(x)) +UNARY_OP(float, uabsg_f32, absg(x)) +UNARY_OP(double, uabsg_f64, absg(x)) UNARY_OP(float, usqr_f32, x*x) UNARY_OP(double, usqr_f64, x*x) UNARY_OP(float, usqrt_f32, sqrtg(x)) diff --git a/src/op.rs b/src/op.rs index 77e87140..c4283d34 100644 --- a/src/op.rs +++ b/src/op.rs @@ -21,6 +21,11 @@ pub(crate) enum Op { mul: f64, add: f64, }, + Exp(Tensor), + Log(Tensor), + Sin(Tensor), + Cos(Tensor), + Abs(Tensor), Neg(Tensor), Reshape(Tensor), #[allow(dead_code)] @@ -60,6 +65,11 @@ pub(crate) struct Add; pub(crate) struct Div; pub(crate) struct Mul; pub(crate) struct Sub; +pub(crate) struct Exp; +pub(crate) struct Log; +pub(crate) struct Sin; +pub(crate) struct Cos; +pub(crate) struct Abs; pub(crate) struct Neg; pub(crate) struct Sqr; pub(crate) struct Sqrt; @@ -129,6 +139,81 @@ impl BinaryOp for Div { } } +impl UnaryOp for Exp { + const NAME: &'static str = "exp"; + fn f32(v1: f32) -> f32 { + v1.exp() + } + fn f64(v1: f64) -> f64 { + v1.exp() + } + fn u32(v1: u32) -> u32 { + (v1 as f64).exp() as u32 + } + const KERNEL_F32: &'static str = "uexp_f32"; + const KERNEL_F64: &'static str = "uexp_f64"; +} + +impl UnaryOp for Log { + const NAME: &'static str = "log"; + fn f32(v1: f32) -> f32 { + v1.ln() + } + fn f64(v1: f64) -> f64 { + v1.ln() + } + fn u32(v1: u32) -> u32 { + (v1 as f64).ln() as u32 + } + const KERNEL_F32: &'static str = "ulog_f32"; + const KERNEL_F64: &'static str = "ulog_f64"; +} + +impl UnaryOp for Sin { + const NAME: &'static str = "sin"; + fn f32(v1: f32) -> f32 { + v1.sin() + } + fn f64(v1: f64) -> f64 { + v1.sin() + } + fn u32(_: u32) -> u32 { + 0 + } + const KERNEL_F32: &'static str = "usin_f32"; + const KERNEL_F64: &'static str = "usin_f64"; +} + +impl UnaryOp for Cos { + const NAME: &'static str = "cos"; + fn f32(v1: f32) -> f32 { + v1.cos() + } + fn f64(v1: f64) -> f64 { + v1.cos() + } + fn u32(_: u32) -> u32 { + 0 + } + const KERNEL_F32: &'static str = "ucos_f32"; + const KERNEL_F64: &'static str = "ucos_f64"; +} + +impl UnaryOp for Abs { + const NAME: &'static str = "abs"; + fn f32(v1: f32) -> f32 { + v1.abs() + } + fn f64(v1: f64) -> f64 { + v1.abs() + } + fn u32(v1: u32) -> u32 { + v1 + } + const KERNEL_F32: &'static str = "uabs_f32"; + const KERNEL_F64: &'static str = "uabs_f64"; +} + impl UnaryOp for Neg { const NAME: &'static str = "neg"; fn f32(v1: f32) -> f32 { diff --git a/src/tensor.rs b/src/tensor.rs index af084740..c1ebaae0 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -254,6 +254,11 @@ impl Tensor { broadcast_binary_op!(broadcast_div, Div, BroadcastDiv); unary_op!(neg, Neg); + unary_op!(exp, Exp); + unary_op!(log, Log); + unary_op!(sin, Sin); + unary_op!(cos, Cos); + unary_op!(abs, Abs); unary_op!(sqr, Sqr); unary_op!(sqrt, Sqrt); unary_op!(gelu, Gelu); @@ -774,6 +779,11 @@ impl Tensor { | Op::Sqr(node) | Op::Sqrt(node) | Op::Gelu(node) + | Op::Exp(node) + | Op::Log(node) + | Op::Sin(node) + | Op::Cos(node) + | Op::Abs(node) | Op::Neg(node) => { let (tg, nodes) = walk(node, nodes, already_seen); track_grad |= tg; @@ -887,6 +897,23 @@ impl Tensor { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.add(&arg_grad)? } + Op::Log(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&(&grad * *node)?)? + } + Op::Sin(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&(&grad * arg.cos())?)? + } + Op::Cos(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.sub(&(&grad * arg.sin())?)? + } + Op::Abs(_args) => return Err(Error::BackwardNotSupported { op: "abs" }), + Op::Exp(arg) => { + let sum_grad = grads.or_insert(arg)?; + *sum_grad = sum_grad.add(&(&grad / arg)?)? + } Op::Neg(arg) => { let sum_grad = grads.or_insert(arg)?; *sum_grad = sum_grad.sub(&grad)?