Expand split ops (#2505)

* candle-onnx: Add Split and Expand operators, Fix Where Op

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining Split examples as tests
TODO: Add.test case that motivates Where fix

* candle-onnx: Add ReduceSum operator

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining ReduceSum examples as tests

* candle-onnx: Add ReduceL2 operator

Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md
Test cases based on those examples.

TODO: Should add the remaining ReduceSum examples as tests

* candle-onnx: Fix Clip operator empty string as default arg issue

Optional input args may be signified by an empty string. The length of the input array is not enough because non optional args may follow optional ones.

I encountered this when trying to use the ONNX model found at https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 for example.

The LSTM op has a utility which I factored to be more generally accessible, and I have used it in the ops I have recently created or debugged.

I believe it is likely that this issue may also manifest in other ops, but I didn't want to change anything that I'm not testing.

* fix formatting

* fix small mistake made during refactor
This commit is contained in:
Steven Lovegrove
2024-09-26 13:57:55 -07:00
committed by GitHub
parent ad8a4c5e5a
commit ed48f54b54
2 changed files with 528 additions and 14 deletions

View File

@ -3980,3 +3980,332 @@ fn test_lstm() -> Result<()> {
Ok(())
}
#[test]
fn test_expand_dim_changed() -> Result<()> {
// Create a manual graph for the Expand operation
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Expand".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec!["data".to_string(), "new_shape".to_string()],
output: vec!["expanded".to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
input: vec![
ValueInfoProto {
name: "data".to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: "new_shape".to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: "expanded".to_string(),
doc_string: "".to_string(),
r#type: None,
}],
..GraphProto::default()
}));
// Input tensor with shape [3, 1]
let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
// New shape tensor: [2, 1, 6]
let new_shape = Tensor::from_vec(vec![2i64, 1, 6], (3,), &Device::Cpu)?;
// Expected output after expansion
let expected = Tensor::from_vec(
vec![
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32,
2.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32,
1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
3.0f32, 3.0f32, 3.0f32,
],
(2, 3, 6),
&Device::Cpu,
)?;
// Execute the model evaluation
let inputs = HashMap::from_iter([
("data".to_string(), data),
("new_shape".to_string(), new_shape),
]);
let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
// Retrieve and compare the result
let expanded = result.get("expanded").expect("Output 'expanded' not found");
assert_eq!(expanded.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
Ok(())
}
fn make_graph_helper(
op_name: &str,
inputs: &[&str],
outputs: &[&str],
attribs: Vec<AttributeProto>,
) -> ModelProto {
create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: op_name.to_string(),
domain: "".to_string(),
attribute: attribs,
input: inputs.iter().map(|s| s.to_string()).collect(),
output: outputs.iter().map(|s| s.to_string()).collect(),
name: "".to_string(),
doc_string: "".to_string(),
}],
input: inputs
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
..ValueInfoProto::default()
})
.collect(),
output: outputs
.into_iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
..ValueInfoProto::default()
})
.collect(),
..GraphProto::default()
}))
}
#[test]
fn test_expand_dim_unchanged() -> Result<()> {
// Create a manual graph for the Expand operation
let manual_graph = make_graph_helper("Expand", &["data", "new_shape"], &["expanded"], vec![]);
// Input tensor with shape [3, 1] and dtype f32
let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
// New shape tensor: [3, 4]
let new_shape = Tensor::from_vec(vec![3i64, 4], (2,), &Device::Cpu)?;
// Expected output after expansion, dtype f32
let expected = Tensor::from_vec(
vec![
1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
3.0f32,
],
(3, 4),
&Device::Cpu,
)?;
// Execute the model evaluation
let inputs = HashMap::from_iter([
("data".to_string(), data),
("new_shape".to_string(), new_shape),
]);
let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
// Retrieve and compare the result
let expanded = result.get("expanded").expect("Output 'expanded' not found");
assert_eq!(expanded.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
Ok(())
}
fn make_split_graph_helper(inputs: &[&str], outputs: &[&str], axis: i64) -> ModelProto {
let attribs = vec![AttributeProto {
name: "axis".to_string(),
r#type: AttributeType::Int.into(),
i: axis,
..AttributeProto::default()
}];
make_graph_helper("Split", inputs, outputs, attribs)
}
#[test]
fn test_split_equal_parts_1d_opset13() -> Result<()> {
let input = Tensor::from_vec(
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32],
(6,),
&Device::Cpu,
)?;
let mut inputs = HashMap::new();
inputs.insert("input".to_string(), input);
{
let manual_graph =
make_split_graph_helper(&["input"], &["output_1", "output_2", "output_3"], 0);
let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
assert_eq!(eval.len(), 3);
let out1 = eval.get("output_1").expect("Output 'output_1' not found");
let out2 = eval.get("output_2").expect("Output 'output_2' not found");
let out3 = eval.get("output_3").expect("Output 'output_3' not found");
assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32]);
assert_eq!(out3.to_vec1::<f32>()?, vec![5.0f32, 6.0f32]);
}
{
let splits = Tensor::from_vec(vec![2i64, 4], (2,), &Device::Cpu)?;
inputs.insert("split".to_string(), splits);
let manual_graph =
make_split_graph_helper(&["input", "split"], &["output_1", "output_2"], 0);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 2);
let out1 = eval.get("output_1").expect("Output 'output_1' not found");
let out2 = eval.get("output_2").expect("Output 'output_2' not found");
assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32, 5.0f32, 6.0f32]);
}
Ok(())
}
fn make_reduce_sum_graph_helper(
inputs: &[&str],
outputs: &[&str],
keepdims: Option<i64>,
noop_with_empty_axes: Option<i64>,
) -> ModelProto {
let mut attribs = vec![];
if let Some(keepdims) = keepdims {
attribs.push(AttributeProto {
name: "keepdims".to_string(),
r#type: AttributeType::Int.into(),
i: keepdims,
..AttributeProto::default()
});
}
if let Some(noop_with_empty_axes) = noop_with_empty_axes {
attribs.push(AttributeProto {
name: "noop_with_empty_axes".to_string(),
r#type: AttributeType::Ints.into(),
i: noop_with_empty_axes,
..AttributeProto::default()
});
}
make_graph_helper("ReduceSum", inputs, outputs, attribs)
}
#[test]
fn test_reduce_sum_default_axes_keepdims() -> Result<()> {
let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(1), None);
// Test with example data
{
let data = Tensor::from_vec(
vec![
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
(3, 2, 2),
&Device::Cpu,
)?;
// let axes = Tensor::from_vec(Vec::<i64>::new(), (0,), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("data".to_string(), data);
// inputs.insert("axes".to_string(), axes);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
let expected = Tensor::from_vec(vec![78.0f32], (1, 1, 1), &Device::Cpu)?;
assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
}
{
let data = Tensor::from_vec(
vec![
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
],
(3, 2, 2),
&Device::Cpu,
)?;
let mut inputs = HashMap::new();
inputs.insert("data".to_string(), data.clone());
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
let expected = data.sum_all()?.reshape((1, 1, 1))?;
assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
}
Ok(())
}
#[test]
fn test_reduce_sum_do_not_keep_dims() -> Result<()> {
let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(0), None);
// Test with example data
{
let data = Tensor::from_vec(
vec![
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
],
(3, 2, 2),
&Device::Cpu,
)?;
let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("data".to_string(), data);
inputs.insert("axes".to_string(), axes);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
let expected = Tensor::from_vec(
vec![4.0f32, 6.0, 12.0, 14.0, 20.0, 22.0],
(3, 2),
&Device::Cpu,
)?;
assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
}
// Test with random data
{
let shape = (3, 2, 2);
let data = Tensor::from_vec(
vec![
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
],
(3, 2, 2),
&Device::Cpu,
)?;
let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
let mut inputs = HashMap::new();
inputs.insert("data".to_string(), data.clone());
inputs.insert("axes".to_string(), axes);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
// Calculate expected result
let expected = data.sum(1)?;
assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
}
Ok(())
}