mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
implement Slice op (#2260)
This commit is contained in:
@ -14,6 +14,7 @@ pub fn dtype(dt: DataType) -> Option<DType> {
|
|||||||
DataType::Float16 => Some(DType::F16),
|
DataType::Float16 => Some(DType::F16),
|
||||||
DataType::Float => Some(DType::F32),
|
DataType::Float => Some(DType::F32),
|
||||||
DataType::Double => Some(DType::F64),
|
DataType::Double => Some(DType::F64),
|
||||||
|
DataType::Bool => Some(DType::U8),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1053,6 +1054,85 @@ fn simple_eval_(
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice
|
||||||
|
"Slice" => {
|
||||||
|
let data = get(&node.input[0])?;
|
||||||
|
let starts = get(&node.input[1])?;
|
||||||
|
let ends = get(&node.input[2])?;
|
||||||
|
let default_axes;
|
||||||
|
let default_steps;
|
||||||
|
let axes: &Tensor;
|
||||||
|
let steps: &Tensor;
|
||||||
|
// If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted,
|
||||||
|
// they are set to [1, ..., 1] of length len(starts)
|
||||||
|
match node.input.len() {
|
||||||
|
3 => {
|
||||||
|
let len = starts.dims()[0];
|
||||||
|
default_axes = Some(Tensor::arange(0, len as i64, starts.device())?);
|
||||||
|
axes = default_axes.as_ref().unwrap();
|
||||||
|
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
|
||||||
|
steps = default_steps.as_ref().unwrap();
|
||||||
|
}
|
||||||
|
4 => {
|
||||||
|
let len = starts.dims()[0];
|
||||||
|
axes = get(&node.input[3])?;
|
||||||
|
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
|
||||||
|
steps = default_steps.as_ref().unwrap();
|
||||||
|
}
|
||||||
|
5 => {
|
||||||
|
steps = get(&node.input[4])?;
|
||||||
|
axes = get(&node.input[3])?;
|
||||||
|
}
|
||||||
|
_ => bail!(
|
||||||
|
"Slice node is invalid, expected 3-5 inputs, got {}: {:?}",
|
||||||
|
node.input.len(),
|
||||||
|
node
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut out = data.clone();
|
||||||
|
for (i, axis) in axes.to_vec1::<i64>()?.into_iter().enumerate() {
|
||||||
|
// All negative elements of axes are made non-negative by
|
||||||
|
// adding r to them, where r = rank(input).
|
||||||
|
let axis = if axis < 0 {
|
||||||
|
axis + data.rank() as i64
|
||||||
|
} else {
|
||||||
|
axis
|
||||||
|
} as usize;
|
||||||
|
|
||||||
|
let data_dim = data.dims()[axis] as i64;
|
||||||
|
let mut s = starts.get(i)?.to_scalar::<i64>()?;
|
||||||
|
let mut e = ends.get(i)?.to_scalar::<i64>()?;
|
||||||
|
// All negative values in starts[i] and ends[i] have
|
||||||
|
// dims[axes[i]] added to them, where dims are the
|
||||||
|
// dimensions of input.
|
||||||
|
if s < 0 {
|
||||||
|
s += data_dim;
|
||||||
|
}
|
||||||
|
if e < 0 {
|
||||||
|
e += data_dim;
|
||||||
|
}
|
||||||
|
|
||||||
|
let p = steps.get(i)?.to_scalar::<i64>()?;
|
||||||
|
// starts[i] is clamped into the range [0, dims[axes[i]]]
|
||||||
|
// for positive stepping and [0, dims[axes[i]]-1] for
|
||||||
|
// negative stepping.
|
||||||
|
// for positive stepping ends[axes[i]] is clamped to
|
||||||
|
// [0, dims[axes[i]]], while for negative stepping it is
|
||||||
|
// clamped to [-1, dims[axes[i]]-1].
|
||||||
|
if p >= 0 {
|
||||||
|
s = s.clamp(0, data_dim);
|
||||||
|
e = e.clamp(0, data_dim);
|
||||||
|
} else {
|
||||||
|
s = s.clamp(0, data_dim - 1);
|
||||||
|
e = e.clamp(-1, data_dim - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
let indexes = Tensor::arange_step(s, e, p, data.device())?;
|
||||||
|
out = out.index_select(&indexes, axis)?
|
||||||
|
}
|
||||||
|
values.insert(node.output[0].clone(), out);
|
||||||
|
}
|
||||||
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
|
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
|
||||||
// TODO: This version is only compatible with ReduceMean V13 and below.
|
// TODO: This version is only compatible with ReduceMean V13 and below.
|
||||||
"ReduceMean" => {
|
"ReduceMean" => {
|
||||||
|
@ -3272,3 +3272,138 @@ fn test_pad() -> Result<()> {
|
|||||||
assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?);
|
assert_eq!(actual.to_vec2::<f64>()?, expected.to_vec2::<f64>()?);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_slice() -> Result<()> {
|
||||||
|
let model = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Slice".to_string(),
|
||||||
|
input: vec![
|
||||||
|
"data".to_string(),
|
||||||
|
"starts".to_string(),
|
||||||
|
"ends".to_string(),
|
||||||
|
"axes".to_string(),
|
||||||
|
"steps".to_string(),
|
||||||
|
],
|
||||||
|
output: vec!["result".to_string()],
|
||||||
|
..NodeProto::default()
|
||||||
|
}],
|
||||||
|
input: ["data", "starts", "ends", "axes", "steps"]
|
||||||
|
.into_iter()
|
||||||
|
.map(|name| ValueInfoProto {
|
||||||
|
name: name.to_string(),
|
||||||
|
r#type: None,
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
output: ["result"]
|
||||||
|
.into_iter()
|
||||||
|
.map(|name| ValueInfoProto {
|
||||||
|
name: name.to_string(),
|
||||||
|
r#type: None,
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
..GraphProto::default()
|
||||||
|
}));
|
||||||
|
|
||||||
|
/*
|
||||||
|
data = [
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8],
|
||||||
|
]
|
||||||
|
axes = [0, 1]
|
||||||
|
starts = [1, 0]
|
||||||
|
ends = [2, 3]
|
||||||
|
steps = [1, 2]
|
||||||
|
result = [
|
||||||
|
[5, 7],
|
||||||
|
]
|
||||||
|
*/
|
||||||
|
|
||||||
|
let outputs = candle_onnx::simple_eval(
|
||||||
|
&model,
|
||||||
|
HashMap::from_iter([
|
||||||
|
(
|
||||||
|
"data".to_string(),
|
||||||
|
Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"starts".to_string(),
|
||||||
|
Tensor::from_vec(vec![1i64, 0], (2,), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ends".to_string(),
|
||||||
|
Tensor::from_vec(vec![2i64, 3], (2,), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"axes".to_string(),
|
||||||
|
Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"steps".to_string(),
|
||||||
|
Tensor::from_vec(vec![1i64, 2], (2,), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
)?;
|
||||||
|
let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
|
||||||
|
assert_eq!(actual, vec![vec![5i64, 7]]);
|
||||||
|
|
||||||
|
/*
|
||||||
|
data = [
|
||||||
|
[1, 2, 3, 4],
|
||||||
|
[5, 6, 7, 8],
|
||||||
|
]
|
||||||
|
starts = [0, 1]
|
||||||
|
ends = [-1, 1000]
|
||||||
|
result = [
|
||||||
|
[2, 3, 4],
|
||||||
|
]
|
||||||
|
*/
|
||||||
|
let model = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Slice".to_string(),
|
||||||
|
input: vec!["data".to_string(), "starts".to_string(), "ends".to_string()],
|
||||||
|
output: vec!["result".to_string()],
|
||||||
|
..NodeProto::default()
|
||||||
|
}],
|
||||||
|
input: ["data", "starts", "ends"]
|
||||||
|
.into_iter()
|
||||||
|
.map(|name| ValueInfoProto {
|
||||||
|
name: name.to_string(),
|
||||||
|
r#type: None,
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
output: ["result"]
|
||||||
|
.into_iter()
|
||||||
|
.map(|name| ValueInfoProto {
|
||||||
|
name: name.to_string(),
|
||||||
|
r#type: None,
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
..GraphProto::default()
|
||||||
|
}));
|
||||||
|
let outputs = candle_onnx::simple_eval(
|
||||||
|
&model,
|
||||||
|
HashMap::from_iter([
|
||||||
|
(
|
||||||
|
"data".to_string(),
|
||||||
|
Tensor::from_vec(vec![1i64, 2, 3, 4, 5, 6, 7, 8], (2, 4), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"starts".to_string(),
|
||||||
|
Tensor::from_vec(vec![0i64, 1], (2,), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"ends".to_string(),
|
||||||
|
Tensor::from_vec(vec![-1i64, 1000], (2,), &Device::Cpu)?,
|
||||||
|
),
|
||||||
|
]),
|
||||||
|
)?;
|
||||||
|
let actual = outputs.get("result").unwrap().to_vec2::<i64>()?;
|
||||||
|
assert_eq!(actual, vec![vec![2i64, 3, 4]]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user