onnx: add the Flatten operator. (#1638)

* onnx: add the Flatten operator.

* onnx flatten: merge axis condition

---------

Co-authored-by: 王泽龙 <wangzelong@shenqishen.com>
This commit is contained in:
wanglong001
2024-02-03 23:28:47 +08:00
committed by GitHub
parent a510ddec4e
commit 5cdd84e0f6
2 changed files with 139 additions and 1 deletions

View File

@ -766,6 +766,16 @@ pub fn simple_eval(
let output = input.cumsum(axis as usize)?; let output = input.cumsum(axis as usize)?;
values.insert(node.output[0].clone(), output); values.insert(node.output[0].clone(), output);
} }
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#flatten
"Flatten" => {
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(1) as usize;
let input = get(&node.input[0])?;
let first_part: usize = input.shape().dims().iter().take(axis).product();
let end_index = input.shape().dims().iter().product::<usize>();
let new_shape = (first_part, end_index / first_part);
let output = input.reshape(new_shape)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"), op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
} }
} }

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src; extern crate accelerate_src;
use candle::{Device, Result, Tensor}; use candle::{Device, Result, Tensor};
use candle_onnx::onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto}; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
use std::collections::HashMap; use std::collections::HashMap;
const INPUT_X: &str = "x"; const INPUT_X: &str = "x";
@ -677,6 +677,134 @@ fn test_dropout_operation() -> Result<()> {
Ok(()) Ok(())
} }
// "Flatten"
#[test]
fn test_flatten_operation() -> Result<()> {
let mut att_axis = AttributeProto {
name: "axis".to_string(),
ref_attr_name: "axis".to_string(),
i: 0,
doc_string: "axis".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Flatten".to_string(),
domain: "".to_string(),
attribute: vec![att_axis.clone()],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32,
],
&[2, 2, 2],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(results, vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]]);
att_axis.i = 1;
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Flatten".to_string(),
domain: "".to_string(),
attribute: vec![att_axis.clone()],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<f32>()?;
assert_eq!(
results,
vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]]
);
Ok(())
}
// Below are ops that are implemented but not tested yet // Below are ops that are implemented but not tested yet
// "MaxPool" // "MaxPool"