From 5cdd84e0f6365df832a9dbb062ad3a9a34bb65b3 Mon Sep 17 00:00:00 2001 From: wanglong001 <406134592@qq.com> Date: Sat, 3 Feb 2024 23:28:47 +0800 Subject: [PATCH] onnx: add the Flatten operator. (#1638) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * onnx: add the Flatten operator. * onnx flatten: merge axis condition --------- Co-authored-by: 王泽龙 --- candle-onnx/src/eval.rs | 10 +++ candle-onnx/tests/ops.rs | 130 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 139 insertions(+), 1 deletion(-) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index c0ad8668..cacb56ca 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -766,6 +766,16 @@ pub fn simple_eval( let output = input.cumsum(axis as usize)?; values.insert(node.output[0].clone(), output); } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#flatten + "Flatten" => { + let axis = get_attr_opt::(node, "axis")?.copied().unwrap_or(1) as usize; + let input = get(&node.input[0])?; + let first_part: usize = input.shape().dims().iter().take(axis).product(); + let end_index = input.shape().dims().iter().product::(); + let new_shape = (first_part, end_index / first_part); + let output = input.reshape(new_shape)?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index a7beeaf0..74b5aad2 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -5,7 +5,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; use candle::{Device, Result, Tensor}; -use candle_onnx::onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto}; +use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; use std::collections::HashMap; const INPUT_X: &str = "x"; @@ -677,6 +677,134 @@ fn test_dropout_operation() -> Result<()> { Ok(()) } +// "Flatten" +#[test] +fn test_flatten_operation() -> Result<()> { + let mut att_axis = AttributeProto { + name: "axis".to_string(), + ref_attr_name: "axis".to_string(), + i: 0, + doc_string: "axis".to_string(), + r#type: 2, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Flatten".to_string(), + domain: "".to_string(), + attribute: vec![att_axis.clone()], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ + ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: INPUT_Y.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec( + vec![ + 1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32, + ], + &[2, 2, 2], + &Device::Cpu, + )?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::()?; + + assert_eq!(results, vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]]); + + att_axis.i = 1; + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Flatten".to_string(), + domain: "".to_string(), + attribute: vec![att_axis.clone()], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ + ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: INPUT_Y.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::()?; + + assert_eq!( + results, + vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]] + ); + + Ok(()) +} + // Below are ops that are implemented but not tested yet // "MaxPool"