Add a custom softmax implementation. (#744)

* Add a custom softmax implementation.

* Add softmaxlastdim to the benchmarks.

* And add a test.

* Support more dtypes.

* Polish the code.

* Use the slow implementation on cuda.

* Add a todo for the cuda kernel.
This commit is contained in:
Laurent Mazare
2023-09-05 15:20:23 +02:00
committed by GitHub
parent a8410bf35e
commit 1c9e5394a5
5 changed files with 109 additions and 18 deletions

View File

@ -41,6 +41,16 @@ fn softmax() -> Result<()> {
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
]
);
let t2 = candle_nn::ops::softmax_last_dim(&tensor.log()?)?;
assert_eq!(
to_vec3_round(&t2, 4)?,
&[
// (3, 1, 4) / 8, (1, 5, 9) / 15
[[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
// (2, 1, 7) / 10, (8, 2, 8) / 18
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
]
);
Ok(())
}