implement Slice op (#2260)

This commit is contained in:
shua
2024-06-12 08:15:32 +02:00
committed by GitHub
parent 9f804af29d
commit 2b10aaa05d
2 changed files with 215 additions and 0 deletions

View File

@ -14,6 +14,7 @@ pub fn dtype(dt: DataType) -> Option<DType> {
DataType::Float16 => Some(DType::F16), DataType::Float16 => Some(DType::F16),
DataType::Float => Some(DType::F32), DataType::Float => Some(DType::F32),
DataType::Double => Some(DType::F64), DataType::Double => Some(DType::F64),
DataType::Bool => Some(DType::U8),
_ => None, _ => None,
} }
} }
@ -1053,6 +1054,85 @@ fn simple_eval_(
), ),
} }
} }
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice
"Slice" => {
let data = get(&node.input[0])?;
let starts = get(&node.input[1])?;
let ends = get(&node.input[2])?;
let default_axes;
let default_steps;
let axes: &Tensor;
let steps: &Tensor;
// If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted,
// they are set to [1, ..., 1] of length len(starts)
match node.input.len() {
3 => {
let len = starts.dims()[0];
default_axes = Some(Tensor::arange(0, len as i64, starts.device())?);
axes = default_axes.as_ref().unwrap();
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
steps = default_steps.as_ref().unwrap();
}
4 => {
let len = starts.dims()[0];
axes = get(&node.input[3])?;
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
steps = default_steps.as_ref().unwrap();
}
5 => {
steps = get(&node.input[4])?;
axes = get(&node.input[3])?;
}
_ => bail!(
"Slice node is invalid, expected 3-5 inputs, got {}: {:?}",
node.input.len(),
node
),
}
let mut out = data.clone();
for (i, axis) in axes.to_vec1::<i64>()?.into_iter().enumerate() {
// All negative elements of axes are made non-negative by
// adding r to them, where r = rank(input).
let axis = if axis < 0 {
axis + data.rank() as i64
} else {
axis
} as usize;
let data_dim = data.dims()[axis] as i64;
let mut s = starts.get(i)?.to_scalar::<i64>()?;
let mut e = ends.get(i)?.to_scalar::<i64>()?;
// All negative values in starts[i] and ends[i] have
// dims[axes[i]] added to them, where dims are the
// dimensions of input.
if s < 0 {
s += data_dim;
}
if e < 0 {
e += data_dim;
}
let p = steps.get(i)?.to_scalar::<i64>()?;
// starts[i] is clamped into the range [0, dims[axes[i]]]
// for positive stepping and [0, dims[axes[i]]-1] for
// negative stepping.
// for positive stepping ends[axes[i]] is clamped to
// [0, dims[axes[i]]], while for negative stepping it is
// clamped to [-1, dims[axes[i]]-1].
if p >= 0 {
s = s.clamp(0, data_dim);
e = e.clamp(0, data_dim);
} else {
s = s.clamp(0, data_dim - 1);
e = e.clamp(-1, data_dim - 1);
}
let indexes = Tensor::arange_step(s, e, p, data.device())?;
out = out.index_select(&indexes, axis)?
}
values.insert(node.output[0].clone(), out);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13 // https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below. // TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => { "ReduceMean" => {

View File

@ -3272,3 +3272,138 @@ fn test_pad() -> Result<()> {
assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?); assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?);
Ok(()) Ok(())
} }
#[test]
fn test_slice() -> Result<()> {
let model = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Slice".to_string(),
input: vec![
"data".to_string(),
"starts".to_string(),
"ends".to_string(),
"axes".to_string(),
"steps".to_string(),
],
output: vec!["result".to_string()],
..NodeProto::default()
}],
input: ["data", "starts", "ends", "axes", "steps"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
r#type: None,
doc_string: "".to_string(),
})
.collect(),
output: ["result"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
r#type: None,
doc_string: "".to_string(),
})
.collect(),
..GraphProto::default()
}));
/*
data = [
[1, 2, 3, 4],
[5, 6, 7, 8],
]
axes = [0, 1]
starts = [1, 0]
ends = [2, 3]
steps = [1, 2]
result = [
[5, 7],
]
*/
let outputs = candle_onnx::simple_eval(
&model,
HashMap::from_iter([
(
"data".to_string(),
Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
),
(
"starts".to_string(),
Tensor::from_vec(vec![1i64, 0], (2,), &Device::Cpu)?,
),
(
"ends".to_string(),
Tensor::from_vec(vec![2i64, 3], (2,), &Device::Cpu)?,
),
(
"axes".to_string(),
Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
),
(
"steps".to_string(),
Tensor::from_vec(vec![1i64, 2], (2,), &Device::Cpu)?,
),
]),
)?;
let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
assert_eq!(actual, vec![vec![5i64, 7]]);
/*
data = [
[1, 2, 3, 4],
[5, 6, 7, 8],
]
starts = [0, 1]
ends = [-1, 1000]
result = [
[2, 3, 4],
]
*/
let model = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Slice".to_string(),
input: vec!["data".to_string(), "starts".to_string(), "ends".to_string()],
output: vec!["result".to_string()],
..NodeProto::default()
}],
input: ["data", "starts", "ends"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
r#type: None,
doc_string: "".to_string(),
})
.collect(),
output: ["result"]
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
r#type: None,
doc_string: "".to_string(),
})
.collect(),
..GraphProto::default()
}));
let outputs = candle_onnx::simple_eval(
&model,
HashMap::from_iter([
(
"data".to_string(),
Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
),
(
"starts".to_string(),
Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
),
(
"ends".to_string(),
Tensor::from_vec(vec![-1i64, 1000], (2,), &Device::Cpu)?,
),
]),
)?;
let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
assert_eq!(actual, vec![vec![2i64, 3, 4]]);
Ok(())
}