diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 5b66a743..9c22eeab 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -323,6 +323,13 @@ fn simple_eval_( Some(value) => Ok(value), None => bail!("cannot find {input_name} for op {}", node.name), }; + let get_opt = |i: usize| { + node.input + .get(i) + .filter(|s: &&String| !s.is_empty()) + .map(|s| get(s)) + }; + // TODO: Validate node.input for each operator. match node.op_type.as_str() { "Add" => { @@ -608,15 +615,13 @@ fn simple_eval_( } "Clip" => { let xs = get(&node.input[0])?; - let xs = if node.input.len() >= 2 { - let mins = get(&node.input[1])?; - xs.broadcast_maximum(mins)? + let xs = if let Some(mins) = get_opt(1) { + xs.broadcast_maximum(mins?)? } else { xs.clone() }; - let xs = if node.input.len() >= 3 { - let maxs = get(&node.input[2])?; - xs.broadcast_minimum(maxs)? + let xs = if let Some(maxs) = get_opt(2) { + xs.broadcast_minimum(maxs?)? } else { xs.clone() }; @@ -759,7 +764,14 @@ fn simple_eval_( let cond = get(&node.input[0])?; let a = get(&node.input[1])?; let b = get(&node.input[2])?; - let output = cond.where_cond(a, b)?; + + // where_cond requires that all inputs are the same shape. + // In contrast, the Where op in ONNX only requires that they are broadcastable. + let shape = broadcast_shape_from_many(&[&cond.dims(), &a.dims(), &b.dims()])?; + let cond = cond.broadcast_as(shape.clone())?; + let a = a.broadcast_as(shape.clone())?; + let b = b.broadcast_as(shape)?; + let output = cond.where_cond(&a, &b)?; values.insert(node.output[0].clone(), output); } "Conv" => { @@ -962,6 +974,7 @@ fn simple_eval_( } rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name), }; + values.insert(node.output[0].clone(), output); } // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast @@ -1199,6 +1212,152 @@ fn simple_eval_( }; values.insert(node.output[0].clone(), output); } + //https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split + // Version 18 impl + "Split" => { + let input_tensor = get(&node.input[0])?; + let axis = get_attr_opt::(node, "axis")?.copied().unwrap_or(0); + let axis = input_tensor.normalize_axis(axis)?; + + // Determine split sizes + let splits = if node.input.len() > 1 { + // If the split tensor is provided, use it to determine sizes + let split_tensor = get(&node.input[1])?.to_vec1::()?; + split_tensor.iter().map(|&x| x as usize).collect::>() + } else { + let num_outputs = if let Some(&num_outputs_attrib) = + get_attr_opt::(node, "num_outputs")? + { + num_outputs_attrib as usize + } else { + node.output.len() + }; + + let input_dim = input_tensor.dim(axis)?; + + let mut split_sizes = + vec![input_dim / num_outputs as usize; num_outputs as usize]; + let remainder = input_dim % num_outputs as usize; + if remainder > 0 { + // If there's a remainder, add it to the last split size + split_sizes[num_outputs as usize - 1] += remainder; + } + + split_sizes + }; + + // Perform the split operation + let mut outputs = vec![]; + let mut start = 0; + for &size in &splits { + let end = start + size; + let slice = input_tensor.narrow(axis, start, size)?; + outputs.push(slice); + start = end; + } + + // Insert the split outputs into the values map + for (output, slice) in node.output.iter().zip(outputs.into_iter()) { + values.insert(output.clone(), slice); + } + } + //https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand + // Version 13 impl + "Expand" => { + // unlike broadcast_to, expand allows for the output shape to + // be different from the specified shape. + let input_tensor = get(&node.input[0])?; + let input_shape = get(&node.input[1])?; + + // Check that the shape tensor is 1D + if input_shape.rank() != 1 { + bail!( + "Expand expects 'shape' input to be 1D tensor: {:?}", + input_shape + ); + } + let input_tensor_dims = input_tensor.dims(); + let input_shape_dims = input_shape + .to_vec1::()? + .into_iter() + .map(|x| x as usize) + .collect::>(); + + let target_shape = + broadcast_shape(&input_tensor_dims, input_shape_dims.as_slice())?; + + let expanded_tensor = input_tensor.broadcast_as(target_shape)?; + + values.insert(node.output[0].clone(), expanded_tensor); + } + //https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum + // Version 13 impl + "ReduceSum" => { + let input = get(&node.input[0])?; + let axes = get_opt(1); + let keepdims = get_attr_opt::(node, "keepdims")?.copied().unwrap_or(1); + let noop_with_empty_axes = get_attr_opt::(node, "noop_with_empty_axes")? + .copied() + .unwrap_or(0); + + let axes = match axes { + Some(axes) => axes? + .to_vec1::()? + .into_iter() + .map(|x| x as usize) + .collect::>(), + None => { + if noop_with_empty_axes == 1 { + vec![] + } else { + (0..input.rank()).collect() + } + } + }; + + let output = if keepdims == 1 { + input.sum_keepdim(axes)? + } else { + input.sum(axes)? + }; + + values.insert(node.output[0].clone(), output); + } + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2 + // Version 18 impl + "ReduceL2" => { + let input = get(&node.input[0])?; + let axes = get_opt(1); + let keepdims = get_attr_opt::(node, "keepdims")?.copied().unwrap_or(1); + let noop_with_empty_axes = get_attr_opt::(node, "noop_with_empty_axes")? + .copied() + .unwrap_or(0); + + let input_sq = input.sqr()?; + + let axes = match axes { + Some(axes) => axes? + .to_vec1::()? + .into_iter() + .map(|x| x as usize) + .collect::>(), + None => { + if noop_with_empty_axes == 1 { + vec![] + } else { + (0..input_sq.rank()).collect() + } + } + }; + + let output = if keepdims == 1 { + input_sq.sum_keepdim(axes)?.sqrt()? + } else { + input_sq.sum(axes)?.sqrt()? + }; + + values.insert(node.output[0].clone(), output); + } random_type @ ("RandomUniform" | "RandomNormal") => { let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float // type by @@ -1395,13 +1554,6 @@ fn simple_eval_( // This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`. let r = get(&node.input[2])?; - let get_opt = |i: usize| { - node.input - .get(i) - .filter(|s: &&String| !s.is_empty()) - .map(|s| get(s)) - }; - // The bias tensor for input gate. // Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0. // This tensor has shape `[num_directions, 8*hidden_size]`. @@ -1580,3 +1732,36 @@ fn simple_eval_( }) .collect() } + +fn broadcast_shape(shape_a: &[usize], shape_b: &[usize]) -> Result> { + let (longest, shortest) = if shape_a.len() > shape_b.len() { + (shape_a, shape_b) + } else { + (shape_b, shape_a) + }; + let diff = longest.len() - shortest.len(); + let mut target_shape = longest[0..diff].to_vec(); + for (dim1, dim2) in longest[diff..].iter().zip(shortest.iter()) { + if *dim1 == *dim2 || *dim2 == 1 || *dim1 == 1 { + target_shape.push(usize::max(*dim1, *dim2)); + } else { + bail!( + "Expand: incompatible shapes for broadcast, {:?} and {:?}", + shape_a, + shape_b + ); + } + } + Ok(target_shape) +} + +fn broadcast_shape_from_many(shapes: &[&[usize]]) -> Result> { + if shapes.is_empty() { + return Ok(Vec::new()); + } + let mut shape_out = shapes[0].to_vec(); + for shape in shapes[1..].iter() { + shape_out = broadcast_shape(&shape_out, shape)?; + } + Ok(shape_out) +} diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 51ee037e..55d6fb86 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -3980,3 +3980,332 @@ fn test_lstm() -> Result<()> { Ok(()) } + +#[test] +fn test_expand_dim_changed() -> Result<()> { + // Create a manual graph for the Expand operation + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Expand".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec!["data".to_string(), "new_shape".to_string()], + output: vec!["expanded".to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + input: vec![ + ValueInfoProto { + name: "data".to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: "new_shape".to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: "expanded".to_string(), + doc_string: "".to_string(), + r#type: None, + }], + ..GraphProto::default() + })); + + // Input tensor with shape [3, 1] + let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?; + + // New shape tensor: [2, 1, 6] + let new_shape = Tensor::from_vec(vec![2i64, 1, 6], (3,), &Device::Cpu)?; + + // Expected output after expansion + let expected = Tensor::from_vec( + vec![ + 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, + 2.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 3.0f32, 1.0f32, 1.0f32, 1.0f32, 1.0f32, + 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32, + 3.0f32, 3.0f32, 3.0f32, + ], + (2, 3, 6), + &Device::Cpu, + )?; + + // Execute the model evaluation + let inputs = HashMap::from_iter([ + ("data".to_string(), data), + ("new_shape".to_string(), new_shape), + ]); + let result = candle_onnx::simple_eval(&manual_graph, inputs)?; + + // Retrieve and compare the result + let expanded = result.get("expanded").expect("Output 'expanded' not found"); + + assert_eq!(expanded.to_vec3::()?, expected.to_vec3::()?); + + Ok(()) +} + +fn make_graph_helper( + op_name: &str, + inputs: &[&str], + outputs: &[&str], + attribs: Vec, +) -> ModelProto { + create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: op_name.to_string(), + domain: "".to_string(), + attribute: attribs, + input: inputs.iter().map(|s| s.to_string()).collect(), + output: outputs.iter().map(|s| s.to_string()).collect(), + name: "".to_string(), + doc_string: "".to_string(), + }], + input: inputs + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + ..ValueInfoProto::default() + }) + .collect(), + output: outputs + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + ..ValueInfoProto::default() + }) + .collect(), + ..GraphProto::default() + })) +} + +#[test] +fn test_expand_dim_unchanged() -> Result<()> { + // Create a manual graph for the Expand operation + let manual_graph = make_graph_helper("Expand", &["data", "new_shape"], &["expanded"], vec![]); + + // Input tensor with shape [3, 1] and dtype f32 + let data = Tensor::from_vec(vec![1.0f32, 2.0f32, 3.0f32], (3, 1), &Device::Cpu)?; + + // New shape tensor: [3, 4] + let new_shape = Tensor::from_vec(vec![3i64, 4], (2,), &Device::Cpu)?; + + // Expected output after expansion, dtype f32 + let expected = Tensor::from_vec( + vec![ + 1.0f32, 1.0f32, 1.0f32, 1.0f32, 2.0f32, 2.0f32, 2.0f32, 2.0f32, 3.0f32, 3.0f32, 3.0f32, + 3.0f32, + ], + (3, 4), + &Device::Cpu, + )?; + + // Execute the model evaluation + let inputs = HashMap::from_iter([ + ("data".to_string(), data), + ("new_shape".to_string(), new_shape), + ]); + let result = candle_onnx::simple_eval(&manual_graph, inputs)?; + + // Retrieve and compare the result + let expanded = result.get("expanded").expect("Output 'expanded' not found"); + assert_eq!(expanded.to_vec2::()?, expected.to_vec2::()?); + + Ok(()) +} + +fn make_split_graph_helper(inputs: &[&str], outputs: &[&str], axis: i64) -> ModelProto { + let attribs = vec![AttributeProto { + name: "axis".to_string(), + r#type: AttributeType::Int.into(), + i: axis, + ..AttributeProto::default() + }]; + + make_graph_helper("Split", inputs, outputs, attribs) +} + +#[test] +fn test_split_equal_parts_1d_opset13() -> Result<()> { + let input = Tensor::from_vec( + vec![1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32, 6.0f32], + (6,), + &Device::Cpu, + )?; + let mut inputs = HashMap::new(); + inputs.insert("input".to_string(), input); + + { + let manual_graph = + make_split_graph_helper(&["input"], &["output_1", "output_2", "output_3"], 0); + let eval = candle_onnx::simple_eval(&manual_graph, inputs.clone())?; + assert_eq!(eval.len(), 3); + + let out1 = eval.get("output_1").expect("Output 'output_1' not found"); + let out2 = eval.get("output_2").expect("Output 'output_2' not found"); + let out3 = eval.get("output_3").expect("Output 'output_3' not found"); + + assert_eq!(out1.to_vec1::()?, vec![1.0f32, 2.0f32]); + assert_eq!(out2.to_vec1::()?, vec![3.0f32, 4.0f32]); + assert_eq!(out3.to_vec1::()?, vec![5.0f32, 6.0f32]); + } + + { + let splits = Tensor::from_vec(vec![2i64, 4], (2,), &Device::Cpu)?; + inputs.insert("split".to_string(), splits); + + let manual_graph = + make_split_graph_helper(&["input", "split"], &["output_1", "output_2"], 0); + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 2); + + let out1 = eval.get("output_1").expect("Output 'output_1' not found"); + let out2 = eval.get("output_2").expect("Output 'output_2' not found"); + + assert_eq!(out1.to_vec1::()?, vec![1.0f32, 2.0f32]); + assert_eq!(out2.to_vec1::()?, vec![3.0f32, 4.0f32, 5.0f32, 6.0f32]); + } + Ok(()) +} + +fn make_reduce_sum_graph_helper( + inputs: &[&str], + outputs: &[&str], + keepdims: Option, + noop_with_empty_axes: Option, +) -> ModelProto { + let mut attribs = vec![]; + if let Some(keepdims) = keepdims { + attribs.push(AttributeProto { + name: "keepdims".to_string(), + r#type: AttributeType::Int.into(), + i: keepdims, + ..AttributeProto::default() + }); + } + if let Some(noop_with_empty_axes) = noop_with_empty_axes { + attribs.push(AttributeProto { + name: "noop_with_empty_axes".to_string(), + r#type: AttributeType::Ints.into(), + i: noop_with_empty_axes, + ..AttributeProto::default() + }); + } + make_graph_helper("ReduceSum", inputs, outputs, attribs) +} + +#[test] +fn test_reduce_sum_default_axes_keepdims() -> Result<()> { + let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(1), None); + + // Test with example data + { + let data = Tensor::from_vec( + vec![ + 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + (3, 2, 2), + &Device::Cpu, + )?; + // let axes = Tensor::from_vec(Vec::::new(), (0,), &Device::Cpu)?; + + let mut inputs = HashMap::new(); + inputs.insert("data".to_string(), data); + // inputs.insert("axes".to_string(), axes); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let reduced = eval.get("reduced").expect("Output 'reduced' not found"); + let expected = Tensor::from_vec(vec![78.0f32], (1, 1, 1), &Device::Cpu)?; + + assert_eq!(reduced.to_vec3::()?, expected.to_vec3::()?); + } + + { + let data = Tensor::from_vec( + vec![ + -5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6, + ], + (3, 2, 2), + &Device::Cpu, + )?; + + let mut inputs = HashMap::new(); + inputs.insert("data".to_string(), data.clone()); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let reduced = eval.get("reduced").expect("Output 'reduced' not found"); + let expected = data.sum_all()?.reshape((1, 1, 1))?; + + assert_eq!(reduced.to_vec3::()?, expected.to_vec3::()?); + } + + Ok(()) +} + +#[test] +fn test_reduce_sum_do_not_keep_dims() -> Result<()> { + let manual_graph = make_reduce_sum_graph_helper(&["data", "axes"], &["reduced"], Some(0), None); + + // Test with example data + { + let data = Tensor::from_vec( + vec![ + 1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + ], + (3, 2, 2), + &Device::Cpu, + )?; + let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?; + + let mut inputs = HashMap::new(); + inputs.insert("data".to_string(), data); + inputs.insert("axes".to_string(), axes); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let reduced = eval.get("reduced").expect("Output 'reduced' not found"); + let expected = Tensor::from_vec( + vec![4.0f32, 6.0, 12.0, 14.0, 20.0, 22.0], + (3, 2), + &Device::Cpu, + )?; + + assert_eq!(reduced.to_vec2::()?, expected.to_vec2::()?); + } + + // Test with random data + { + let shape = (3, 2, 2); + let data = Tensor::from_vec( + vec![ + -5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6, + ], + (3, 2, 2), + &Device::Cpu, + )?; + let axes = Tensor::from_vec(vec![1i64], (1,), &Device::Cpu)?; + + let mut inputs = HashMap::new(); + inputs.insert("data".to_string(), data.clone()); + inputs.insert("axes".to_string(), axes); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let reduced = eval.get("reduced").expect("Output 'reduced' not found"); + + // Calculate expected result + let expected = data.sum(1)?; + + assert_eq!(reduced.to_vec2::()?, expected.to_vec2::()?); + } + + Ok(()) +}