From d31f11035fb69d06bf22194d918a5efc45c8ab37 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 17 Nov 2023 22:03:40 +0000 Subject: [PATCH] Support for CumSum in ONNX models. (#1340) --- candle-onnx/src/eval.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 123e4c19..684776c2 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -741,6 +741,25 @@ pub fn simple_eval( let output = input.to_dtype(dtype)?; values.insert(node.output[0].clone(), output); } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum + "CumSum" => { + let exclusive = get_attr_opt::(node, "exclusive")? + .copied() + .unwrap_or(0); + let reverse = get_attr_opt::(node, "reverse")?.copied().unwrap_or(0); + if exclusive != 0 { + bail!("only exclusive == 0 is supported in CumSum") + } + if reverse != 0 { + bail!("only reverse == 0 is supported in CumSum") + } + let input = get(&node.input[0])?; + let axis = get(&node.input[1])? + .to_dtype(DType::U32)? + .to_vec0::()?; + let output = input.cumsum(axis as usize)?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } }