mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Expand split ops (#2505)
* candle-onnx: Add Split and Expand operators, Fix Where Op Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md Test cases based on those examples. TODO: Should add the remaining Split examples as tests TODO: Add.test case that motivates Where fix * candle-onnx: Add ReduceSum operator Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md Test cases based on those examples. TODO: Should add the remaining ReduceSum examples as tests * candle-onnx: Add ReduceL2 operator Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md Test cases based on those examples. TODO: Should add the remaining ReduceSum examples as tests * candle-onnx: Fix Clip operator empty string as default arg issue Optional input args may be signified by an empty string. The length of the input array is not enough because non optional args may follow optional ones. I encountered this when trying to use the ONNX model found at https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 for example. The LSTM op has a utility which I factored to be more generally accessible, and I have used it in the ops I have recently created or debugged. I believe it is likely that this issue may also manifest in other ops, but I didn't want to change anything that I'm not testing. * fix formatting * fix small mistake made during refactor
This commit is contained in:
@ -3980,3 +3980,332 @@ fn test_lstm() -> Result<()> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_dim_changed() -> Result<()> {
|
||||
// Create a manual graph for the Expand operation
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Expand".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec!["data".to_string(), "new_shape".to_string()],
|
||||
output: vec!["expanded".to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: "data".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: "new_shape".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: "expanded".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
..GraphProto::default()
|
||||
}));
|
||||
|
||||
// Input tensor with shape [3, 1]
|
||||
let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
|
||||
|
||||
// New shape tensor: [2, 1, 6]
|
||||
let new_shape = Tensor::from_vec(vec![2i64, 1, 6], (3,), &Device::Cpu)?;
|
||||
|
||||
// Expected output after expansion
|
||||
let expected = Tensor::from_vec(
|
||||
vec![
|
||||
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32,
|
||||
2.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32,
|
||||
1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
|
||||
3.0f32, 3.0f32, 3.0f32,
|
||||
],
|
||||
(2, 3, 6),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
// Execute the model evaluation
|
||||
let inputs = HashMap::from_iter([
|
||||
("data".to_string(), data),
|
||||
("new_shape".to_string(), new_shape),
|
||||
]);
|
||||
let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
|
||||
// Retrieve and compare the result
|
||||
let expanded = result.get("expanded").expect("Output 'expanded' not found");
|
||||
|
||||
assert_eq!(expanded.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn make_graph_helper(
|
||||
op_name: &str,
|
||||
inputs: &[&str],
|
||||
outputs: &[&str],
|
||||
attribs: Vec<AttributeProto>,
|
||||
) -> ModelProto {
|
||||
create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: op_name.to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: attribs,
|
||||
input: inputs.iter().map(|s| s.to_string()).collect(),
|
||||
output: outputs.iter().map(|s| s.to_string()).collect(),
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
input: inputs
|
||||
.into_iter()
|
||||
.map(|name| ValueInfoProto {
|
||||
name: name.to_string(),
|
||||
..ValueInfoProto::default()
|
||||
})
|
||||
.collect(),
|
||||
output: outputs
|
||||
.into_iter()
|
||||
.map(|name| ValueInfoProto {
|
||||
name: name.to_string(),
|
||||
..ValueInfoProto::default()
|
||||
})
|
||||
.collect(),
|
||||
..GraphProto::default()
|
||||
}))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_expand_dim_unchanged() -> Result<()> {
|
||||
// Create a manual graph for the Expand operation
|
||||
let manual_graph = make_graph_helper("Expand", &["data", "new_shape"], &["expanded"], vec![]);
|
||||
|
||||
// Input tensor with shape [3, 1] and dtype f32
|
||||
let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
|
||||
|
||||
// New shape tensor: [3, 4]
|
||||
let new_shape = Tensor::from_vec(vec![3i64, 4], (2,), &Device::Cpu)?;
|
||||
|
||||
// Expected output after expansion, dtype f32
|
||||
let expected = Tensor::from_vec(
|
||||
vec![
|
||||
1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
|
||||
3.0f32,
|
||||
],
|
||||
(3, 4),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
// Execute the model evaluation
|
||||
let inputs = HashMap::from_iter([
|
||||
("data".to_string(), data),
|
||||
("new_shape".to_string(), new_shape),
|
||||
]);
|
||||
let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
|
||||
// Retrieve and compare the result
|
||||
let expanded = result.get("expanded").expect("Output 'expanded' not found");
|
||||
assert_eq!(expanded.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn make_split_graph_helper(inputs: &[&str], outputs: &[&str], axis: i64) -> ModelProto {
|
||||
let attribs = vec![AttributeProto {
|
||||
name: "axis".to_string(),
|
||||
r#type: AttributeType::Int.into(),
|
||||
i: axis,
|
||||
..AttributeProto::default()
|
||||
}];
|
||||
|
||||
make_graph_helper("Split", inputs, outputs, attribs)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_split_equal_parts_1d_opset13() -> Result<()> {
|
||||
let input = Tensor::from_vec(
|
||||
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32],
|
||||
(6,),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("input".to_string(), input);
|
||||
|
||||
{
|
||||
let manual_graph =
|
||||
make_split_graph_helper(&["input"], &["output_1", "output_2", "output_3"], 0);
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
|
||||
assert_eq!(eval.len(), 3);
|
||||
|
||||
let out1 = eval.get("output_1").expect("Output 'output_1' not found");
|
||||
let out2 = eval.get("output_2").expect("Output 'output_2' not found");
|
||||
let out3 = eval.get("output_3").expect("Output 'output_3' not found");
|
||||
|
||||
assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
|
||||
assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32]);
|
||||
assert_eq!(out3.to_vec1::<f32>()?, vec![5.0f32, 6.0f32]);
|
||||
}
|
||||
|
||||
{
|
||||
let splits = Tensor::from_vec(vec![2i64, 4], (2,), &Device::Cpu)?;
|
||||
inputs.insert("split".to_string(), splits);
|
||||
|
||||
let manual_graph =
|
||||
make_split_graph_helper(&["input", "split"], &["output_1", "output_2"], 0);
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 2);
|
||||
|
||||
let out1 = eval.get("output_1").expect("Output 'output_1' not found");
|
||||
let out2 = eval.get("output_2").expect("Output 'output_2' not found");
|
||||
|
||||
assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
|
||||
assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32, 5.0f32, 6.0f32]);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn make_reduce_sum_graph_helper(
|
||||
inputs: &[&str],
|
||||
outputs: &[&str],
|
||||
keepdims: Option<i64>,
|
||||
noop_with_empty_axes: Option<i64>,
|
||||
) -> ModelProto {
|
||||
let mut attribs = vec![];
|
||||
if let Some(keepdims) = keepdims {
|
||||
attribs.push(AttributeProto {
|
||||
name: "keepdims".to_string(),
|
||||
r#type: AttributeType::Int.into(),
|
||||
i: keepdims,
|
||||
..AttributeProto::default()
|
||||
});
|
||||
}
|
||||
if let Some(noop_with_empty_axes) = noop_with_empty_axes {
|
||||
attribs.push(AttributeProto {
|
||||
name: "noop_with_empty_axes".to_string(),
|
||||
r#type: AttributeType::Ints.into(),
|
||||
i: noop_with_empty_axes,
|
||||
..AttributeProto::default()
|
||||
});
|
||||
}
|
||||
make_graph_helper("ReduceSum", inputs, outputs, attribs)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduce_sum_default_axes_keepdims() -> Result<()> {
|
||||
let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(1), None);
|
||||
|
||||
// Test with example data
|
||||
{
|
||||
let data = Tensor::from_vec(
|
||||
vec![
|
||||
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
|
||||
],
|
||||
(3, 2, 2),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
// let axes = Tensor::from_vec(Vec::<i64>::new(), (0,), &Device::Cpu)?;
|
||||
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("data".to_string(), data);
|
||||
// inputs.insert("axes".to_string(), axes);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
|
||||
let expected = Tensor::from_vec(vec![78.0f32], (1, 1, 1), &Device::Cpu)?;
|
||||
|
||||
assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
|
||||
}
|
||||
|
||||
{
|
||||
let data = Tensor::from_vec(
|
||||
vec![
|
||||
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
|
||||
],
|
||||
(3, 2, 2),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("data".to_string(), data.clone());
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
|
||||
let expected = data.sum_all()?.reshape((1, 1, 1))?;
|
||||
|
||||
assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reduce_sum_do_not_keep_dims() -> Result<()> {
|
||||
let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(0), None);
|
||||
|
||||
// Test with example data
|
||||
{
|
||||
let data = Tensor::from_vec(
|
||||
vec![
|
||||
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
|
||||
],
|
||||
(3, 2, 2),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
|
||||
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("data".to_string(), data);
|
||||
inputs.insert("axes".to_string(), axes);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
|
||||
let expected = Tensor::from_vec(
|
||||
vec![4.0f32, 6.0, 12.0, 14.0, 20.0, 22.0],
|
||||
(3, 2),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
|
||||
}
|
||||
|
||||
// Test with random data
|
||||
{
|
||||
let shape = (3, 2, 2);
|
||||
let data = Tensor::from_vec(
|
||||
vec![
|
||||
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
|
||||
],
|
||||
(3, 2, 2),
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
|
||||
|
||||
let mut inputs = HashMap::new();
|
||||
inputs.insert("data".to_string(), data.clone());
|
||||
inputs.insert("axes".to_string(), axes);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
|
||||
|
||||
// Calculate expected result
|
||||
let expected = data.sum(1)?;
|
||||
|
||||
assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user