Support i64 in index-select on metal. (#1951)

* Support i64 in index-select on metal.

* Add some testing of index-select for all dtypes.
This commit is contained in:
Laurent Mazare
2024-03-27 16:30:07 +01:00
committed by GitHub
parent a9abde5f93
commit ab86cd37c8
3 changed files with 53 additions and 38 deletions

View File

@ -1391,6 +1391,10 @@ impl BackendStorage for MetalStorage {
(DType::U32, DType::F16) => "is_u32_f16", (DType::U32, DType::F16) => "is_u32_f16",
(DType::U32, DType::BF16) => "is_u32_bf16", (DType::U32, DType::BF16) => "is_u32_bf16",
(DType::I64, DType::F32) => "is_i64_f32",
(DType::I64, DType::F16) => "is_i64_f16",
(DType::I64, DType::BF16) => "is_i64_bf16",
(left, right) => { (left, right) => {
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented") crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
} }

View File

@ -707,6 +707,8 @@ fn embeddings(device: &Device) -> Result<()> {
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids, 0)?; let hs = t.index_select(&ids, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]); assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
Ok(()) Ok(())
} }
@ -734,44 +736,47 @@ fn index_select(device: &Device) -> Result<()> {
[9.0, 10.0, 11.0] [9.0, 10.0, 11.0]
] ]
); );
let hs = t.index_select(&ids, 1)?; for dtype in [DType::U8, DType::U32, DType::I64] {
assert_eq!( let ids = ids.to_dtype(dtype)?;
hs.to_vec2::<f32>()?, let hs = t.index_select(&ids, 1)?;
&[ assert_eq!(
[0.0, 2.0, 1.0], hs.to_vec2::<f32>()?,
[3.0, 5.0, 4.0], &[
[6.0, 8.0, 7.0], [0.0, 2.0, 1.0],
[9.0, 11.0, 10.0] [3.0, 5.0, 4.0],
] [6.0, 8.0, 7.0],
); [9.0, 11.0, 10.0]
let hs = t.index_select(&ids, 0)?; ]
assert_eq!( );
hs.to_vec2::<f32>()?, let hs = t.index_select(&ids, 0)?;
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]] assert_eq!(
); hs.to_vec2::<f32>()?,
// Prior to https://github.com/huggingface/candle/pull/1022 &[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
// There would be a bug where the last values in the result tensor would be set to 0. );
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?; // Prior to https://github.com/huggingface/candle/pull/1022
let hs = t.index_select(&ids, 0)?; // There would be a bug where the last values in the result tensor would be set to 0.
assert_eq!( let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
hs.to_vec2::<f32>()?, let hs = t.index_select(&ids, 0)?;
&[ assert_eq!(
[0.0, 1.0, 2.0], hs.to_vec2::<f32>()?,
[6.0, 7.0, 8.0], &[
[3.0, 4.0, 5.0], [0.0, 1.0, 2.0],
[0.0, 1.0, 2.0], [6.0, 7.0, 8.0],
[6.0, 7.0, 8.0], [3.0, 4.0, 5.0],
[3.0, 4.0, 5.0], [0.0, 1.0, 2.0],
] [6.0, 7.0, 8.0],
); [3.0, 4.0, 5.0],
]
);
// Test when selecting dim > 0 with ids size different from elem count of // Test when selecting dim > 0 with ids size different from elem count of
// target dim in source/input. // target dim in source/input.
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?; let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?; let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]); assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
let hs = t.index_select(&ids, 1)?; let hs = t.index_select(&ids, 1)?;
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]); assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
}
Ok(()) Ok(())
} }

View File

@ -187,6 +187,12 @@ kernel void NAME( \
} }
INDEX_OP(is_i64_f32, int64_t, float)
INDEX_OP(is_i64_f16, int64_t, half)
#if defined(__HAVE_BFLOAT__)
INDEX_OP(is_i64_bf16, int64_t, bfloat)
#endif
INDEX_OP(is_u32_f32, uint32_t, float) INDEX_OP(is_u32_f32, uint32_t, float)
INDEX_OP(is_u32_f16, uint32_t, half) INDEX_OP(is_u32_f16, uint32_t, half)
#if defined(__HAVE_BFLOAT__) #if defined(__HAVE_BFLOAT__)
@ -242,4 +248,4 @@ INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t) INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
#if defined(__HAVE_BFLOAT__) #if defined(__HAVE_BFLOAT__)
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat) INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
#endif #endif