mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the rounding operators. (#1030)
* Add the rounding operators. * Avoid tracking gradients for the rounding operations. * Add some rounding tests.
This commit is contained in:
@ -91,6 +91,9 @@ impl Tensor {
|
||||
nodes
|
||||
}
|
||||
}
|
||||
Op::Unary(_node, UnaryOp::Ceil)
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D(node)
|
||||
@ -451,6 +454,13 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&arg_grad)?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Ceil) => Err(Error::BackwardNotSupported { op: "ceil" })?,
|
||||
Op::Unary(_, UnaryOp::Floor) => {
|
||||
Err(Error::BackwardNotSupported { op: "floor" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Round) => {
|
||||
Err(Error::BackwardNotSupported { op: "round" })?
|
||||
}
|
||||
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
|
||||
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
|
||||
Op::Unary(_, UnaryOp::GeluErf) => {
|
||||
|
@ -62,6 +62,9 @@ pub enum UnaryOp {
|
||||
Erf,
|
||||
Relu,
|
||||
Tanh,
|
||||
Floor,
|
||||
Ceil,
|
||||
Round,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@ -332,6 +335,9 @@ pub(crate) struct GeluErf;
|
||||
pub(crate) struct Erf;
|
||||
pub(crate) struct Relu;
|
||||
pub(crate) struct Tanh;
|
||||
pub(crate) struct Floor;
|
||||
pub(crate) struct Ceil;
|
||||
pub(crate) struct Round;
|
||||
|
||||
macro_rules! bin_op {
|
||||
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {
|
||||
@ -660,6 +666,108 @@ impl UnaryOpT for Erf {
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Ceil {
|
||||
const NAME: &'static str = "ceil";
|
||||
const KERNEL: &'static str = "uceil";
|
||||
const V: Self = Ceil;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.ceil()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Floor {
|
||||
const NAME: &'static str = "floor";
|
||||
const KERNEL: &'static str = "ufloor";
|
||||
const V: Self = Floor;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.floor()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for Round {
|
||||
const NAME: &'static str = "round";
|
||||
const KERNEL: &'static str = "uround";
|
||||
const V: Self = Round;
|
||||
#[inline(always)]
|
||||
fn bf16(v: bf16) -> bf16 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f16(v: f16) -> f16 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f32(v: f32) -> f32 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn f64(v: f64) -> f64 {
|
||||
v.round()
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u8(v: u8) -> u8 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn u32(v: u32) -> u32 {
|
||||
v
|
||||
}
|
||||
#[inline(always)]
|
||||
fn i64(v: i64) -> i64 {
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl UnaryOpT for GeluErf {
|
||||
const NAME: &'static str = "gelu_erf";
|
||||
const KERNEL: &'static str = "ugelu_erf";
|
||||
|
@ -492,6 +492,9 @@ impl Tensor {
|
||||
unary_op!(gelu_erf, GeluErf);
|
||||
unary_op!(erf, Erf);
|
||||
unary_op!(relu, Relu);
|
||||
unary_op!(ceil, Ceil);
|
||||
unary_op!(floor, Floor);
|
||||
unary_op!(round, Round);
|
||||
|
||||
/// Retrieves the single scalar value hold in the tensor. If the tensor contains multiple
|
||||
/// dimensions, an error is returned instead.
|
||||
|
@ -93,6 +93,18 @@ fn unary_op(device: &Device) -> Result<()> {
|
||||
[0.9999, -0.9891, -0.3079, 0.9891, 0.9999]
|
||||
]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.ceil()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -1.0, -0.0, 2.0, 3.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.floor()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -1.0, 0.0], [2.0, -2.0, -1.0, 1.0, 2.0]]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&tensor.round()?, 4)?,
|
||||
[[-3.0, 1.0, 4.0, -0.0, 1.0], [3.0, -2.0, -0.0, 2.0, 3.0]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -131,6 +131,12 @@ __device__ __forceinline__ float tanhg(float a) { return tanhf(a); }
|
||||
__device__ __forceinline__ double tanhg(double a) { return tanh(a); }
|
||||
__device__ __forceinline__ float erfg(float a) { return erff(a); }
|
||||
__device__ __forceinline__ double erfg(double a) { return erf(a); }
|
||||
__device__ __forceinline__ float ceilg(float a) { return ceilf(a); }
|
||||
__device__ __forceinline__ double ceilg(double a) { return ceil(a); }
|
||||
__device__ __forceinline__ float floorg(float a) { return floorf(a); }
|
||||
__device__ __forceinline__ double floorg(double a) { return floor(a); }
|
||||
__device__ __forceinline__ float roundg(float a) { return roundf(a); }
|
||||
__device__ __forceinline__ double roundg(double a) { return round(a); }
|
||||
__device__ __forceinline__ float normcdfg(float a) { return normcdff(a); }
|
||||
__device__ __forceinline__ double normcdfg(double a) { return normcdf(a); }
|
||||
__device__ __forceinline__ float maxg(float a, float b) { return fmaxf(a, b); }
|
||||
@ -162,6 +168,9 @@ __device__ __forceinline__ __half recipg(__half a) { __half one = 1.0; return on
|
||||
__device__ __forceinline__ __half maxg(__half a, __half b) { return __hmax_nan(a, b); }
|
||||
__device__ __forceinline__ __half tanhg(__half a) { return __float2half(tanhf(__half2float(a))); }
|
||||
__device__ __forceinline__ __half erfg(__half a) { return __float2half(erff(__half2float(a))); }
|
||||
__device__ __forceinline__ __half ceilg(__half a) { return __float2half(ceilf(__half2float(a))); }
|
||||
__device__ __forceinline__ __half floorg(__half a) { return __float2half(floorf(__half2float(a))); }
|
||||
__device__ __forceinline__ __half roundg(__half a) { return __float2half(roundf(__half2float(a))); }
|
||||
__device__ __forceinline__ __half normcdfg(__half a) { return __float2half(normcdff(__half2float(a))); }
|
||||
__device__ __forceinline__ __half ming(__half a, __half b) { return __hmin_nan(a, b); }
|
||||
__device__ __forceinline__ __half logg(__half a) { return hlog(a); }
|
||||
@ -180,6 +189,9 @@ __device__ __forceinline__ __nv_bfloat16 recipg(__nv_bfloat16 a) { __nv_bfloat16
|
||||
__device__ __forceinline__ __nv_bfloat16 maxg(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmax_nan(a, b); }
|
||||
__device__ __forceinline__ __nv_bfloat16 tanhg(__nv_bfloat16 a) { return __float2bfloat16(tanhf(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 erfg(__nv_bfloat16 a) { return __float2bfloat16(erff(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 ceilg(__nv_bfloat16 a) { return __float2bfloat16(ceilf(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 floorg(__nv_bfloat16 a) { return __float2bfloat16(floorf(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 roundg(__nv_bfloat16 a) { return __float2bfloat16(roundf(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 normcdfg(__nv_bfloat16 a) { return __float2bfloat16(normcdff(__bfloat162float(a))); }
|
||||
__device__ __forceinline__ __nv_bfloat16 ming(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmin_nan(a, b); }
|
||||
__device__ __forceinline__ __nv_bfloat16 logg(__nv_bfloat16 a) { return hlog(a); }
|
||||
|
@ -92,6 +92,9 @@ UNARY_OP(__nv_bfloat16, usin_bf16, sing(x))
|
||||
UNARY_OP(__nv_bfloat16, ucos_bf16, cosg(x))
|
||||
UNARY_OP(__nv_bfloat16, utanh_bf16, tanhg(x))
|
||||
UNARY_OP(__nv_bfloat16, uerf_bf16, erfg(x))
|
||||
UNARY_OP(__nv_bfloat16, uceil_bf16, ceilg(x))
|
||||
UNARY_OP(__nv_bfloat16, ufloor_bf16, floorg(x))
|
||||
UNARY_OP(__nv_bfloat16, uround_bf16, roundg(x))
|
||||
UNARY_OP(__nv_bfloat16, unormcdf_bf16, normcdfg(x))
|
||||
UNARY_OP(__nv_bfloat16, uabs_bf16, absg(x))
|
||||
UNARY_OP(__nv_bfloat16, usqr_bf16, x*x)
|
||||
@ -113,6 +116,9 @@ UNARY_OP(__half, usin_f16, sing(x))
|
||||
UNARY_OP(__half, ucos_f16, cosg(x))
|
||||
UNARY_OP(__half, utanh_f16, tanhg(x))
|
||||
UNARY_OP(__half, uerf_f16, erfg(x))
|
||||
UNARY_OP(__half, uceil_f16, ceilg(x))
|
||||
UNARY_OP(__half, ufloor_f16, floorg(x))
|
||||
UNARY_OP(__half, uround_f16, roundg(x))
|
||||
UNARY_OP(__half, unormcdf_f16, normcdfg(x))
|
||||
UNARY_OP(__half, uabs_f16, absg(x))
|
||||
UNARY_OP(__half, usqr_f16, x*x)
|
||||
@ -145,6 +151,12 @@ UNARY_OP(float, utanh_f32, tanhg(x))
|
||||
UNARY_OP(double, utanh_f64, tanhg(x))
|
||||
UNARY_OP(float, uerf_f32, erfg(x))
|
||||
UNARY_OP(double, uerf_f64, erfg(x))
|
||||
UNARY_OP(float, uceil_f32, ceilg(x))
|
||||
UNARY_OP(double, uceil_f64, ceilg(x))
|
||||
UNARY_OP(float, ufloor_f32, floorg(x))
|
||||
UNARY_OP(double, ufloor_f64, floorg(x))
|
||||
UNARY_OP(float, uround_f32, roundg(x))
|
||||
UNARY_OP(double, uround_f64, roundg(x))
|
||||
UNARY_OP(float, unormcdf_f32, normcdfg(x))
|
||||
UNARY_OP(double, unormcdf_f64, normcdfg(x))
|
||||
UNARY_OP(float, uabs_f32, absg(x))
|
||||
|
Reference in New Issue
Block a user