mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Add some missing index-select metal kernels. (#2613)
* Add some missing index-select metal kernels. * Make some matrix contiguous pre-matmul.
This commit is contained in:
@ -1237,7 +1237,7 @@ impl BackendStorage for MetalStorage {
|
||||
let dst_el = ids_l.shape().elem_count();
|
||||
let dtype = self.dtype;
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||
let buffer = device.new_buffer(dst_el, dtype, "gather")?;
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U32, DType::F32) => "gather_u32_f32",
|
||||
(DType::U32, DType::F16) => "gather_u32_f16",
|
||||
@ -1324,14 +1324,23 @@ impl BackendStorage for MetalStorage {
|
||||
let device = self.device();
|
||||
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::U8) => "is_u8_u8",
|
||||
(DType::U8, DType::U32) => "is_u8_u32",
|
||||
(DType::U8, DType::I64) => "is_u8_i64",
|
||||
(DType::U8, DType::BF16) => "is_u8_bf16",
|
||||
(DType::U8, DType::F32) => "is_u8_f32",
|
||||
(DType::U8, DType::F16) => "is_u8_f16",
|
||||
|
||||
(DType::U32, DType::U8) => "is_u32_u8",
|
||||
(DType::U32, DType::U32) => "is_u32_u32",
|
||||
(DType::U32, DType::I64) => "is_u32_i64",
|
||||
(DType::U32, DType::F32) => "is_u32_f32",
|
||||
(DType::U32, DType::F16) => "is_u32_f16",
|
||||
(DType::U32, DType::BF16) => "is_u32_bf16",
|
||||
|
||||
(DType::I64, DType::U8) => "is_i64_u8",
|
||||
(DType::I64, DType::U32) => "is_i64_u32",
|
||||
(DType::I64, DType::I64) => "is_i64_i64",
|
||||
(DType::I64, DType::F32) => "is_i64_f32",
|
||||
(DType::I64, DType::F16) => "is_i64_f16",
|
||||
(DType::I64, DType::BF16) => "is_i64_bf16",
|
||||
|
@ -193,12 +193,16 @@ INDEX_OP(is_i64_f16, int64_t, half)
|
||||
INDEX_OP(is_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
INDEX_OP(is_u32_u8, uint32_t, uint8_t)
|
||||
INDEX_OP(is_u32_u32, uint32_t, uint32_t)
|
||||
INDEX_OP(is_u32_f32, uint32_t, float)
|
||||
INDEX_OP(is_u32_f16, uint32_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
||||
#endif
|
||||
|
||||
INDEX_OP(is_u8_u8, uint8_t, uint8_t)
|
||||
INDEX_OP(is_u8_u32, uint8_t, uint32_t)
|
||||
INDEX_OP(is_u8_f32, uint8_t, float)
|
||||
INDEX_OP(is_u8_f16, uint8_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
|
@ -171,7 +171,8 @@ impl ChineseClipModel {
|
||||
) -> Result<Tensor> {
|
||||
let output = self
|
||||
.text_model
|
||||
.forward(input_ids, token_type_ids, attention_mask)?;
|
||||
.forward(input_ids, token_type_ids, attention_mask)?
|
||||
.contiguous()?;
|
||||
self.text_projection.forward(&output)
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user