mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add a custom softmax implementation. (#744)
* Add a custom softmax implementation. * Add softmaxlastdim to the benchmarks. * And add a test. * Support more dtypes. * Polish the code. * Use the slow implementation on cuda. * Add a todo for the cuda kernel.
This commit is contained in:
@ -198,7 +198,7 @@ impl CrossAttention {
|
||||
let xs = query.matmul(&(key.t()? * self.scale)?)?;
|
||||
let xs = {
|
||||
let _enter = self.span_softmax.enter();
|
||||
nn::ops::softmax(&xs, D::Minus1)?
|
||||
nn::ops::softmax_last_dim(&xs)?
|
||||
};
|
||||
xs.matmul(&value)?.to_dtype(in_dtype)?
|
||||
};
|
||||
|
Reference in New Issue
Block a user