mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
@ -537,6 +537,11 @@ impl BackendStorage for MetalStorage {
|
||||
(ReduceOp::Max, DType::I64) => ("fast_max_i64_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64_strided", true, true),
|
||||
(ReduceOp::Sum, DType::U8) => ("fast_sum_u8_strided", false, false),
|
||||
(ReduceOp::Min, DType::U8) => ("fast_min_u8_strided", true, false),
|
||||
(ReduceOp::Max, DType::U8) => ("fast_max_u8_strided", true, false),
|
||||
(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8_strided", true, true),
|
||||
(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8_strided", true, true),
|
||||
(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"),
|
||||
};
|
||||
if check_empty && layout.shape().elem_count() == 0 {
|
||||
@ -779,6 +784,8 @@ impl BackendStorage for MetalStorage {
|
||||
(DType::U8, DType::F32) => "where_u8_f32",
|
||||
(DType::U8, DType::F16) => "where_u8_f16",
|
||||
(DType::U8, DType::I64) => "where_u8_i64",
|
||||
(DType::U8, DType::U32) => "where_u8_u32",
|
||||
(DType::U8, DType::U8) => "where_u8_u8",
|
||||
(left, right) => crate::bail!("Metal where_cond {left:?} {right:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_where_cond_strided(
|
||||
@ -1323,6 +1330,26 @@ impl MetalStorage {
|
||||
("lt", DType::I64) => (contiguous::lt::I64, DType::U8),
|
||||
("ge", DType::I64) => (contiguous::ge::I64, DType::U8),
|
||||
("gt", DType::I64) => (contiguous::gt::I64, DType::U8),
|
||||
("add", DType::U32) => (contiguous::add::U32, self.dtype),
|
||||
("sub", DType::U32) => (contiguous::sub::U32, self.dtype),
|
||||
("mul", DType::U32) => (contiguous::mul::U32, self.dtype),
|
||||
("div", DType::U32) => (contiguous::div::U32, self.dtype),
|
||||
("eq", DType::U32) => (contiguous::eq::U32, DType::U8),
|
||||
("ne", DType::U32) => (contiguous::ne::U32, DType::U8),
|
||||
("le", DType::U32) => (contiguous::le::U32, DType::U8),
|
||||
("lt", DType::U32) => (contiguous::lt::U32, DType::U8),
|
||||
("ge", DType::U32) => (contiguous::ge::U32, DType::U8),
|
||||
("gt", DType::U32) => (contiguous::gt::U32, DType::U8),
|
||||
("add", DType::U8) => (contiguous::add::U8, self.dtype),
|
||||
("sub", DType::U8) => (contiguous::sub::U8, self.dtype),
|
||||
("mul", DType::U8) => (contiguous::mul::U8, self.dtype),
|
||||
("div", DType::U8) => (contiguous::div::U8, self.dtype),
|
||||
("eq", DType::U8) => (contiguous::eq::U8, DType::U8),
|
||||
("ne", DType::U8) => (contiguous::ne::U8, DType::U8),
|
||||
("le", DType::U8) => (contiguous::le::U8, DType::U8),
|
||||
("lt", DType::U8) => (contiguous::lt::U8, DType::U8),
|
||||
("ge", DType::U8) => (contiguous::ge::U8, DType::U8),
|
||||
("gt", DType::U8) => (contiguous::gt::U8, DType::U8),
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal contiguous binary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -1380,6 +1407,30 @@ impl MetalStorage {
|
||||
("lt", DType::I64) => (strided::lt::I64, DType::U8),
|
||||
("ge", DType::I64) => (strided::ge::I64, DType::U8),
|
||||
("gt", DType::I64) => (strided::gt::I64, DType::U8),
|
||||
("badd", DType::U32) => (strided::add::U32, self.dtype),
|
||||
("bsub", DType::U32) => (strided::sub::U32, self.dtype),
|
||||
("bmul", DType::U32) => (strided::mul::U32, self.dtype),
|
||||
("bdiv", DType::U32) => (strided::div::U32, self.dtype),
|
||||
("bminimum", DType::U32) => (strided::min::U32, self.dtype),
|
||||
("bmaximum", DType::U32) => (strided::max::U32, self.dtype),
|
||||
("eq", DType::U32) => (strided::eq::U32, DType::U8),
|
||||
("ne", DType::U32) => (strided::ne::U32, DType::U8),
|
||||
("le", DType::U32) => (strided::le::U32, DType::U8),
|
||||
("lt", DType::U32) => (strided::lt::U32, DType::U8),
|
||||
("ge", DType::U32) => (strided::ge::U32, DType::U8),
|
||||
("gt", DType::U32) => (strided::gt::U32, DType::U8),
|
||||
("badd", DType::U8) => (strided::add::U8, self.dtype),
|
||||
("bsub", DType::U8) => (strided::sub::U8, self.dtype),
|
||||
("bmul", DType::U8) => (strided::mul::U8, self.dtype),
|
||||
("bdiv", DType::U8) => (strided::div::U8, self.dtype),
|
||||
("bminimum", DType::U8) => (strided::min::U8, self.dtype),
|
||||
("bmaximum", DType::U8) => (strided::max::U8, self.dtype),
|
||||
("eq", DType::U8) => (strided::eq::U8, DType::U8),
|
||||
("ne", DType::U8) => (strided::ne::U8, DType::U8),
|
||||
("le", DType::U8) => (strided::le::U8, DType::U8),
|
||||
("lt", DType::U8) => (strided::lt::U8, DType::U8),
|
||||
("ge", DType::U8) => (strided::ge::U8, DType::U8),
|
||||
("gt", DType::U8) => (strided::gt::U8, DType::U8),
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided binary {name} {dtype:?} not implemented")
|
||||
}
|
||||
|
Reference in New Issue
Block a user