Expand split ops (#2505)

* candle-onnx: Add Split and Expand operators, Fix Where Op

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining Split examples as tests
TODO: Add.test case that motivates Where fix

* candle-onnx: Add ReduceSum operator

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining ReduceSum examples as tests

* candle-onnx: Add ReduceL2 operator

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining ReduceSum examples as tests

* candle-onnx: Fix Clip operator empty string as default arg issue

Optional input args may be signified by an empty string. The length of the input array is not enough because non optional args may follow optional ones.

I encountered this when trying to use the ONNX model found at https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 for example.

The LSTM op has a utility which I factored to be more generally accessible, and I have used it in the ops I have recently created or debugged.

I believe it is likely that this issue may also manifest in other ops, but I didn't want to change anything that I'm not testing.

* fix formatting

* fix small mistake made during refactor
This commit is contained in:
Steven Lovegrove
2024-09-26 13:57:55 -07:00
committed by GitHub
parent ad8a4c5e5a
commit ed48f54b54
2 changed files with 528 additions and 14 deletions

View File

@ -323,6 +323,13 @@ fn simple_eval_(
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
};
let get_opt = |i: usize| {
node.input
.get(i)
.filter(|s: &&String| !s.is_empty())
.map(|s| get(s))
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
@ -608,15 +615,13 @@ fn simple_eval_(
}
"Clip" => {
let xs = get(&node.input[0])?;
let xs = if node.input.len() >= 2 {
let mins = get(&node.input[1])?;
xs.broadcast_maximum(mins)?
let xs = if let Some(mins) = get_opt(1) {
xs.broadcast_maximum(mins?)?
} else {
xs.clone()
};
let xs = if node.input.len() >= 3 {
let maxs = get(&node.input[2])?;
xs.broadcast_minimum(maxs)?
let xs = if let Some(maxs) = get_opt(2) {
xs.broadcast_minimum(maxs?)?
} else {
xs.clone()
};
@ -759,7 +764,14 @@ fn simple_eval_(
let cond = get(&node.input[0])?;
let a = get(&node.input[1])?;
let b = get(&node.input[2])?;
let output = cond.where_cond(a, b)?;
// where_cond requires that all inputs are the same shape.
// In contrast, the Where op in ONNX only requires that they are broadcastable.
let shape = broadcast_shape_from_many(&[&cond.dims(), &a.dims(), &b.dims()])?;
let cond = cond.broadcast_as(shape.clone())?;
let a = a.broadcast_as(shape.clone())?;
let b = b.broadcast_as(shape)?;
let output = cond.where_cond(&a, &b)?;
values.insert(node.output[0].clone(), output);
}
"Conv" => {
@ -962,6 +974,7 @@ fn simple_eval_(
}
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
@ -1199,6 +1212,152 @@ fn simple_eval_(
};
values.insert(node.output[0].clone(), output);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
// Version 18 impl
"Split" => {
let input_tensor = get(&node.input[0])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = input_tensor.normalize_axis(axis)?;
// Determine split sizes
let splits = if node.input.len() > 1 {
// If the split tensor is provided, use it to determine sizes
let split_tensor = get(&node.input[1])?.to_vec1::<i64>()?;
split_tensor.iter().map(|&x| x as usize).collect::<Vec<_>>()
} else {
let num_outputs = if let Some(&num_outputs_attrib) =
get_attr_opt::<i64>(node, "num_outputs")?
{
num_outputs_attrib as usize
} else {
node.output.len()
};
let input_dim = input_tensor.dim(axis)?;
let mut split_sizes =
vec![input_dim / num_outputs as usize; num_outputs as usize];
let remainder = input_dim % num_outputs as usize;
if remainder > 0 {
// If there's a remainder, add it to the last split size
split_sizes[num_outputs as usize - 1] += remainder;
}
split_sizes
};
// Perform the split operation
let mut outputs = vec![];
let mut start = 0;
for &size in &splits {
let end = start + size;
let slice = input_tensor.narrow(axis, start, size)?;
outputs.push(slice);
start = end;
}
// Insert the split outputs into the values map
for (output, slice) in node.output.iter().zip(outputs.into_iter()) {
values.insert(output.clone(), slice);
}
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand
// Version 13 impl
"Expand" => {
// unlike broadcast_to, expand allows for the output shape to
// be different from the specified shape.
let input_tensor = get(&node.input[0])?;
let input_shape = get(&node.input[1])?;
// Check that the shape tensor is 1D
if input_shape.rank() != 1 {
bail!(
"Expand expects 'shape' input to be 1D tensor: {:?}",
input_shape
);
}
let input_tensor_dims = input_tensor.dims();
let input_shape_dims = input_shape
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>();
let target_shape =
broadcast_shape(&input_tensor_dims, input_shape_dims.as_slice())?;
let expanded_tensor = input_tensor.broadcast_as(target_shape)?;
values.insert(node.output[0].clone(), expanded_tensor);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum
// Version 13 impl
"ReduceSum" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
.copied()
.unwrap_or(0);
let axes = match axes {
Some(axes) => axes?
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
(0..input.rank()).collect()
}
}
};
let output = if keepdims == 1 {
input.sum_keepdim(axes)?
} else {
input.sum(axes)?
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2
// Version 18 impl
"ReduceL2" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
.copied()
.unwrap_or(0);
let input_sq = input.sqr()?;
let axes = match axes {
Some(axes) => axes?
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
(0..input_sq.rank()).collect()
}
}
};
let output = if keepdims == 1 {
input_sq.sum_keepdim(axes)?.sqrt()?
} else {
input_sq.sum(axes)?.sqrt()?
};
values.insert(node.output[0].clone(), output);
}
random_type @ ("RandomUniform" | "RandomNormal") => {
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
// type by
@ -1395,13 +1554,6 @@ fn simple_eval_(
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
let r = get(&node.input[2])?;
let get_opt = |i: usize| {
node.input
.get(i)
.filter(|s: &&String| !s.is_empty())
.map(|s| get(s))
};
// The bias tensor for input gate.
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
// This tensor has shape `[num_directions, 8*hidden_size]`.
@ -1580,3 +1732,36 @@ fn simple_eval_(
})
.collect()
}
fn broadcast_shape(shape_a: &[usize], shape_b: &[usize]) -> Result<Vec<usize>> {
let (longest, shortest) = if shape_a.len() > shape_b.len() {
(shape_a, shape_b)
} else {
(shape_b, shape_a)
};
let diff = longest.len() - shortest.len();
let mut target_shape = longest[0..diff].to_vec();
for (dim1, dim2) in longest[diff..].iter().zip(shortest.iter()) {
if *dim1 == *dim2 || *dim2 == 1 || *dim1 == 1 {
target_shape.push(usize::max(*dim1, *dim2));
} else {
bail!(
"Expand: incompatible shapes for broadcast, {:?} and {:?}",
shape_a,
shape_b
);
}
}
Ok(target_shape)
}
fn broadcast_shape_from_many(shapes: &[&[usize]]) -> Result<Vec<usize>> {
if shapes.is_empty() {
return Ok(Vec::new());
}
let mut shape_out = shapes[0].to_vec();
for shape in shapes[1..].iter() {
shape_out = broadcast_shape(&shape_out, shape)?;
}
Ok(shape_out)
}

View File

@ -3980,3 +3980,332 @@ fn test_lstm() -> Result<()> {
Ok(())
}
#[test]
fn test_expand_dim_changed() -> Result<()> {
// Create a manual graph for the Expand operation
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Expand".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec!["data".to_string(), "new_shape".to_string()],
output: vec!["expanded".to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
input: vec![
ValueInfoProto {
name: "data".to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: "new_shape".to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: "expanded".to_string(),
doc_string: "".to_string(),
r#type: None,
}],
..GraphProto::default()
}));
// Input tensor with shape [3, 1]
let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
// New shape tensor: [2, 1, 6]
let new_shape = Tensor::from_vec(vec![2i64, 1, 6], (3,), &Device::Cpu)?;
// Expected output after expansion
let expected = Tensor::from_vec(
vec![
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32,
2.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32,
1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
3.0f32, 3.0f32, 3.0f32,
],
(2, 3, 6),
&Device::Cpu,
)?;
// Execute the model evaluation
let inputs = HashMap::from_iter([
("data".to_string(), data),
("new_shape".to_string(), new_shape),
]);
let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
// Retrieve and compare the result
let expanded = result.get("expanded").expect("Output 'expanded' not found");
assert_eq!(expanded.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
Ok(())
}
fn make_graph_helper(
op_name: &str,
inputs: &[&str],
outputs: &[&str],
attribs: Vec<AttributeProto>,
) -> ModelProto {
create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: op_name.to_string(),
domain: "".to_string(),
attribute: attribs,
input: inputs.iter().map(|s| s.to_string()).collect(),
output: outputs.iter().map(|s| s.to_string()).collect(),
name: "".to_string(),
doc_string: "".to_string(),
}],
input: inputs
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
..ValueInfoProto::default()
})
.collect(),
output: outputs
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
..ValueInfoProto::default()
})
.collect(),
..GraphProto::default()
}))
}
#[test]
fn test_expand_dim_unchanged() -> Result<()> {
// Create a manual graph for the Expand operation
let manual_graph = make_graph_helper("Expand", &["data", "new_shape"], &["expanded"], vec![]);
// Input tensor with shape [3, 1] and dtype f32
let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
// New shape tensor: [3, 4]
let new_shape = Tensor::from_vec(vec![3i64, 4], (2,), &Device::Cpu)?;
// Expected output after expansion, dtype f32
let expected = Tensor::from_vec(
vec![
1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
3.0f32,
],
(3, 4),
&Device::Cpu,
)?;
// Execute the model evaluation
let inputs = HashMap::from_iter([
("data".to_string(), data),
("new_shape".to_string(), new_shape),
]);
let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
// Retrieve and compare the result
let expanded = result.get("expanded").expect("Output 'expanded' not found");
assert_eq!(expanded.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
Ok(())
}
fn make_split_graph_helper(inputs: &[&str], outputs: &[&str], axis: i64) -> ModelProto {
let attribs = vec![AttributeProto {
name: "axis".to_string(),
r#type: AttributeType::Int.into(),
i: axis,
..AttributeProto::default()
}];
make_graph_helper("Split", inputs, outputs, attribs)
}
#[test]
fn test_split_equal_parts_1d_opset13() -> Result<()> {
let input = Tensor::from_vec(
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32],
(6,),
&Device::Cpu,
)?;
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), input);
{
let manual_graph =
make_split_graph_helper(&["input"], &["output_1", "output_2", "output_3"], 0);
let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
assert_eq!(eval.len(), 3);
let out1 = eval.get("output_1").expect("Output 'output_1' not found");
let out2 = eval.get("output_2").expect("Output 'output_2' not found");
let out3 = eval.get("output_3").expect("Output 'output_3' not found");
assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32]);
assert_eq!(out3.to_vec1::<f32>()?, vec![5.0f32, 6.0f32]);
}
{
let splits = Tensor::from_vec(vec![2i64, 4], (2,), &Device::Cpu)?;
inputs.insert("split".to_string(), splits);
let manual_graph =
make_split_graph_helper(&["input", "split"], &["output_1", "output_2"], 0);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 2);
let out1 = eval.get("output_1").expect("Output 'output_1' not found");
let out2 = eval.get("output_2").expect("Output 'output_2' not found");
assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32, 5.0f32, 6.0f32]);
}
Ok(())
}
fn make_reduce_sum_graph_helper(
inputs: &[&str],
outputs: &[&str],
keepdims: Option<i64>,
noop_with_empty_axes: Option<i64>,
) -> ModelProto {
let mut attribs = vec![];
if let Some(keepdims) = keepdims {
attribs.push(AttributeProto {
name: "keepdims".to_string(),
r#type: AttributeType::Int.into(),
i: keepdims,
..AttributeProto::default()
});
}
if let Some(noop_with_empty_axes) = noop_with_empty_axes {
attribs.push(AttributeProto {
name: "noop_with_empty_axes".to_string(),
r#type: AttributeType::Ints.into(),
i: noop_with_empty_axes,
..AttributeProto::default()
});
}
make_graph_helper("ReduceSum", inputs, outputs, attribs)
}
#[test]
fn test_reduce_sum_default_axes_keepdims() -> Result<()> {
let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(1), None);
// Test with example data
{
let data = Tensor::from_vec(
vec![
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
(3, 2, 2),
&Device::Cpu,
)?;
// let axes = Tensor::from_vec(Vec::<i64>::new(), (0,), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("data".to_string(), data);
// inputs.insert("axes".to_string(), axes);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
let expected = Tensor::from_vec(vec![78.0f32], (1, 1, 1), &Device::Cpu)?;
assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
}
{
let data = Tensor::from_vec(
vec![
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
],
(3, 2, 2),
&Device::Cpu,
)?;
let mut inputs = HashMap::new();
inputs.insert("data".to_string(), data.clone());
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
let expected = data.sum_all()?.reshape((1, 1, 1))?;
assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
}
Ok(())
}
#[test]
fn test_reduce_sum_do_not_keep_dims() -> Result<()> {
let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(0), None);
// Test with example data
{
let data = Tensor::from_vec(
vec![
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
(3, 2, 2),
&Device::Cpu,
)?;
let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("data".to_string(), data);
inputs.insert("axes".to_string(), axes);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
let expected = Tensor::from_vec(
vec![4.0f32, 6.0, 12.0, 14.0, 20.0, 22.0],
(3, 2),
&Device::Cpu,
)?;
assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
}
// Test with random data
{
let shape = (3, 2, 2);
let data = Tensor::from_vec(
vec![
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
],
(3, 2, 2),
&Device::Cpu,
)?;
let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("data".to_string(), data.clone());
inputs.insert("axes".to_string(), axes);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
// Calculate expected result
let expected = data.sum(1)?;
assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
}
Ok(())
}