mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Support i64 in index-select on metal. (#1951)
* Support i64 in index-select on metal. * Add some testing of index-select for all dtypes.
This commit is contained in:
@ -1391,6 +1391,10 @@ impl BackendStorage for MetalStorage {
|
|||||||
(DType::U32, DType::F16) => "is_u32_f16",
|
(DType::U32, DType::F16) => "is_u32_f16",
|
||||||
(DType::U32, DType::BF16) => "is_u32_bf16",
|
(DType::U32, DType::BF16) => "is_u32_bf16",
|
||||||
|
|
||||||
|
(DType::I64, DType::F32) => "is_i64_f32",
|
||||||
|
(DType::I64, DType::F16) => "is_i64_f16",
|
||||||
|
(DType::I64, DType::BF16) => "is_i64_bf16",
|
||||||
|
|
||||||
(left, right) => {
|
(left, right) => {
|
||||||
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
|
crate::bail!("Metal contiguous index_select {left:?} {right:?} not implemented")
|
||||||
}
|
}
|
||||||
|
@ -707,6 +707,8 @@ fn embeddings(device: &Device) -> Result<()> {
|
|||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
let hs = t.index_select(&ids, 0)?;
|
let hs = t.index_select(&ids, 0)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
|
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
||||||
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -734,44 +736,47 @@ fn index_select(device: &Device) -> Result<()> {
|
|||||||
[9.0, 10.0, 11.0]
|
[9.0, 10.0, 11.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
let hs = t.index_select(&ids, 1)?;
|
for dtype in [DType::U8, DType::U32, DType::I64] {
|
||||||
assert_eq!(
|
let ids = ids.to_dtype(dtype)?;
|
||||||
hs.to_vec2::<f32>()?,
|
let hs = t.index_select(&ids, 1)?;
|
||||||
&[
|
assert_eq!(
|
||||||
[0.0, 2.0, 1.0],
|
hs.to_vec2::<f32>()?,
|
||||||
[3.0, 5.0, 4.0],
|
&[
|
||||||
[6.0, 8.0, 7.0],
|
[0.0, 2.0, 1.0],
|
||||||
[9.0, 11.0, 10.0]
|
[3.0, 5.0, 4.0],
|
||||||
]
|
[6.0, 8.0, 7.0],
|
||||||
);
|
[9.0, 11.0, 10.0]
|
||||||
let hs = t.index_select(&ids, 0)?;
|
]
|
||||||
assert_eq!(
|
);
|
||||||
hs.to_vec2::<f32>()?,
|
let hs = t.index_select(&ids, 0)?;
|
||||||
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
assert_eq!(
|
||||||
);
|
hs.to_vec2::<f32>()?,
|
||||||
// Prior to https://github.com/huggingface/candle/pull/1022
|
&[[0.0, 1.0, 2.0], [6.0, 7.0, 8.0], [3.0, 4.0, 5.0]]
|
||||||
// There would be a bug where the last values in the result tensor would be set to 0.
|
);
|
||||||
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
// Prior to https://github.com/huggingface/candle/pull/1022
|
||||||
let hs = t.index_select(&ids, 0)?;
|
// There would be a bug where the last values in the result tensor would be set to 0.
|
||||||
assert_eq!(
|
let ids = Tensor::new(&[0u32, 2u32, 1u32, 0u32, 2u32, 1u32], device)?;
|
||||||
hs.to_vec2::<f32>()?,
|
let hs = t.index_select(&ids, 0)?;
|
||||||
&[
|
assert_eq!(
|
||||||
[0.0, 1.0, 2.0],
|
hs.to_vec2::<f32>()?,
|
||||||
[6.0, 7.0, 8.0],
|
&[
|
||||||
[3.0, 4.0, 5.0],
|
[0.0, 1.0, 2.0],
|
||||||
[0.0, 1.0, 2.0],
|
[6.0, 7.0, 8.0],
|
||||||
[6.0, 7.0, 8.0],
|
[3.0, 4.0, 5.0],
|
||||||
[3.0, 4.0, 5.0],
|
[0.0, 1.0, 2.0],
|
||||||
]
|
[6.0, 7.0, 8.0],
|
||||||
);
|
[3.0, 4.0, 5.0],
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
// Test when selecting dim > 0 with ids size different from elem count of
|
// Test when selecting dim > 0 with ids size different from elem count of
|
||||||
// target dim in source/input.
|
// target dim in source/input.
|
||||||
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
let ids = Tensor::new(&[1u32, 0u32, 1u32], device)?;
|
||||||
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
let t = Tensor::arange(1f32, 5f32, device)?.reshape((2, 2))?;
|
||||||
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
assert_eq!(t.to_vec2::<f32>()?, &[[1.0, 2.0], [3.0, 4.0]]);
|
||||||
let hs = t.index_select(&ids, 1)?;
|
let hs = t.index_select(&ids, 1)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[2.0, 1.0, 2.0], [4.0, 3.0, 4.0]]);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -187,6 +187,12 @@ kernel void NAME( \
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_OP(is_i64_f32, int64_t, float)
|
||||||
|
INDEX_OP(is_i64_f16, int64_t, half)
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
INDEX_OP(is_i64_bf16, int64_t, bfloat)
|
||||||
|
#endif
|
||||||
|
|
||||||
INDEX_OP(is_u32_f32, uint32_t, float)
|
INDEX_OP(is_u32_f32, uint32_t, float)
|
||||||
INDEX_OP(is_u32_f16, uint32_t, half)
|
INDEX_OP(is_u32_f16, uint32_t, half)
|
||||||
#if defined(__HAVE_BFLOAT__)
|
#if defined(__HAVE_BFLOAT__)
|
||||||
@ -242,4 +248,4 @@ INDEX_ADD_OP(ia_u8_u32, uint8_t, uint32_t)
|
|||||||
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
|
INDEX_ADD_OP(ia_u8_u8, uint8_t, uint8_t)
|
||||||
#if defined(__HAVE_BFLOAT__)
|
#if defined(__HAVE_BFLOAT__)
|
||||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
Reference in New Issue
Block a user