Add the recip op + use it in stable-diffusion. (#331)

* Add the recip unary op.

* Fix the cuda kernel.

* Use the recip op in sigmoid.
This commit is contained in:
Laurent Mazare
2023-08-06 22:14:52 +02:00
committed by GitHub
parent 1c062bf06b
commit 166bfd5847
5 changed files with 26 additions and 5 deletions

View File

@ -291,6 +291,11 @@ impl Tensor {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.sub(&grad)?
}
Op::Unary(arg, UnaryOp::Recip) => {
let sum_grad = grads.or_insert(arg)?;
let grad = (grad / arg.sqr()?)?;
*sum_grad = sum_grad.sub(&grad)?
}
&Op::Narrow(ref arg, dim, start_idx, len) => {
let arg_dims = arg.dims();
let left_pad = if start_idx == 0 {

View File

@ -51,6 +51,7 @@ pub enum UnaryOp {
Cos,
Abs,
Neg,
Recip,
Sqr,
Sqrt,
Gelu,
@ -264,6 +265,7 @@ pub(crate) struct Sin;
pub(crate) struct Cos;
pub(crate) struct Abs;
pub(crate) struct Neg;
pub(crate) struct Recip;
pub(crate) struct Sqr;
pub(crate) struct Sqrt;
pub(crate) struct Gelu;
@ -410,6 +412,7 @@ unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);

View File

@ -474,6 +474,7 @@ impl Tensor {
broadcast_binary_op!(broadcast_sub, sub);
broadcast_binary_op!(broadcast_div, div);
unary_op!(recip, Recip);
unary_op!(neg, Neg);
unary_op!(exp, Exp);
unary_op!(log, Log);