Added a test for LeakyRelu

This commit is contained in:
b1rtek
2024-05-10 00:50:05 +02:00
parent 91b0d526ee
commit 2ced31b530

View File

@ -2808,3 +2808,81 @@ fn test_argmax() -> Result<()> {
Ok(()) Ok(())
} }
// "LeakyRelu"
#[test]
fn test_leakyrelu() -> Result<()> {
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-80
// leakyrelu
test(
&[-1.0, 0.0, 1.0],
Some(0.1),
&[-0.1, 0.0, 1.0]
)?;
fn test(data: impl NdArray, alpha: Option<f32>, expected: impl NdArray) -> Result<()> {
let att_alpha = AttributeProto {
name: "alpha".to_string(),
ref_attr_name: "alpha".to_string(),
i: 0,
doc_string: "alpha".to_string(),
r#type: 1, // FLOAT
f: alpha.unwrap_or(0.01),
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let attrs = {
let mut mut_attrs = vec![];
if alpha.is_some() {
mut_attrs.push(att_alpha);
}
mut_attrs
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "LeakyRelu".to_string(),
domain: "".to_string(),
attribute: attrs,
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
for both in z.to_vec1::<f64>()?.iter().zip(expected.to_vec1::<f64>()?.iter()) {
let (act, exp) = both;
assert!(f64::abs(act - exp) < f32::EPSILON.into());
}
Ok(())
}
Ok(())
}