onnx: workaround pow with negative base (#2439)

* onnx: workaround pow with negative base

rather than fully defining pow in the cpu backend (as in #2318),
this implements a much smaller change which is sufficient to evaluate silero-vad
onnx models. Specifically, checking if pow is run with 2.0 exponent, and if so
evaluate as simply `x*x` instead of the cpu backend of `e^(2.0 * ln(x))`.

* PR: use Tensor::powf insead

powf correctly handles a negative base.
This commit is contained in:
shua
2024-08-22 13:34:53 +02:00
committed by GitHub
parent 6070278a31
commit a8288b7a72

View File

@ -352,8 +352,15 @@ fn simple_eval_(
"Pow" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_pow(input1)?;
values.insert(node.output[0].clone(), output);
// HACK: current implementation of broadcast_pow cannot handle negative base,
// so we use powf where we can, which *does* correctly handle negative base.
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
let output = input0.powf(exp as f64)?;
values.insert(node.output[0].clone(), output);
} else {
let output = input0.broadcast_pow(input1)?;
values.insert(node.output[0].clone(), output);
}
}
"Exp" => {
let xs = get(&node.input[0])?;