diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 036f5838..fca51ef7 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -352,8 +352,15 @@ fn simple_eval_( "Pow" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?; - let output = input0.broadcast_pow(input1)?; - values.insert(node.output[0].clone(), output); + // HACK: current implementation of broadcast_pow cannot handle negative base, + // so we use powf where we can, which *does* correctly handle negative base. + if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::())() { + let output = input0.powf(exp as f64)?; + values.insert(node.output[0].clone(), output); + } else { + let output = input0.broadcast_pow(input1)?; + values.insert(node.output[0].clone(), output); + } } "Exp" => { let xs = get(&node.input[0])?;