mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add missing bfloat unary strided kernels and fix typo (#2058)
This commit is contained in:
@ -540,6 +540,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||||
("uround", DType::F32) => strided::round::FLOAT,
|
("uround", DType::F32) => strided::round::FLOAT,
|
||||||
("utanh", DType::F32) => strided::tanh::FLOAT,
|
("utanh", DType::F32) => strided::tanh::FLOAT,
|
||||||
|
|
||||||
("ucos", DType::F16) => strided::cos::HALF,
|
("ucos", DType::F16) => strided::cos::HALF,
|
||||||
("usin", DType::F16) => strided::sin::HALF,
|
("usin", DType::F16) => strided::sin::HALF,
|
||||||
("usqr", DType::F16) => strided::sqr::HALF,
|
("usqr", DType::F16) => strided::sqr::HALF,
|
||||||
@ -557,6 +558,25 @@ impl BackendStorage for MetalStorage {
|
|||||||
("urelu", DType::F16) => strided::relu::HALF,
|
("urelu", DType::F16) => strided::relu::HALF,
|
||||||
("uround", DType::F16) => strided::round::HALF,
|
("uround", DType::F16) => strided::round::HALF,
|
||||||
("utanh", DType::F16) => strided::tanh::HALF,
|
("utanh", DType::F16) => strided::tanh::HALF,
|
||||||
|
|
||||||
|
("ucos", DType::BF16) => strided::cos::BFLOAT,
|
||||||
|
("usin", DType::BF16) => strided::sin::BFLOAT,
|
||||||
|
("usqr", DType::BF16) => strided::sqr::BFLOAT,
|
||||||
|
("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
|
||||||
|
("uneg", DType::BF16) => strided::neg::BFLOAT,
|
||||||
|
("uexp", DType::BF16) => strided::exp::BFLOAT,
|
||||||
|
("ulog", DType::BF16) => strided::log::BFLOAT,
|
||||||
|
("ugelu", DType::BF16) => strided::gelu::BFLOAT,
|
||||||
|
("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
|
||||||
|
("uerf", DType::BF16) => strided::erf::BFLOAT,
|
||||||
|
("usilu", DType::BF16) => strided::silu::BFLOAT,
|
||||||
|
("uabs", DType::BF16) => strided::abs::BFLOAT,
|
||||||
|
("uceil", DType::BF16) => strided::ceil::BFLOAT,
|
||||||
|
("ufloor", DType::BF16) => strided::floor::BFLOAT,
|
||||||
|
("urelu", DType::BF16) => strided::relu::BFLOAT,
|
||||||
|
("uround", DType::BF16) => strided::round::BFLOAT,
|
||||||
|
("utanh", DType::BF16) => strided::tanh::BFLOAT,
|
||||||
|
|
||||||
(name, dtype) => {
|
(name, dtype) => {
|
||||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||||
}
|
}
|
||||||
|
@ -175,5 +175,5 @@ BFLOAT_UNARY_OP(sign)
|
|||||||
|
|
||||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||||
|
|
||||||
COPY2D(copy2d_bf64, bfloat)
|
COPY2D(copy2d_bf16, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
Reference in New Issue
Block a user