mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
onnx: workaround pow with negative base (#2439)
* onnx: workaround pow with negative base rather than fully defining pow in the cpu backend (as in #2318), this implements a much smaller change which is sufficient to evaluate silero-vad onnx models. Specifically, checking if pow is run with 2.0 exponent, and if so evaluate as simply `x*x` instead of the cpu backend of `e^(2.0 * ln(x))`. * PR: use Tensor::powf insead powf correctly handles a negative base.
This commit is contained in:
@ -352,8 +352,15 @@ fn simple_eval_(
|
||||
"Pow" => {
|
||||
let input0 = get(&node.input[0])?;
|
||||
let input1 = get(&node.input[1])?;
|
||||
let output = input0.broadcast_pow(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
// HACK: current implementation of broadcast_pow cannot handle negative base,
|
||||
// so we use powf where we can, which *does* correctly handle negative base.
|
||||
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
|
||||
let output = input0.powf(exp as f64)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
} else {
|
||||
let output = input0.broadcast_pow(input1)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
}
|
||||
"Exp" => {
|
||||
let xs = get(&node.input[0])?;
|
||||
|
Reference in New Issue
Block a user