mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Softmax numerical stability. (#267)
* Softmax numerical stability. * Fix the flash-attn test.
This commit is contained in:
@ -90,7 +90,6 @@ impl Tensor {
|
||||
| Op::ToDevice(node)
|
||||
| Op::Transpose(node, _, _)
|
||||
| Op::Narrow(node, _, _, _)
|
||||
| Op::Softmax(node, _)
|
||||
| Op::Unary(node, _)
|
||||
| Op::Elu(node, _)
|
||||
| Op::CustomOp1(node, _) => {
|
||||
@ -324,7 +323,6 @@ impl Tensor {
|
||||
}
|
||||
Op::Reduce(_, ReduceOp::ArgMin, _) => {}
|
||||
Op::Reduce(_, ReduceOp::ArgMax, _) => {}
|
||||
Op::Softmax(_arg, _) => Err(Error::BackwardNotSupported { op: "softmax" })?,
|
||||
Op::Reshape(arg) => {
|
||||
let arg_grad = grad.reshape(arg.dims())?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
|
Reference in New Issue
Block a user