mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Add missing bfloat unary strided kernels and fix typo (#2058)
This commit is contained in:
@ -540,6 +540,7 @@ impl BackendStorage for MetalStorage {
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("utanh", DType::F32) => strided::tanh::FLOAT,
|
||||
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
("usqr", DType::F16) => strided::sqr::HALF,
|
||||
@ -557,6 +558,25 @@ impl BackendStorage for MetalStorage {
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
("utanh", DType::F16) => strided::tanh::HALF,
|
||||
|
||||
("ucos", DType::BF16) => strided::cos::BFLOAT,
|
||||
("usin", DType::BF16) => strided::sin::BFLOAT,
|
||||
("usqr", DType::BF16) => strided::sqr::BFLOAT,
|
||||
("usqrt", DType::BF16) => strided::sqrt::BFLOAT,
|
||||
("uneg", DType::BF16) => strided::neg::BFLOAT,
|
||||
("uexp", DType::BF16) => strided::exp::BFLOAT,
|
||||
("ulog", DType::BF16) => strided::log::BFLOAT,
|
||||
("ugelu", DType::BF16) => strided::gelu::BFLOAT,
|
||||
("ugelu_erf", DType::BF16) => strided::gelu_erf::BFLOAT,
|
||||
("uerf", DType::BF16) => strided::erf::BFLOAT,
|
||||
("usilu", DType::BF16) => strided::silu::BFLOAT,
|
||||
("uabs", DType::BF16) => strided::abs::BFLOAT,
|
||||
("uceil", DType::BF16) => strided::ceil::BFLOAT,
|
||||
("ufloor", DType::BF16) => strided::floor::BFLOAT,
|
||||
("urelu", DType::BF16) => strided::relu::BFLOAT,
|
||||
("uround", DType::BF16) => strided::round::BFLOAT,
|
||||
("utanh", DType::BF16) => strided::tanh::BFLOAT,
|
||||
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
|
Reference in New Issue
Block a user