mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
update dtypes checks for several metal operations (#2010)
This commit is contained in:
@ -443,42 +443,60 @@ impl BackendStorage for MetalStorage {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
||||
("uneg", DType::F16) => contiguous::neg::HALF,
|
||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||
("ulog", DType::F16) => contiguous::log::HALF,
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("uabs", DType::F16) => contiguous::abs::HALF,
|
||||
("uabs", DType::F32) => contiguous::abs::FLOAT,
|
||||
("uabs", DType::BF16) => contiguous::abs::BFLOAT,
|
||||
("uceil", DType::F16) => contiguous::ceil::HALF,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("uceil", DType::BF16) => contiguous::ceil::BFLOAT,
|
||||
("ucos", DType::F16) => contiguous::cos::HALF,
|
||||
("ucos", DType::F32) => contiguous::cos::FLOAT,
|
||||
("ucos", DType::BF16) => contiguous::cos::BFLOAT,
|
||||
("uerf", DType::F16) => contiguous::erf::HALF,
|
||||
("uerf", DType::F32) => contiguous::erf::FLOAT,
|
||||
("uerf", DType::BF16) => contiguous::erf::BFLOAT,
|
||||
("uexp", DType::F16) => contiguous::exp::HALF,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("uexp", DType::BF16) => contiguous::exp::BFLOAT,
|
||||
("ufloor", DType::F16) => contiguous::floor::HALF,
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
("ufloor", DType::BF16) => contiguous::floor::BFLOAT,
|
||||
("ugelu_erf", DType::F16) => contiguous::gelu_erf::HALF,
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu_erf::FLOAT,
|
||||
("ugelu_erf", DType::BF16) => contiguous::gelu_erf::BFLOAT,
|
||||
("ugelu", DType::F16) => contiguous::gelu::HALF,
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("ugelu", DType::BF16) => contiguous::gelu::BFLOAT,
|
||||
("ulog", DType::F16) => contiguous::log::HALF,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
("ulog", DType::BF16) => contiguous::log::BFLOAT,
|
||||
("uneg", DType::F16) => contiguous::neg::HALF,
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uneg", DType::BF16) => contiguous::neg::BFLOAT,
|
||||
("urecip", DType::F16) => contiguous::recip::HALF,
|
||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||
("urecip", DType::F32) => contiguous::recip::FLOAT,
|
||||
("urecip", DType::BF16) => contiguous::recip::BFLOAT,
|
||||
("urelu", DType::F16) => contiguous::relu::HALF,
|
||||
("urelu", DType::F32) => contiguous::relu::FLOAT,
|
||||
("urelu", DType::BF16) => contiguous::relu::BFLOAT,
|
||||
("uround", DType::F16) => contiguous::round::HALF,
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
("uround", DType::BF16) => contiguous::round::BFLOAT,
|
||||
("usilu", DType::F16) => contiguous::silu::HALF,
|
||||
("usilu", DType::F32) => contiguous::silu::FLOAT,
|
||||
("usilu", DType::BF16) => contiguous::silu::BFLOAT,
|
||||
("usin", DType::F16) => contiguous::sin::HALF,
|
||||
("usin", DType::F32) => contiguous::sin::FLOAT,
|
||||
("usin", DType::BF16) => contiguous::sin::BFLOAT,
|
||||
("usqr", DType::F16) => contiguous::sqr::HALF,
|
||||
("usqr", DType::F32) => contiguous::sqr::FLOAT,
|
||||
("usqr", DType::BF16) => contiguous::sqr::BFLOAT,
|
||||
("usqrt", DType::F16) => contiguous::sqrt::HALF,
|
||||
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
|
||||
("usqrt", DType::BF16) => contiguous::sqrt::BFLOAT,
|
||||
("utanh", DType::F16) => contiguous::tanh::HALF,
|
||||
("utanh", DType::F32) => contiguous::tanh::FLOAT,
|
||||
("utanh", DType::BF16) => contiguous::tanh::BFLOAT,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
|
Reference in New Issue
Block a user