mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
add urecip op to metal backend
This commit is contained in:
@ -648,6 +648,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||||
|
("urecip", DType::F32) => contiguous::round::FLOAT,
|
||||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||||
("usin", DType::F16) => contiguous::sin::HALF,
|
("usin", DType::F16) => contiguous::sin::HALF,
|
||||||
@ -662,6 +663,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||||
("uround", DType::F16) => contiguous::round::HALF,
|
("uround", DType::F16) => contiguous::round::HALF,
|
||||||
|
("urecip", DType::F16) => contiguous::round::HALF,
|
||||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||||
};
|
};
|
||||||
|
@ -165,7 +165,7 @@ macro_rules! ops{
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub mod unary {
|
pub mod unary {
|
||||||
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh);
|
ops!(cos, sin, exp, sqr, sqrt, neg, log, gelu, ceil, floor, round, erf, gelu_erf, tanh, urecip);
|
||||||
}
|
}
|
||||||
pub mod binary {
|
pub mod binary {
|
||||||
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
|
ops!(add, sub, mul, div, min, max, eq, ne, le, lt, ge, gt);
|
||||||
|
@ -19,7 +19,9 @@ METAL_FUNC uint get_strided_index(
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
|
||||||
|
template <typename T> METAL_FUNC T urecip(T in){ return T(1.0 / in); }
|
||||||
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
template <typename T> METAL_FUNC T neg(T in){ return -in; }
|
||||||
|
|
||||||
template <typename T> METAL_FUNC T erf(T in){
|
template <typename T> METAL_FUNC T erf(T in){
|
||||||
float x = (float) in;
|
float x = (float) in;
|
||||||
// constants
|
// constants
|
||||||
@ -57,8 +59,6 @@ template <typename T> METAL_FUNC T gelu(T x) {
|
|||||||
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
return static_cast<T>(0.5) * x * (static_cast<T>(1.0) + T(tanh(beta)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
@ -108,6 +108,8 @@ UNARY_OP(round)
|
|||||||
UNARY_OP(gelu_erf)
|
UNARY_OP(gelu_erf)
|
||||||
UNARY_OP(erf)
|
UNARY_OP(erf)
|
||||||
UNARY_OP(tanh)
|
UNARY_OP(tanh)
|
||||||
|
UNARY_OP(urecip)
|
||||||
|
|
||||||
UNARY(id, float, copy_f32, copy_f32_strided)
|
UNARY(id, float, copy_f32, copy_f32_strided)
|
||||||
UNARY(id, half, copy_f16, copy_f16_strided)
|
UNARY(id, half, copy_f16, copy_f16_strided)
|
||||||
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
UNARY(id, uint8_t, copy_u8, copy_u8_strided)
|
||||||
@ -128,6 +130,7 @@ BFLOAT_UNARY_OP(round)
|
|||||||
BFLOAT_UNARY_OP(gelu_erf)
|
BFLOAT_UNARY_OP(gelu_erf)
|
||||||
BFLOAT_UNARY_OP(erf)
|
BFLOAT_UNARY_OP(erf)
|
||||||
BFLOAT_UNARY_OP(tanh)
|
BFLOAT_UNARY_OP(tanh)
|
||||||
|
BFLOAT_UNARY_OP(urecip)
|
||||||
|
|
||||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||||
#endif
|
#endif
|
||||||
|
BIN
sd_final.png
Normal file
BIN
sd_final.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 329 KiB |
Reference in New Issue
Block a user