Add some missing index-select metal kernels. (#2613)

* Add some missing index-select metal kernels.

* Make some matrix contiguous pre-matmul.
This commit is contained in:
Laurent Mazare
2024-11-12 17:10:12 +01:00
committed by GitHub
parent 9453cc3095
commit 06350c31c7
3 changed files with 16 additions and 2 deletions

View File

@ -1237,7 +1237,7 @@ impl BackendStorage for MetalStorage {
let dst_el = ids_l.shape().elem_count();
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let buffer = device.new_buffer(dst_el, dtype, "gather")?;
let name = match (ids.dtype, self.dtype) {
(DType::U32, DType::F32) => "gather_u32_f32",
(DType::U32, DType::F16) => "gather_u32_f16",
@ -1324,14 +1324,23 @@ impl BackendStorage for MetalStorage {
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
let name = match (ids.dtype, self.dtype) {
(DType::U8, DType::U8) => "is_u8_u8",
(DType::U8, DType::U32) => "is_u8_u32",
(DType::U8, DType::I64) => "is_u8_i64",
(DType::U8, DType::BF16) => "is_u8_bf16",
(DType::U8, DType::F32) => "is_u8_f32",
(DType::U8, DType::F16) => "is_u8_f16",
(DType::U32, DType::U8) => "is_u32_u8",
(DType::U32, DType::U32) => "is_u32_u32",
(DType::U32, DType::I64) => "is_u32_i64",
(DType::U32, DType::F32) => "is_u32_f32",
(DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16",
(DType::I64, DType::U8) => "is_i64_u8",
(DType::I64, DType::U32) => "is_i64_u32",
(DType::I64, DType::I64) => "is_i64_i64",
(DType::I64, DType::F32) => "is_i64_f32",
(DType::I64, DType::F16) => "is_i64_f16",
(DType::I64, DType::BF16) => "is_i64_bf16",