mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Expand split ops (#2505)
* candle-onnx: Add Split and Expand operators, Fix Where Op Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md Test cases based on those examples. TODO: Should add the remaining Split examples as tests TODO: Add.test case that motivates Where fix * candle-onnx: Add ReduceSum operator Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md Test cases based on those examples. TODO: Should add the remaining ReduceSum examples as tests * candle-onnx: Add ReduceL2 operator Implemented based on https://github.com/onnx/onnx/blob/main/docs/Operators.md Test cases based on those examples. TODO: Should add the remaining ReduceSum examples as tests * candle-onnx: Fix Clip operator empty string as default arg issue Optional input args may be signified by an empty string. The length of the input array is not enough because non optional args may follow optional ones. I encountered this when trying to use the ONNX model found at https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2 for example. The LSTM op has a utility which I factored to be more generally accessible, and I have used it in the ops I have recently created or debugged. I believe it is likely that this issue may also manifest in other ops, but I didn't want to change anything that I'm not testing. * fix formatting * fix small mistake made during refactor
This commit is contained in:
@ -323,6 +323,13 @@ fn simple_eval_(
|
|||||||
Some(value) => Ok(value),
|
Some(value) => Ok(value),
|
||||||
None => bail!("cannot find {input_name} for op {}", node.name),
|
None => bail!("cannot find {input_name} for op {}", node.name),
|
||||||
};
|
};
|
||||||
|
let get_opt = |i: usize| {
|
||||||
|
node.input
|
||||||
|
.get(i)
|
||||||
|
.filter(|s: &&String| !s.is_empty())
|
||||||
|
.map(|s| get(s))
|
||||||
|
};
|
||||||
|
|
||||||
// TODO: Validate node.input for each operator.
|
// TODO: Validate node.input for each operator.
|
||||||
match node.op_type.as_str() {
|
match node.op_type.as_str() {
|
||||||
"Add" => {
|
"Add" => {
|
||||||
@ -608,15 +615,13 @@ fn simple_eval_(
|
|||||||
}
|
}
|
||||||
"Clip" => {
|
"Clip" => {
|
||||||
let xs = get(&node.input[0])?;
|
let xs = get(&node.input[0])?;
|
||||||
let xs = if node.input.len() >= 2 {
|
let xs = if let Some(mins) = get_opt(1) {
|
||||||
let mins = get(&node.input[1])?;
|
xs.broadcast_maximum(mins?)?
|
||||||
xs.broadcast_maximum(mins)?
|
|
||||||
} else {
|
} else {
|
||||||
xs.clone()
|
xs.clone()
|
||||||
};
|
};
|
||||||
let xs = if node.input.len() >= 3 {
|
let xs = if let Some(maxs) = get_opt(2) {
|
||||||
let maxs = get(&node.input[2])?;
|
xs.broadcast_minimum(maxs?)?
|
||||||
xs.broadcast_minimum(maxs)?
|
|
||||||
} else {
|
} else {
|
||||||
xs.clone()
|
xs.clone()
|
||||||
};
|
};
|
||||||
@ -759,7 +764,14 @@ fn simple_eval_(
|
|||||||
let cond = get(&node.input[0])?;
|
let cond = get(&node.input[0])?;
|
||||||
let a = get(&node.input[1])?;
|
let a = get(&node.input[1])?;
|
||||||
let b = get(&node.input[2])?;
|
let b = get(&node.input[2])?;
|
||||||
let output = cond.where_cond(a, b)?;
|
|
||||||
|
// where_cond requires that all inputs are the same shape.
|
||||||
|
// In contrast, the Where op in ONNX only requires that they are broadcastable.
|
||||||
|
let shape = broadcast_shape_from_many(&[&cond.dims(), &a.dims(), &b.dims()])?;
|
||||||
|
let cond = cond.broadcast_as(shape.clone())?;
|
||||||
|
let a = a.broadcast_as(shape.clone())?;
|
||||||
|
let b = b.broadcast_as(shape)?;
|
||||||
|
let output = cond.where_cond(&a, &b)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
"Conv" => {
|
"Conv" => {
|
||||||
@ -962,6 +974,7 @@ fn simple_eval_(
|
|||||||
}
|
}
|
||||||
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
|
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
|
||||||
};
|
};
|
||||||
|
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
|
||||||
@ -1199,6 +1212,152 @@ fn simple_eval_(
|
|||||||
};
|
};
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
|
||||||
|
// Version 18 impl
|
||||||
|
"Split" => {
|
||||||
|
let input_tensor = get(&node.input[0])?;
|
||||||
|
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
|
||||||
|
let axis = input_tensor.normalize_axis(axis)?;
|
||||||
|
|
||||||
|
// Determine split sizes
|
||||||
|
let splits = if node.input.len() > 1 {
|
||||||
|
// If the split tensor is provided, use it to determine sizes
|
||||||
|
let split_tensor = get(&node.input[1])?.to_vec1::<i64>()?;
|
||||||
|
split_tensor.iter().map(|&x| x as usize).collect::<Vec<_>>()
|
||||||
|
} else {
|
||||||
|
let num_outputs = if let Some(&num_outputs_attrib) =
|
||||||
|
get_attr_opt::<i64>(node, "num_outputs")?
|
||||||
|
{
|
||||||
|
num_outputs_attrib as usize
|
||||||
|
} else {
|
||||||
|
node.output.len()
|
||||||
|
};
|
||||||
|
|
||||||
|
let input_dim = input_tensor.dim(axis)?;
|
||||||
|
|
||||||
|
let mut split_sizes =
|
||||||
|
vec![input_dim / num_outputs as usize; num_outputs as usize];
|
||||||
|
let remainder = input_dim % num_outputs as usize;
|
||||||
|
if remainder > 0 {
|
||||||
|
// If there's a remainder, add it to the last split size
|
||||||
|
split_sizes[num_outputs as usize - 1] += remainder;
|
||||||
|
}
|
||||||
|
|
||||||
|
split_sizes
|
||||||
|
};
|
||||||
|
|
||||||
|
// Perform the split operation
|
||||||
|
let mut outputs = vec![];
|
||||||
|
let mut start = 0;
|
||||||
|
for &size in &splits {
|
||||||
|
let end = start + size;
|
||||||
|
let slice = input_tensor.narrow(axis, start, size)?;
|
||||||
|
outputs.push(slice);
|
||||||
|
start = end;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert the split outputs into the values map
|
||||||
|
for (output, slice) in node.output.iter().zip(outputs.into_iter()) {
|
||||||
|
values.insert(output.clone(), slice);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand
|
||||||
|
// Version 13 impl
|
||||||
|
"Expand" => {
|
||||||
|
// unlike broadcast_to, expand allows for the output shape to
|
||||||
|
// be different from the specified shape.
|
||||||
|
let input_tensor = get(&node.input[0])?;
|
||||||
|
let input_shape = get(&node.input[1])?;
|
||||||
|
|
||||||
|
// Check that the shape tensor is 1D
|
||||||
|
if input_shape.rank() != 1 {
|
||||||
|
bail!(
|
||||||
|
"Expand expects 'shape' input to be 1D tensor: {:?}",
|
||||||
|
input_shape
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let input_tensor_dims = input_tensor.dims();
|
||||||
|
let input_shape_dims = input_shape
|
||||||
|
.to_vec1::<i64>()?
|
||||||
|
.into_iter()
|
||||||
|
.map(|x| x as usize)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let target_shape =
|
||||||
|
broadcast_shape(&input_tensor_dims, input_shape_dims.as_slice())?;
|
||||||
|
|
||||||
|
let expanded_tensor = input_tensor.broadcast_as(target_shape)?;
|
||||||
|
|
||||||
|
values.insert(node.output[0].clone(), expanded_tensor);
|
||||||
|
}
|
||||||
|
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum
|
||||||
|
// Version 13 impl
|
||||||
|
"ReduceSum" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let axes = get_opt(1);
|
||||||
|
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
|
||||||
|
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
let axes = match axes {
|
||||||
|
Some(axes) => axes?
|
||||||
|
.to_vec1::<i64>()?
|
||||||
|
.into_iter()
|
||||||
|
.map(|x| x as usize)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => {
|
||||||
|
if noop_with_empty_axes == 1 {
|
||||||
|
vec![]
|
||||||
|
} else {
|
||||||
|
(0..input.rank()).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let output = if keepdims == 1 {
|
||||||
|
input.sum_keepdim(axes)?
|
||||||
|
} else {
|
||||||
|
input.sum(axes)?
|
||||||
|
};
|
||||||
|
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2
|
||||||
|
// Version 18 impl
|
||||||
|
"ReduceL2" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let axes = get_opt(1);
|
||||||
|
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
|
||||||
|
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
|
||||||
|
.copied()
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
let input_sq = input.sqr()?;
|
||||||
|
|
||||||
|
let axes = match axes {
|
||||||
|
Some(axes) => axes?
|
||||||
|
.to_vec1::<i64>()?
|
||||||
|
.into_iter()
|
||||||
|
.map(|x| x as usize)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
None => {
|
||||||
|
if noop_with_empty_axes == 1 {
|
||||||
|
vec![]
|
||||||
|
} else {
|
||||||
|
(0..input_sq.rank()).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let output = if keepdims == 1 {
|
||||||
|
input_sq.sum_keepdim(axes)?.sqrt()?
|
||||||
|
} else {
|
||||||
|
input_sq.sum(axes)?.sqrt()?
|
||||||
|
};
|
||||||
|
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
random_type @ ("RandomUniform" | "RandomNormal") => {
|
random_type @ ("RandomUniform" | "RandomNormal") => {
|
||||||
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
|
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
|
||||||
// type by
|
// type by
|
||||||
@ -1395,13 +1554,6 @@ fn simple_eval_(
|
|||||||
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
|
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
|
||||||
let r = get(&node.input[2])?;
|
let r = get(&node.input[2])?;
|
||||||
|
|
||||||
let get_opt = |i: usize| {
|
|
||||||
node.input
|
|
||||||
.get(i)
|
|
||||||
.filter(|s: &&String| !s.is_empty())
|
|
||||||
.map(|s| get(s))
|
|
||||||
};
|
|
||||||
|
|
||||||
// The bias tensor for input gate.
|
// The bias tensor for input gate.
|
||||||
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
|
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
|
||||||
// This tensor has shape `[num_directions, 8*hidden_size]`.
|
// This tensor has shape `[num_directions, 8*hidden_size]`.
|
||||||
@ -1580,3 +1732,36 @@ fn simple_eval_(
|
|||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn broadcast_shape(shape_a: &[usize], shape_b: &[usize]) -> Result<Vec<usize>> {
|
||||||
|
let (longest, shortest) = if shape_a.len() > shape_b.len() {
|
||||||
|
(shape_a, shape_b)
|
||||||
|
} else {
|
||||||
|
(shape_b, shape_a)
|
||||||
|
};
|
||||||
|
let diff = longest.len() - shortest.len();
|
||||||
|
let mut target_shape = longest[0..diff].to_vec();
|
||||||
|
for (dim1, dim2) in longest[diff..].iter().zip(shortest.iter()) {
|
||||||
|
if *dim1 == *dim2 || *dim2 == 1 || *dim1 == 1 {
|
||||||
|
target_shape.push(usize::max(*dim1, *dim2));
|
||||||
|
} else {
|
||||||
|
bail!(
|
||||||
|
"Expand: incompatible shapes for broadcast, {:?} and {:?}",
|
||||||
|
shape_a,
|
||||||
|
shape_b
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(target_shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn broadcast_shape_from_many(shapes: &[&[usize]]) -> Result<Vec<usize>> {
|
||||||
|
if shapes.is_empty() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
let mut shape_out = shapes[0].to_vec();
|
||||||
|
for shape in shapes[1..].iter() {
|
||||||
|
shape_out = broadcast_shape(&shape_out, shape)?;
|
||||||
|
}
|
||||||
|
Ok(shape_out)
|
||||||
|
}
|
||||||
|
@ -3980,3 +3980,332 @@ fn test_lstm() -> Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_expand_dim_changed() -> Result<()> {
|
||||||
|
// Create a manual graph for the Expand operation
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Expand".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec!["data".to_string(), "new_shape".to_string()],
|
||||||
|
output: vec!["expanded".to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
input: vec![
|
||||||
|
ValueInfoProto {
|
||||||
|
name: "data".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
ValueInfoProto {
|
||||||
|
name: "new_shape".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: "expanded".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
..GraphProto::default()
|
||||||
|
}));
|
||||||
|
|
||||||
|
// Input tensor with shape [3, 1]
|
||||||
|
let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
|
||||||
|
|
||||||
|
// New shape tensor: [2, 1, 6]
|
||||||
|
let new_shape = Tensor::from_vec(vec![2i64, 1, 6], (3,), &Device::Cpu)?;
|
||||||
|
|
||||||
|
// Expected output after expansion
|
||||||
|
let expected = Tensor::from_vec(
|
||||||
|
vec![
|
||||||
|
1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32,
|
||||||
|
2.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32,
|
||||||
|
1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
|
||||||
|
3.0f32, 3.0f32, 3.0f32,
|
||||||
|
],
|
||||||
|
(2, 3, 6),
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Execute the model evaluation
|
||||||
|
let inputs = HashMap::from_iter([
|
||||||
|
("data".to_string(), data),
|
||||||
|
("new_shape".to_string(), new_shape),
|
||||||
|
]);
|
||||||
|
let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
|
||||||
|
// Retrieve and compare the result
|
||||||
|
let expanded = result.get("expanded").expect("Output 'expanded' not found");
|
||||||
|
|
||||||
|
assert_eq!(expanded.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_graph_helper(
|
||||||
|
op_name: &str,
|
||||||
|
inputs: &[&str],
|
||||||
|
outputs: &[&str],
|
||||||
|
attribs: Vec<AttributeProto>,
|
||||||
|
) -> ModelProto {
|
||||||
|
create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: op_name.to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: attribs,
|
||||||
|
input: inputs.iter().map(|s| s.to_string()).collect(),
|
||||||
|
output: outputs.iter().map(|s| s.to_string()).collect(),
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
input: inputs
|
||||||
|
.into_iter()
|
||||||
|
.map(|name| ValueInfoProto {
|
||||||
|
name: name.to_string(),
|
||||||
|
..ValueInfoProto::default()
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
output: outputs
|
||||||
|
.into_iter()
|
||||||
|
.map(|name| ValueInfoProto {
|
||||||
|
name: name.to_string(),
|
||||||
|
..ValueInfoProto::default()
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
..GraphProto::default()
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_expand_dim_unchanged() -> Result<()> {
|
||||||
|
// Create a manual graph for the Expand operation
|
||||||
|
let manual_graph = make_graph_helper("Expand", &["data", "new_shape"], &["expanded"], vec![]);
|
||||||
|
|
||||||
|
// Input tensor with shape [3, 1] and dtype f32
|
||||||
|
let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?;
|
||||||
|
|
||||||
|
// New shape tensor: [3, 4]
|
||||||
|
let new_shape = Tensor::from_vec(vec![3i64, 4], (2,), &Device::Cpu)?;
|
||||||
|
|
||||||
|
// Expected output after expansion, dtype f32
|
||||||
|
let expected = Tensor::from_vec(
|
||||||
|
vec![
|
||||||
|
1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32,
|
||||||
|
3.0f32,
|
||||||
|
],
|
||||||
|
(3, 4),
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Execute the model evaluation
|
||||||
|
let inputs = HashMap::from_iter([
|
||||||
|
("data".to_string(), data),
|
||||||
|
("new_shape".to_string(), new_shape),
|
||||||
|
]);
|
||||||
|
let result = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
|
||||||
|
// Retrieve and compare the result
|
||||||
|
let expanded = result.get("expanded").expect("Output 'expanded' not found");
|
||||||
|
assert_eq!(expanded.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_split_graph_helper(inputs: &[&str], outputs: &[&str], axis: i64) -> ModelProto {
|
||||||
|
let attribs = vec![AttributeProto {
|
||||||
|
name: "axis".to_string(),
|
||||||
|
r#type: AttributeType::Int.into(),
|
||||||
|
i: axis,
|
||||||
|
..AttributeProto::default()
|
||||||
|
}];
|
||||||
|
|
||||||
|
make_graph_helper("Split", inputs, outputs, attribs)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_split_equal_parts_1d_opset13() -> Result<()> {
|
||||||
|
let input = Tensor::from_vec(
|
||||||
|
vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32],
|
||||||
|
(6,),
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
let mut inputs = HashMap::new();
|
||||||
|
inputs.insert("input".to_string(), input);
|
||||||
|
|
||||||
|
{
|
||||||
|
let manual_graph =
|
||||||
|
make_split_graph_helper(&["input"], &["output_1", "output_2", "output_3"], 0);
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?;
|
||||||
|
assert_eq!(eval.len(), 3);
|
||||||
|
|
||||||
|
let out1 = eval.get("output_1").expect("Output 'output_1' not found");
|
||||||
|
let out2 = eval.get("output_2").expect("Output 'output_2' not found");
|
||||||
|
let out3 = eval.get("output_3").expect("Output 'output_3' not found");
|
||||||
|
|
||||||
|
assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
|
||||||
|
assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32]);
|
||||||
|
assert_eq!(out3.to_vec1::<f32>()?, vec![5.0f32, 6.0f32]);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let splits = Tensor::from_vec(vec![2i64, 4], (2,), &Device::Cpu)?;
|
||||||
|
inputs.insert("split".to_string(), splits);
|
||||||
|
|
||||||
|
let manual_graph =
|
||||||
|
make_split_graph_helper(&["input", "split"], &["output_1", "output_2"], 0);
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 2);
|
||||||
|
|
||||||
|
let out1 = eval.get("output_1").expect("Output 'output_1' not found");
|
||||||
|
let out2 = eval.get("output_2").expect("Output 'output_2' not found");
|
||||||
|
|
||||||
|
assert_eq!(out1.to_vec1::<f32>()?, vec![1.0f32, 2.0f32]);
|
||||||
|
assert_eq!(out2.to_vec1::<f32>()?, vec![3.0f32, 4.0f32, 5.0f32, 6.0f32]);
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_reduce_sum_graph_helper(
|
||||||
|
inputs: &[&str],
|
||||||
|
outputs: &[&str],
|
||||||
|
keepdims: Option<i64>,
|
||||||
|
noop_with_empty_axes: Option<i64>,
|
||||||
|
) -> ModelProto {
|
||||||
|
let mut attribs = vec![];
|
||||||
|
if let Some(keepdims) = keepdims {
|
||||||
|
attribs.push(AttributeProto {
|
||||||
|
name: "keepdims".to_string(),
|
||||||
|
r#type: AttributeType::Int.into(),
|
||||||
|
i: keepdims,
|
||||||
|
..AttributeProto::default()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if let Some(noop_with_empty_axes) = noop_with_empty_axes {
|
||||||
|
attribs.push(AttributeProto {
|
||||||
|
name: "noop_with_empty_axes".to_string(),
|
||||||
|
r#type: AttributeType::Ints.into(),
|
||||||
|
i: noop_with_empty_axes,
|
||||||
|
..AttributeProto::default()
|
||||||
|
});
|
||||||
|
}
|
||||||
|
make_graph_helper("ReduceSum", inputs, outputs, attribs)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_reduce_sum_default_axes_keepdims() -> Result<()> {
|
||||||
|
let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(1), None);
|
||||||
|
|
||||||
|
// Test with example data
|
||||||
|
{
|
||||||
|
let data = Tensor::from_vec(
|
||||||
|
vec![
|
||||||
|
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
|
||||||
|
],
|
||||||
|
(3, 2, 2),
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
// let axes = Tensor::from_vec(Vec::<i64>::new(), (0,), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let mut inputs = HashMap::new();
|
||||||
|
inputs.insert("data".to_string(), data);
|
||||||
|
// inputs.insert("axes".to_string(), axes);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
|
||||||
|
let expected = Tensor::from_vec(vec![78.0f32], (1, 1, 1), &Device::Cpu)?;
|
||||||
|
|
||||||
|
assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let data = Tensor::from_vec(
|
||||||
|
vec![
|
||||||
|
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
|
||||||
|
],
|
||||||
|
(3, 2, 2),
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut inputs = HashMap::new();
|
||||||
|
inputs.insert("data".to_string(), data.clone());
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
|
||||||
|
let expected = data.sum_all()?.reshape((1, 1, 1))?;
|
||||||
|
|
||||||
|
assert_eq!(reduced.to_vec3::<f32>()?, expected.to_vec3::<f32>()?);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_reduce_sum_do_not_keep_dims() -> Result<()> {
|
||||||
|
let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(0), None);
|
||||||
|
|
||||||
|
// Test with example data
|
||||||
|
{
|
||||||
|
let data = Tensor::from_vec(
|
||||||
|
vec![
|
||||||
|
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
|
||||||
|
],
|
||||||
|
(3, 2, 2),
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let mut inputs = HashMap::new();
|
||||||
|
inputs.insert("data".to_string(), data);
|
||||||
|
inputs.insert("axes".to_string(), axes);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
|
||||||
|
let expected = Tensor::from_vec(
|
||||||
|
vec![4.0f32, 6.0, 12.0, 14.0, 20.0, 22.0],
|
||||||
|
(3, 2),
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with random data
|
||||||
|
{
|
||||||
|
let shape = (3, 2, 2);
|
||||||
|
let data = Tensor::from_vec(
|
||||||
|
vec![
|
||||||
|
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,
|
||||||
|
],
|
||||||
|
(3, 2, 2),
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let mut inputs = HashMap::new();
|
||||||
|
inputs.insert("data".to_string(), data.clone());
|
||||||
|
inputs.insert("axes".to_string(), axes);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let reduced = eval.get("reduced").expect("Output 'reduced' not found");
|
||||||
|
|
||||||
|
// Calculate expected result
|
||||||
|
let expected = data.sum(1)?;
|
||||||
|
|
||||||
|
assert_eq!(reduced.to_vec2::<f32>()?, expected.to_vec2::<f32>()?);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user