mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
onnx: add the Flatten operator. (#1638)
* onnx: add the Flatten operator. * onnx flatten: merge axis condition --------- Co-authored-by: 王泽龙 <wangzelong@shenqishen.com>
This commit is contained in:
@ -766,6 +766,16 @@ pub fn simple_eval(
|
||||
let output = input.cumsum(axis as usize)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#flatten
|
||||
"Flatten" => {
|
||||
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(1) as usize;
|
||||
let input = get(&node.input[0])?;
|
||||
let first_part: usize = input.shape().dims().iter().take(axis).product();
|
||||
let end_index = input.shape().dims().iter().product::<usize>();
|
||||
let new_shape = (first_part, end_index / first_part);
|
||||
let output = input.reshape(new_shape)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user