From 03344d3c19887f6e357d3667fc8e519dfd58b23a Mon Sep 17 00:00:00 2001 From: mokulus <36231852+mokulus@users.noreply.github.com> Date: Sun, 2 Jun 2024 21:45:20 +0200 Subject: [PATCH] ONNX: Add Floor and Ceil (#2235) --- candle-onnx/src/eval.rs | 10 +++ candle-onnx/tests/ops.rs | 152 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 162 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 65fb6d77..df258290 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -877,6 +877,16 @@ pub fn simple_eval( let output = input.relu()?; values.insert(node.output[0].clone(), output); } + "Ceil" => { + let input = get(&node.input[0])?; + let output = input.ceil()?; + values.insert(node.output[0].clone(), output); + } + "Floor" => { + let input = get(&node.input[0])?; + let output = input.floor()?; + values.insert(node.output[0].clone(), output); + } // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant "Constant" => { let value = match node.attribute.iter().find(|attr| attr.name == "value") { diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index a53ad8c5..f58aeccf 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -2556,3 +2556,155 @@ fn test_where() -> Result<()> { Ok(()) } + +#[test] +fn test_floor() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Floor".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec( + // some values taken from https://numpy.org/doc/stable/reference/generated/numpy.floor.html + vec![ + f64::NAN, + f64::INFINITY, + f64::NEG_INFINITY, + -1.7, + -1.5, + -0.2, + 0.2, + 1.5, + 1.7, + 2.0, + ], + &[10], + &Device::Cpu, + )?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec1::()?; + + assert!(results[0].is_nan()); + assert_eq!( + results[1..], + vec![ + f64::INFINITY, + f64::NEG_INFINITY, + -2., + -2., + -1., + 0., + 1., + 1., + 2. + ] + ); + + Ok(()) +} + +#[test] +fn test_ceil() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Ceil".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec( + // some values taken from https://numpy.org/doc/stable/reference/generated/numpy.ceil.html + vec![ + f64::NAN, + f64::INFINITY, + f64::NEG_INFINITY, + -1.7, + -1.5, + -0.2, + 0.2, + 1.5, + 1.7, + 2.0, + ], + &[10], + &Device::Cpu, + )?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec1::()?; + + assert!(results[0].is_nan()); + assert_eq!( + results[1..], + vec![ + f64::INFINITY, + f64::NEG_INFINITY, + -1., + -1., + -0., + 1., + 2., + 2., + 2. + ] + ); + + Ok(()) +}