onnx: add the Flatten operator. (#1638)

* onnx: add the Flatten operator.

* onnx flatten: merge axis condition

---------

Co-authored-by: 王泽龙 <wangzelong@shenqishen.com>
This commit is contained in:
wanglong001
2024-02-03 23:28:47 +08:00
committed by GitHub
parent a510ddec4e
commit 5cdd84e0f6
2 changed files with 139 additions and 1 deletions

View File

@ -766,6 +766,16 @@ pub fn simple_eval(
let output = input.cumsum(axis as usize)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#flatten
"Flatten" => {
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(1) as usize;
let input = get(&node.input[0])?;
let first_part: usize = input.shape().dims().iter().take(axis).product();
let end_index = input.shape().dims().iter().product::<usize>();
let new_shape = (first_part, end_index / first_part);
let output = input.reshape(new_shape)?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}