mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add a custom softmax implementation. (#744)
* Add a custom softmax implementation. * Add softmaxlastdim to the benchmarks. * And add a test. * Support more dtypes. * Polish the code. * Use the slow implementation on cuda. * Add a todo for the cuda kernel.
This commit is contained in:
@ -41,6 +41,16 @@ fn softmax() -> Result<()> {
|
||||
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
|
||||
]
|
||||
);
|
||||
let t2 = candle_nn::ops::softmax_last_dim(&tensor.log()?)?;
|
||||
assert_eq!(
|
||||
to_vec3_round(&t2, 4)?,
|
||||
&[
|
||||
// (3, 1, 4) / 8, (1, 5, 9) / 15
|
||||
[[0.375, 0.125, 0.5], [0.0667, 0.3333, 0.6]],
|
||||
// (2, 1, 7) / 10, (8, 2, 8) / 18
|
||||
[[0.2, 0.1, 0.7], [0.4444, 0.1111, 0.4444]]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user