mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
Indexing with max-value results in zero/no-op. (#2940)
* Indexing with max-value results in zero/no-op. * Add some testing. * Also adapt the metal kernels. * Another test. * Fix.
This commit is contained in:
@ -845,6 +845,9 @@ fn embeddings(device: &Device) -> Result<()> {
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||
let ids = Tensor::new(&[u32::MAX, 2u32, u32::MAX], device)?;
|
||||
let hs = t.index_select(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1087,6 +1090,31 @@ fn scatter(device: &Device) -> Result<()> {
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
|
||||
let hs = {
|
||||
let ids = Tensor::new(
|
||||
&[
|
||||
[0u32, u32::MAX, 2],
|
||||
[3, 4, u32::MAX],
|
||||
[3, 3, 1],
|
||||
[u32::MAX, u32::MAX, 4],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
init.scatter(&ids, &t, 0)?
|
||||
};
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 1.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[1.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
|
||||
init.scatter_set(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
init.to_vec2::<f32>()?,
|
||||
@ -1099,6 +1127,7 @@ fn scatter(device: &Device) -> Result<()> {
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1132,6 +1161,23 @@ fn gather(device: &Device) -> Result<()> {
|
||||
let hs = t.gather(&ids, 0)?;
|
||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
|
||||
|
||||
let hs = {
|
||||
let ids = Tensor::new(
|
||||
&[
|
||||
[0u32, 0u32],
|
||||
[2u32, u32::MAX],
|
||||
[u32::MAX, 1u32],
|
||||
[0u32, 2u32],
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
t.gather(&ids, 1)?
|
||||
};
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[[0.0, 0.0], [5.0, 0.0], [0.0, 7.0], [9.0, 11.0]]
|
||||
);
|
||||
|
||||
// Random data
|
||||
|
||||
// Dim: 0
|
||||
|
Reference in New Issue
Block a user