Add the scatter op. (#2921)

* Add the scatter op.

* Backprop support.

* Cuda support.
This commit is contained in:
Laurent Mazare
2025-04-25 21:46:58 +02:00
committed by GitHub
parent 3aeb9575c7
commit 3827685524
15 changed files with 429 additions and 19 deletions

View File

@ -1027,7 +1027,7 @@ fn slice_scatter(device: &Device) -> Result<()> {
Ok(())
}
fn scatter_add(device: &Device) -> Result<()> {
fn scatter(device: &Device) -> Result<()> {
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
assert_eq!(
t.to_vec2::<f32>()?,
@ -1051,6 +1051,17 @@ fn scatter_add(device: &Device) -> Result<()> {
]
);
let hs = init.scatter(&ids, &t, 1)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 1.0, 2.0, 1.0, 1.0],
[5.0, 1.0, 1.0, 3.0, 4.0],
[1.0, 8.0, 1.0, 7.0, 1.0],
[10.0, 1.0, 9.0, 1.0, 11.0]
]
);
let init = Tensor::ones((6, 3), DType::F32, device)?;
let hs = init.scatter_add(&ids, &t, 0)?;
assert_eq!(
@ -1064,6 +1075,18 @@ fn scatter_add(device: &Device) -> Result<()> {
[1.0, 1.0, 1.0]
]
);
let hs = init.scatter(&ids, &t, 0)?;
assert_eq!(
hs.to_vec2::<f32>()?,
&[
[0.0, 10.0, 5.0],
[1.0, 1.0, 8.0],
[9.0, 1.0, 2.0],
[6.0, 7.0, 1.0],
[1.0, 4.0, 11.0],
[1.0, 1.0, 1.0]
]
);
Ok(())
}
@ -1563,12 +1586,7 @@ test_device!(
);
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
test_device!(
scatter_add,
scatter_add_cpu,
scatter_add_gpu,
scatter_add_metal
);
test_device!(scatter, scatter_cpu, scatter_gpu, scatter_metal);
test_device!(
slice_scatter,
slice_scatter_cpu,