mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 20:09:50 +00:00
Starting to fix some tests.
This commit is contained in:
@ -293,6 +293,12 @@ impl BackendStorage for MetalStorage {
|
||||
("uneg", DType::F32) => contiguous::neg::FLOAT,
|
||||
("uexp", DType::F32) => contiguous::exp::FLOAT,
|
||||
("ulog", DType::F32) => contiguous::log::FLOAT,
|
||||
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
|
||||
// TODO erf does not exist in metal
|
||||
("ugelu_erf", DType::F32) => contiguous::gelu::FLOAT,
|
||||
("uceil", DType::F32) => contiguous::ceil::FLOAT,
|
||||
("ufloor", DType::F32) => contiguous::floor::FLOAT,
|
||||
("uround", DType::F32) => contiguous::round::FLOAT,
|
||||
(name, dtype) => todo!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_contiguous(
|
||||
@ -519,7 +525,6 @@ impl BackendStorage for MetalStorage {
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let mut buffer = device.new_buffer(dst_el, dtype);
|
||||
let out = self.to_cpu_storage().unwrap();
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(left, right) => todo!("index select metal {left:?} {right:?}"),
|
||||
@ -690,6 +695,7 @@ impl BackendStorage for MetalStorage {
|
||||
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
|
||||
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
|
||||
DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
|
||||
dtype => todo!("copy_strided not implemented for {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_unary_strided(
|
||||
|
Reference in New Issue
Block a user