mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
Support for CumSum in ONNX models. (#1340)
This commit is contained in:
@ -741,6 +741,25 @@ pub fn simple_eval(
|
|||||||
let output = input.to_dtype(dtype)?;
|
let output = input.to_dtype(dtype)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum
|
||||||
|
"CumSum" => {
|
||||||
|
let exclusive = get_attr_opt::<i64>(node, "exclusive")?
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(0);
|
||||||
|
let reverse = get_attr_opt::<i64>(node, "reverse")?.copied().unwrap_or(0);
|
||||||
|
if exclusive != 0 {
|
||||||
|
bail!("only exclusive == 0 is supported in CumSum")
|
||||||
|
}
|
||||||
|
if reverse != 0 {
|
||||||
|
bail!("only reverse == 0 is supported in CumSum")
|
||||||
|
}
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let axis = get(&node.input[1])?
|
||||||
|
.to_dtype(DType::U32)?
|
||||||
|
.to_vec0::<u32>()?;
|
||||||
|
let output = input.cumsum(axis as usize)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user