diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 47f75949..e300fae7 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -2808,3 +2808,81 @@ fn test_argmax() -> Result<()> { Ok(()) } + +// "LeakyRelu" +#[test] +fn test_leakyrelu() -> Result<()> { + // tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-80 + // leakyrelu + test( + &[-1.0, 0.0, 1.0], + Some(0.1), + &[-0.1, 0.0, 1.0] + )?; + fn test(data: impl NdArray, alpha: Option, expected: impl NdArray) -> Result<()> { + let att_alpha = AttributeProto { + name: "alpha".to_string(), + ref_attr_name: "alpha".to_string(), + i: 0, + doc_string: "alpha".to_string(), + r#type: 1, // FLOAT + f: alpha.unwrap_or(0.01), + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let attrs = { + let mut mut_attrs = vec![]; + if alpha.is_some() { + mut_attrs.push(att_alpha); + } + mut_attrs + }; + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "LeakyRelu".to_string(), + domain: "".to_string(), + attribute: attrs, + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + for both in z.to_vec1::()?.iter().zip(expected.to_vec1::()?.iter()) { + let (act, exp) = both; + assert!(f64::abs(act - exp) < f32::EPSILON.into()); + } + + Ok(()) + } + + Ok(()) +}