Add a custom softmax implementation. (#744)

* Add a custom softmax implementation.

* Add softmaxlastdim to the benchmarks.

* And add a test.

* Support more dtypes.

* Polish the code.

* Use the slow implementation on cuda.

* Add a todo for the cuda kernel.
This commit is contained in:
Laurent Mazare
2023-09-05 15:20:23 +02:00
committed by GitHub
parent a8410bf35e
commit 1c9e5394a5
5 changed files with 109 additions and 18 deletions

View File

@ -198,7 +198,7 @@ impl CrossAttention {
let xs = query.matmul(&(key.t()? * self.scale)?)?;
let xs = {
let _enter = self.span_softmax.enter();
nn::ops::softmax(&xs, D::Minus1)?
nn::ops::softmax_last_dim(&xs)?
};
xs.matmul(&value)?.to_dtype(in_dtype)?
};