Starting to fix some tests.

This commit is contained in:
Nicolas Patry
2023-11-11 01:02:15 +01:00
parent 4f39695465
commit 3ad02147e4
3 changed files with 42 additions and 15 deletions

View File

@ -293,6 +293,12 @@ impl BackendStorage for MetalStorage {
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
("ulog", DType::F32) => contiguous::log::FLOAT,
("ugelu", DType::F32) => contiguous::gelu::FLOAT,
// TODO erf does not exist in metal
("ugelu_erf", DType::F32) => contiguous::gelu::FLOAT,
("uceil", DType::F32) => contiguous::ceil::FLOAT,
("ufloor", DType::F32) => contiguous::floor::FLOAT,
("uround", DType::F32) => contiguous::round::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
@ -519,7 +525,6 @@ impl BackendStorage for MetalStorage {
let dtype = self.dtype;
let device = self.device();
let mut buffer = device.new_buffer(dst_el, dtype);
let out = self.to_cpu_storage().unwrap();
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "is_u32_f32",
(left, right) => todo!("index select metal {left:?} {right:?}"),
@ -690,6 +695,7 @@ impl BackendStorage for MetalStorage {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
DType::U32 => candle_metal_kernels::unary::strided::copy::U32,
dtype => todo!("copy_strided not implemented for {dtype:?}"),
};
candle_metal_kernels::call_unary_strided(