Onnx Support for Sign operation #2641 (#2642)

* Support for Sign operation #2641

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Ionut Mihalcea
2024-11-26 23:10:09 +01:00
committed by GitHub
parent b4deb5c5a9
commit 21c686387c
2 changed files with 47 additions and 0 deletions

View File

@ -5869,3 +5869,44 @@ fn test_xor() -> Result<()> {
}
Ok(())
}
#[test]
fn test_sign_operation() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Sign".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(
INPUT_X.to_string(),
Tensor::new(vec![-2f32, -1., 0., 1., 2.], &Device::Cpu)?,
);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
assert_eq!(
z.to_dtype(candle::DType::I64)?.to_vec1::<i64>()?.to_vec(),
vec![-1, -1, 0, 1, 1]
);
Ok(())
}