Add the scatter in place ops. (#2923)

* Add the scatter_set op.

* Metal op.

* Cuda version.

* Merge the checks.

* Add the actual ops.
This commit is contained in:
Laurent Mazare
2025-04-26 07:36:49 +02:00
committed by GitHub
parent 3827685524
commit a2e925462c
12 changed files with 208 additions and 141 deletions

View File

@ -1087,6 +1087,18 @@ fn scatter(device: &Device) -> Result<()> {
[1.0, 1.0, 1.0]
]
);
init.scatter_set(&ids, &t, 0)?;
assert_eq!(
init.to_vec2::<f32>()?,
&[
[0.0, 10.0, 5.0],
[1.0, 1.0, 8.0],
[9.0, 1.0, 2.0],
[6.0, 7.0, 1.0],
[1.0, 4.0, 11.0],
[1.0, 1.0, 1.0]
]
);
Ok(())
}