Add PRelu operation (#2904)

* Add PRelu operation

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
A2va
2025-04-19 07:24:10 +02:00
committed by GitHub
parent 9dbaf958dc
commit 21055b5697
3 changed files with 71 additions and 1 deletions

View File

@ -1846,6 +1846,64 @@ fn test_relu_operation() -> Result<()> {
Ok(())
}
// "PRelu"
#[test]
fn test_prelu_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "PRelu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x: Tensor = Tensor::from_vec(
vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32],
&[2, 2],
&Device::Cpu,
)?;
let y: Tensor = Tensor::from_vec(vec![1.0f32, 1.1f32, 1.2f32, 1.3f32], &[2, 2], &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), y);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![-1.0, 1.0], vec![-2.4, 3.0]]);
Ok(())
}
// "Constant"
// #[test]