mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 19:58:35 +00:00
onnx: add the Flatten operator. (#1638)
* onnx: add the Flatten operator. * onnx flatten: merge axis condition --------- Co-authored-by: 王泽龙 <wangzelong@shenqishen.com>
This commit is contained in:
@ -766,6 +766,16 @@ pub fn simple_eval(
|
|||||||
let output = input.cumsum(axis as usize)?;
|
let output = input.cumsum(axis as usize)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#flatten
|
||||||
|
"Flatten" => {
|
||||||
|
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(1) as usize;
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let first_part: usize = input.shape().dims().iter().take(axis).product();
|
||||||
|
let end_index = input.shape().dims().iter().product::<usize>();
|
||||||
|
let new_shape = (first_part, end_index / first_part);
|
||||||
|
let output = input.reshape(new_shape)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, Result, Tensor};
|
||||||
use candle_onnx::onnx::{GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
const INPUT_X: &str = "x";
|
const INPUT_X: &str = "x";
|
||||||
@ -677,6 +677,134 @@ fn test_dropout_operation() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// "Flatten"
|
||||||
|
#[test]
|
||||||
|
fn test_flatten_operation() -> Result<()> {
|
||||||
|
let mut att_axis = AttributeProto {
|
||||||
|
name: "axis".to_string(),
|
||||||
|
ref_attr_name: "axis".to_string(),
|
||||||
|
i: 0,
|
||||||
|
doc_string: "axis".to_string(),
|
||||||
|
r#type: 2,
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Flatten".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![att_axis.clone()],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_Y.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
vec![
|
||||||
|
1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32, 7.0f32, 8.0f32,
|
||||||
|
],
|
||||||
|
&[2, 2, 2],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let results = z.to_vec2::<f32>()?;
|
||||||
|
|
||||||
|
assert_eq!(results, vec![vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]]);
|
||||||
|
|
||||||
|
att_axis.i = 1;
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Flatten".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![att_axis.clone()],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_Y.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let results = z.to_vec2::<f32>()?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
results,
|
||||||
|
vec![vec![1.0, 2.0, 3.0, 4.0], vec![5.0, 6.0, 7.0, 8.0]]
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
// Below are ops that are implemented but not tested yet
|
// Below are ops that are implemented but not tested yet
|
||||||
|
|
||||||
// "MaxPool"
|
// "MaxPool"
|
||||||
|
Reference in New Issue
Block a user