mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the scatter in place ops. (#2923)
* Add the scatter_set op. * Metal op. * Cuda version. * Merge the checks. * Add the actual ops.
This commit is contained in:
@ -1087,6 +1087,18 @@ fn scatter(device: &Device) -> Result<()> {
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
init.scatter_set(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
init.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 10.0, 5.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[9.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user