From a8288b7a7233d6b3b80eb074a03a345732d20578 Mon Sep 17 00:00:00 2001 From: shua Date: Thu, 22 Aug 2024 13:34:53 +0200 Subject: [PATCH] onnx: workaround pow with negative base (#2439) * onnx: workaround pow with negative base rather than fully defining pow in the cpu backend (as in #2318), this implements a much smaller change which is sufficient to evaluate silero-vad onnx models. Specifically, checking if pow is run with 2.0 exponent, and if so evaluate as simply `x*x` instead of the cpu backend of `e^(2.0 * ln(x))`. * PR: use Tensor::powf insead powf correctly handles a negative base. --- candle-onnx/src/eval.rs | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 036f5838..fca51ef7 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -352,8 +352,15 @@ fn simple_eval_( "Pow" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; - let output = input0.broadcast_pow(input1)?; - values.insert(node.output[0].clone(), output); + // HACK: current implementation of broadcast_pow cannot handle negative base, + // so we use powf where we can, which *does* correctly handle negative base. + if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::())() { + let output = input0.powf(exp as f64)?; + values.insert(node.output[0].clone(), output); + } else { + let output = input0.broadcast_pow(input1)?; + values.insert(node.output[0].clone(), output); + } } "Exp" => { let xs = get(&node.input[0])?;