mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Add the recip op + use it in stable-diffusion. (#331)
* Add the recip unary op. * Fix the cuda kernel. * Use the recip op in sigmoid.
This commit is contained in:
@ -291,6 +291,11 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.sub(&grad)?
|
*sum_grad = sum_grad.sub(&grad)?
|
||||||
}
|
}
|
||||||
|
Op::Unary(arg, UnaryOp::Recip) => {
|
||||||
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
|
let grad = (grad / arg.sqr()?)?;
|
||||||
|
*sum_grad = sum_grad.sub(&grad)?
|
||||||
|
}
|
||||||
&Op::Narrow(ref arg, dim, start_idx, len) => {
|
&Op::Narrow(ref arg, dim, start_idx, len) => {
|
||||||
let arg_dims = arg.dims();
|
let arg_dims = arg.dims();
|
||||||
let left_pad = if start_idx == 0 {
|
let left_pad = if start_idx == 0 {
|
||||||
|
@ -51,6 +51,7 @@ pub enum UnaryOp {
|
|||||||
Cos,
|
Cos,
|
||||||
Abs,
|
Abs,
|
||||||
Neg,
|
Neg,
|
||||||
|
Recip,
|
||||||
Sqr,
|
Sqr,
|
||||||
Sqrt,
|
Sqrt,
|
||||||
Gelu,
|
Gelu,
|
||||||
@ -264,6 +265,7 @@ pub(crate) struct Sin;
|
|||||||
pub(crate) struct Cos;
|
pub(crate) struct Cos;
|
||||||
pub(crate) struct Abs;
|
pub(crate) struct Abs;
|
||||||
pub(crate) struct Neg;
|
pub(crate) struct Neg;
|
||||||
|
pub(crate) struct Recip;
|
||||||
pub(crate) struct Sqr;
|
pub(crate) struct Sqr;
|
||||||
pub(crate) struct Sqrt;
|
pub(crate) struct Sqrt;
|
||||||
pub(crate) struct Gelu;
|
pub(crate) struct Gelu;
|
||||||
@ -410,6 +412,7 @@ unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
|
|||||||
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
|
||||||
unary_op!(Abs, "abs", v, v.abs());
|
unary_op!(Abs, "abs", v, v.abs());
|
||||||
unary_op!(Neg, "neg", v, -v);
|
unary_op!(Neg, "neg", v, -v);
|
||||||
|
unary_op!(Recip, "recip", v, v.recip());
|
||||||
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
|
||||||
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
unary_op!(Sqrt, "sqrt", v, v.sqrt(), vs_sqrt, vd_sqrt);
|
||||||
|
|
||||||
|
@ -474,6 +474,7 @@ impl Tensor {
|
|||||||
broadcast_binary_op!(broadcast_sub, sub);
|
broadcast_binary_op!(broadcast_sub, sub);
|
||||||
broadcast_binary_op!(broadcast_div, div);
|
broadcast_binary_op!(broadcast_div, div);
|
||||||
|
|
||||||
|
unary_op!(recip, Recip);
|
||||||
unary_op!(neg, Neg);
|
unary_op!(neg, Neg);
|
||||||
unary_op!(exp, Exp);
|
unary_op!(exp, Exp);
|
||||||
unary_op!(log, Log);
|
unary_op!(log, Log);
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
use candle::{Result, Tensor};
|
use candle::{Device, Result, Tensor};
|
||||||
|
|
||||||
pub fn sigmoid(_: &Tensor) -> Result<Tensor> {
|
pub fn sigmoid(xs: &Tensor) -> Result<Tensor> {
|
||||||
todo!()
|
// TODO: Add sigmoid as binary ops.
|
||||||
|
(xs.neg()?.exp()? - 1.0)?.recip()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
|
pub fn avg_pool2d(_: &Tensor) -> Result<Tensor> {
|
||||||
@ -16,6 +17,13 @@ pub fn upsample_nearest2d(_: &Tensor) -> Result<Tensor> {
|
|||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn linspace(_: f64, _: f64, _: usize) -> Result<Tensor> {
|
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
||||||
todo!()
|
if steps < 1 {
|
||||||
|
candle::bail!("cannot use linspace with steps {steps} <= 1")
|
||||||
|
}
|
||||||
|
let delta = (stop - start) / (steps - 1) as f64;
|
||||||
|
let vs = (0..steps)
|
||||||
|
.map(|step| start + step as f64 * delta)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
Tensor::from_vec(vs, steps, &Device::Cpu)
|
||||||
}
|
}
|
||||||
|
@ -80,6 +80,7 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
|
UNARY_OP(__nv_bfloat16, ucopy_bf16, x)
|
||||||
UNARY_OP(__nv_bfloat16, uneg_bf16, -x)
|
UNARY_OP(__nv_bfloat16, uneg_bf16, -x)
|
||||||
|
UNARY_OP(__nv_bfloat16, urecip_bf16, recipg(x))
|
||||||
UNARY_OP(__nv_bfloat16, uexp_bf16, expg(x))
|
UNARY_OP(__nv_bfloat16, uexp_bf16, expg(x))
|
||||||
UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
|
UNARY_OP(__nv_bfloat16, ulog_bf16, logg(x))
|
||||||
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
|
UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
|
||||||
@ -95,6 +96,7 @@ UNARY_OP1(__nv_bfloat16, uelu_bf16, elu_fwd(x, param))
|
|||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
UNARY_OP(__half, ucopy_f16, x)
|
UNARY_OP(__half, ucopy_f16, x)
|
||||||
UNARY_OP(__half, uneg_f16, -x)
|
UNARY_OP(__half, uneg_f16, -x)
|
||||||
|
UNARY_OP(__half, urecip_f16, recipg(x))
|
||||||
UNARY_OP(__half, uexp_f16, expg(x))
|
UNARY_OP(__half, uexp_f16, expg(x))
|
||||||
UNARY_OP(__half, ulog_f16, logg(x))
|
UNARY_OP(__half, ulog_f16, logg(x))
|
||||||
UNARY_OP(__half, usin_f16, sing(x))
|
UNARY_OP(__half, usin_f16, sing(x))
|
||||||
@ -113,6 +115,8 @@ UNARY_OP(float, ucopy_f32, x)
|
|||||||
UNARY_OP(double, ucopy_f64, x)
|
UNARY_OP(double, ucopy_f64, x)
|
||||||
UNARY_OP(float, uneg_f32, -x)
|
UNARY_OP(float, uneg_f32, -x)
|
||||||
UNARY_OP(double, uneg_f64, -x)
|
UNARY_OP(double, uneg_f64, -x)
|
||||||
|
UNARY_OP(float, urecip_f32, recipg(x))
|
||||||
|
UNARY_OP(double, urecip_f64, recipg(x))
|
||||||
UNARY_OP(float, uexp_f32, expg(x))
|
UNARY_OP(float, uexp_f32, expg(x))
|
||||||
UNARY_OP(double, uexp_f64, expg(x))
|
UNARY_OP(double, uexp_f64, expg(x))
|
||||||
UNARY_OP(float, ulog_f32, logg(x))
|
UNARY_OP(float, ulog_f32, logg(x))
|
||||||
|
Reference in New Issue
Block a user