From 2b10aaa05d3752186899bd5b5364d92164edc7ef Mon Sep 17 00:00:00 2001 From: shua Date: Wed, 12 Jun 2024 08:15:32 +0200 Subject: [PATCH] implement Slice op (#2260) --- candle-onnx/src/eval.rs | 80 +++++++++++++++++++++++ candle-onnx/tests/ops.rs | 135 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 215 insertions(+) diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index f52e6c5c..10a3b937 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -14,6 +14,7 @@ pub fn dtype(dt: DataType) -> Option { DataType::Float16 => Some(DType::F16), DataType::Float => Some(DType::F32), DataType::Double => Some(DType::F64), + DataType::Bool => Some(DType::U8), _ => None, } } @@ -1053,6 +1054,85 @@ fn simple_eval_( ), } } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice + "Slice" => { + let data = get(&node.input[0])?; + let starts = get(&node.input[1])?; + let ends = get(&node.input[2])?; + let default_axes; + let default_steps; + let axes: &Tensor; + let steps: &Tensor; + // If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted, + // they are set to [1, ..., 1] of length len(starts) + match node.input.len() { + 3 => { + let len = starts.dims()[0]; + default_axes = Some(Tensor::arange(0, len as i64, starts.device())?); + axes = default_axes.as_ref().unwrap(); + default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?); + steps = default_steps.as_ref().unwrap(); + } + 4 => { + let len = starts.dims()[0]; + axes = get(&node.input[3])?; + default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?); + steps = default_steps.as_ref().unwrap(); + } + 5 => { + steps = get(&node.input[4])?; + axes = get(&node.input[3])?; + } + _ => bail!( + "Slice node is invalid, expected 3-5 inputs, got {}: {:?}", + node.input.len(), + node + ), + } + + let mut out = data.clone(); + for (i, axis) in axes.to_vec1::()?.into_iter().enumerate() { + // All negative elements of axes are made non-negative by + // adding r to them, where r = rank(input). + let axis = if axis < 0 { + axis + data.rank() as i64 + } else { + axis + } as usize; + + let data_dim = data.dims()[axis] as i64; + let mut s = starts.get(i)?.to_scalar::()?; + let mut e = ends.get(i)?.to_scalar::()?; + // All negative values in starts[i] and ends[i] have + // dims[axes[i]] added to them, where dims are the + // dimensions of input. + if s < 0 { + s += data_dim; + } + if e < 0 { + e += data_dim; + } + + let p = steps.get(i)?.to_scalar::()?; + // starts[i] is clamped into the range [0, dims[axes[i]]] + // for positive stepping and [0, dims[axes[i]]-1] for + // negative stepping. + // for positive stepping ends[axes[i]] is clamped to + // [0, dims[axes[i]]], while for negative stepping it is + // clamped to [-1, dims[axes[i]]-1]. + if p >= 0 { + s = s.clamp(0, data_dim); + e = e.clamp(0, data_dim); + } else { + s = s.clamp(0, data_dim - 1); + e = e.clamp(-1, data_dim - 1); + } + + let indexes = Tensor::arange_step(s, e, p, data.device())?; + out = out.index_select(&indexes, axis)? + } + values.insert(node.output[0].clone(), out); + } // https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13 // TODO: This version is only compatible with ReduceMean V13 and below. "ReduceMean" => { diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index b4299af1..82d38aa4 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -3272,3 +3272,138 @@ fn test_pad() -> Result<()> { assert_eq!(actual.to_vec2::()?, expected.to_vec2::()?); Ok(()) } + +#[test] +fn test_slice() -> Result<()> { + let model = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Slice".to_string(), + input: vec![ + "data".to_string(), + "starts".to_string(), + "ends".to_string(), + "axes".to_string(), + "steps".to_string(), + ], + output: vec!["result".to_string()], + ..NodeProto::default() + }], + input: ["data", "starts", "ends", "axes", "steps"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + r#type: None, + doc_string: "".to_string(), + }) + .collect(), + output: ["result"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + r#type: None, + doc_string: "".to_string(), + }) + .collect(), + ..GraphProto::default() + })); + + /* + data = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + axes = [0, 1] + starts = [1, 0] + ends = [2, 3] + steps = [1, 2] + result = [ + [5, 7], + ] + */ + + let outputs = candle_onnx::simple_eval( + &model, + HashMap::from_iter([ + ( + "data".to_string(), + Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?, + ), + ( + "starts".to_string(), + Tensor::from_vec(vec![1i64, 0], (2,), &Device::Cpu)?, + ), + ( + "ends".to_string(), + Tensor::from_vec(vec![2i64, 3], (2,), &Device::Cpu)?, + ), + ( + "axes".to_string(), + Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?, + ), + ( + "steps".to_string(), + Tensor::from_vec(vec![1i64, 2], (2,), &Device::Cpu)?, + ), + ]), + )?; + let actual = outputs.get("result").unwrap().to_vec2::()?; + assert_eq!(actual, vec![vec![5i64, 7]]); + + /* + data = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + starts = [0, 1] + ends = [-1, 1000] + result = [ + [2, 3, 4], + ] + */ + let model = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Slice".to_string(), + input: vec!["data".to_string(), "starts".to_string(), "ends".to_string()], + output: vec!["result".to_string()], + ..NodeProto::default() + }], + input: ["data", "starts", "ends"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + r#type: None, + doc_string: "".to_string(), + }) + .collect(), + output: ["result"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + r#type: None, + doc_string: "".to_string(), + }) + .collect(), + ..GraphProto::default() + })); + let outputs = candle_onnx::simple_eval( + &model, + HashMap::from_iter([ + ( + "data".to_string(), + Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?, + ), + ( + "starts".to_string(), + Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?, + ), + ( + "ends".to_string(), + Tensor::from_vec(vec![-1i64, 1000], (2,), &Device::Cpu)?, + ), + ]), + )?; + let actual = outputs.get("result").unwrap().to_vec2::()?; + assert_eq!(actual, vec![vec![2i64, 3, 4]]); + + Ok(()) +}