diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 123e4c19..684776c2 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -741,6 +741,25 @@ pub fn simple_eval( let output = input.to_dtype(dtype)?; values.insert(node.output[0].clone(), output); } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum + "CumSum" => { + let exclusive = get_attr_opt::(node, "exclusive")? + .copied() + .unwrap_or(0); + let reverse = get_attr_opt::(node, "reverse")?.copied().unwrap_or(0); + if exclusive != 0 { + bail!("only exclusive == 0 is supported in CumSum") + } + if reverse != 0 { + bail!("only reverse == 0 is supported in CumSum") + } + let input = get(&node.input[0])?; + let axis = get(&node.input[1])? + .to_dtype(DType::U32)? + .to_vec0::()?; + let output = input.cumsum(axis as usize)?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } }