From 21055b569752a473a0f6b6e2a9c5f53b5bcfe933 Mon Sep 17 00:00:00 2001 From: A2va <49582555+A2va@users.noreply.github.com> Date: Sat, 19 Apr 2025 07:24:10 +0200 Subject: [PATCH] Add PRelu operation (#2904) * Add PRelu operation * Apply rustfmt. --------- Co-authored-by: Laurent --- candle-nn/src/activation.rs | 4 ++- candle-onnx/src/eval.rs | 10 +++++++ candle-onnx/tests/ops.rs | 58 +++++++++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 1 deletion(-) diff --git a/candle-nn/src/activation.rs b/candle-nn/src/activation.rs index 30f65de0..cc995442 100644 --- a/candle-nn/src/activation.rs +++ b/candle-nn/src/activation.rs @@ -71,6 +71,8 @@ impl candle::Module for PReLU { fn forward(&self, xs: &Tensor) -> Result { let weight = if self.is_scalar { self.weight.reshape(())? + } else if xs.shape() == self.weight.shape() { + self.weight.clone() } else if xs.rank() >= 2 { let num_channels = xs.dim(1)?; let num_weights = self.weight.elem_count(); @@ -78,7 +80,7 @@ impl candle::Module for PReLU { candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}") } let mut s = vec![1; xs.rank()]; - s[1] = self.weight.elem_count(); + s[1] = num_weights; self.weight.reshape(s)? } else { self.weight.clone() diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 2c60ed2f..f1255172 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1,7 +1,9 @@ use crate::onnx::attribute_proto::AttributeType; use crate::onnx::tensor_proto::DataType; use crate::onnx::{self, GraphProto}; +use candle::Module; use candle::{bail, DType, Device, Result, Tensor}; +use candle_nn::activation::PReLU; use std::collections::{HashMap, HashSet}; pub type Value = Tensor; @@ -991,6 +993,14 @@ fn simple_eval_( let output = input.relu()?; values.insert(node.output[0].clone(), output); } + "PRelu" => { + // https://onnx.ai/onnx/operators/onnx__PRelu.html + let input = get(&node.input[0])?; + let slope = get(&node.input[1])?; + + let output = PReLU::new(slope.clone(), false).forward(input)?; + values.insert(node.output[0].clone(), output); + } "Ceil" => { let input = get(&node.input[0])?; let output = input.ceil()?; diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 3586bfbd..dffb79b7 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -1846,6 +1846,64 @@ fn test_relu_operation() -> Result<()> { Ok(()) } +// "PRelu" +#[test] +fn test_prelu_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "PRelu".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ + ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: INPUT_Y.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x: Tensor = Tensor::from_vec( + vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32], + &[2, 2], + &Device::Cpu, + )?; + + let y: Tensor = Tensor::from_vec(vec![1.0f32, 1.1f32, 1.2f32, 1.3f32], &[2, 2], &Device::Cpu)?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + inputs.insert(INPUT_Y.to_string(), y); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let results = z.to_vec2::()?; + assert_eq!(results, vec![vec![-1.0, 1.0], vec![-2.4, 3.0]]); + + Ok(()) +} // "Constant" // #[test]