Softmax numerical stability. (#267)

* Softmax numerical stability.

* Fix the flash-attn test.
This commit is contained in:
Laurent Mazare
2023-07-28 13:13:01 +01:00
committed by GitHub
parent 68eab38de6
commit 3eb2bc6d07
28 changed files with 117 additions and 188 deletions

View File

@ -90,7 +90,6 @@ impl Tensor {
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Narrow(node, _, _, _)
| Op::Softmax(node, _)
| Op::Unary(node, _)
| Op::Elu(node, _)
| Op::CustomOp1(node, _) => {
@ -324,7 +323,6 @@ impl Tensor {
}
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
Op::Reshape(arg) => {
let arg_grad = grad.reshape(arg.dims())?;
let sum_grad = grads.or_insert(arg)?;