diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index b9e761f6..fed7db13 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -1391,6 +1391,10 @@ impl BackendStorage for MetalStorage { (DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::BF16) => "is_u32_bf16", + (DType::I64, DType::F32) => "is_i64_f32", + (DType::I64, DType::F16) => "is_i64_f16", + (DType::I64, DType::BF16) => "is_i64_bf16", + (left, right) => { crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") } diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index af28c1c1..8aacc05d 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -707,6 +707,8 @@ fn embeddings(device: &Device) -> Result<()> { assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); let hs = t.index_select(&ids, 0)?; assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); + let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?; + assert_eq!(hs.to_vec2::()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); Ok(()) } @@ -734,44 +736,47 @@ fn index_select(device: &Device) -> Result<()> { [9.0, 10.0, 11.0] ] ); - let hs = t.index_select(&ids, 1)?; - assert_eq!( - hs.to_vec2::()?, - &[ - [0.0, 2.0, 1.0], - [3.0, 5.0, 4.0], - [6.0, 8.0, 7.0], - [9.0, 11.0, 10.0] - ] - ); - let hs = t.index_select(&ids, 0)?; - assert_eq!( - hs.to_vec2::()?, - &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]] - ); - // Prior to https://github.com/huggingface/candle/pull/1022 - // There would be a bug where the last values in the result tensor would be set to 0. - let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?; - let hs = t.index_select(&ids, 0)?; - assert_eq!( - hs.to_vec2::()?, - &[ - [0.0, 1.0, 2.0], - [6.0, 7.0, 8.0], - [3.0, 4.0, 5.0], - [0.0, 1.0, 2.0], - [6.0, 7.0, 8.0], - [3.0, 4.0, 5.0], - ] - ); + for dtype in [DType::U8, DType::U32, DType::I64] { + let ids = ids.to_dtype(dtype)?; + let hs = t.index_select(&ids, 1)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [0.0, 2.0, 1.0], + [3.0, 5.0, 4.0], + [6.0, 8.0, 7.0], + [9.0, 11.0, 10.0] + ] + ); + let hs = t.index_select(&ids, 0)?; + assert_eq!( + hs.to_vec2::()?, + &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]] + ); + // Prior to https://github.com/huggingface/candle/pull/1022 + // There would be a bug where the last values in the result tensor would be set to 0. + let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?; + let hs = t.index_select(&ids, 0)?; + assert_eq!( + hs.to_vec2::()?, + &[ + [0.0, 1.0, 2.0], + [6.0, 7.0, 8.0], + [3.0, 4.0, 5.0], + [0.0, 1.0, 2.0], + [6.0, 7.0, 8.0], + [3.0, 4.0, 5.0], + ] + ); - // Test when selecting dim > 0 with ids size different from elem count of - // target dim in source/input. - let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?; - let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?; - assert_eq!(t.to_vec2::()?, &[[1.0, 2.0], [3.0, 4.0]]); - let hs = t.index_select(&ids, 1)?; - assert_eq!(hs.to_vec2::()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]); + // Test when selecting dim > 0 with ids size different from elem count of + // target dim in source/input. + let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?; + let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?; + assert_eq!(t.to_vec2::()?, &[[1.0, 2.0], [3.0, 4.0]]); + let hs = t.index_select(&ids, 1)?; + assert_eq!(hs.to_vec2::()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]); + } Ok(()) } diff --git a/candle-metal-kernels/src/indexing.metal b/candle-metal-kernels/src/indexing.metal index ad4a8605..762b42be 100644 --- a/candle-metal-kernels/src/indexing.metal +++ b/candle-metal-kernels/src/indexing.metal @@ -187,6 +187,12 @@ kernel void NAME( \ } +INDEX_OP(is_i64_f32, int64_t, float) +INDEX_OP(is_i64_f16, int64_t, half) +#if defined(__HAVE_BFLOAT__) +INDEX_OP(is_i64_bf16, int64_t, bfloat) +#endif + INDEX_OP(is_u32_f32, uint32_t, float) INDEX_OP(is_u32_f16, uint32_t, half) #if defined(__HAVE_BFLOAT__) @@ -242,4 +248,4 @@ INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t) INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) #if defined(__HAVE_BFLOAT__) INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) -#endif \ No newline at end of file +#endif