candle-onnx: Implement Trilu and ScatterND ops (#2952)

* onnx attention

* setup an example, adding and fixing onnx ops bit by bit

* model working, output is garbage data

* trilu working

* close but not quite, Issues still with scatterND

* closer but the outputs are still slightly wrong

* added tests for trilu and scatterND

* lint

* readme

* clippy

* removed unnessisary comments

* changed device selection, took hyperparameters from model config
This commit is contained in:
Kyle Birnbaum
2025-05-29 22:36:09 -07:00
committed by GitHub
parent 5aed817f1b
commit cd7b877d6b
5 changed files with 950 additions and 5 deletions

View File

@ -583,7 +583,13 @@ fn simple_eval_(
&Device::Cpu,
)?);
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
let shape_vec: Vec<usize> = input
.to_vec1::<i64>()?
.iter()
.map(|&x| x as usize)
.collect();
let xs = Tensor::ones(shape_vec, value.dtype(), input.device())?
.broadcast_mul(&value)?;
values.insert(node.output[0].clone(), xs);
}
@ -1238,7 +1244,7 @@ fn simple_eval_(
}
let indexes = Tensor::arange_step(s, e, p, data.device())?;
out = out.index_select(&indexes, axis)?
out = out.contiguous()?.index_select(&indexes, axis)?
}
values.insert(node.output[0].clone(), out);
}
@ -2030,6 +2036,203 @@ fn simple_eval_(
values.insert(node.output[0].clone(), output);
}
"Trilu" => {
let input = get(&node.input[0])?;
// Get the diagonal offset 'k' from the second input if provided
let k = if node.input.len() > 1 && !node.input[1].is_empty() {
get(&node.input[1])?.to_vec0::<i64>()?
} else {
0
};
// Get the 'upper' attribute
let upper = get_attr_opt::<i64>(node, "upper")?.copied().unwrap_or(1);
// For batched inputs, we need to handle each matrix separately
let dims = input.dims();
if dims.len() < 2 {
bail!("Trilu expects input with at least 2 dimensions: {:?}", dims);
}
// Get the last two dimensions which represent the matrix
let n = dims[dims.len() - 2];
let m = dims[dims.len() - 1];
let max_dim = std::cmp::max(n, m);
// Handle the diagonal offset k
let mask = if k != 0 {
let mut data = vec![0u32; n * m];
for i in 0..n {
for j in 0..m {
if (upper != 0 && (j as i64) >= (i as i64) + k)
|| (upper == 0 && (j as i64) <= (i as i64) + k)
{
data[i * m + j] = 1u32;
}
}
}
Tensor::from_vec(data, (n, m), input.device())?.to_dtype(input.dtype())?
} else if upper == 0 {
Tensor::tril2(max_dim, input.dtype(), input.device())?
} else {
Tensor::triu2(max_dim, input.dtype(), input.device())?
};
let final_mask = if n != m {
mask.narrow(0, 0, n)?.narrow(1, 0, m)?
} else {
mask
};
let output = (input * &final_mask)?;
values.insert(node.output[0].clone(), output);
}
"ScatterND" => {
let data = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let indices = indices.to_dtype(DType::I64)?;
let updates = get(&node.input[2])?;
let reduction = get_attr_opt::<str>(node, "reduction")?.unwrap_or("none");
let indices_shape = indices.dims();
let data_shape = data.dims();
let updates_shape = updates.dims();
// Last dimension of indices represents the depth of indexing
let k = indices_shape.last().unwrap().clone();
if k > data.rank() {
bail!("ScatterND expects k (indices.shape[-1]) to be at most the rank of data");
}
let num_updates = indices_shape[..indices_shape.len() - 1]
.iter()
.product::<usize>();
let flat_indices = if indices.rank() == 1 && k == 1 {
indices.unsqueeze(0)?
} else {
indices.reshape((num_updates, k))?
};
// Calculate the shape of each update element
let update_element_shape = if k < data_shape.len() {
data_shape[k..].to_vec()
} else {
vec![]
};
// Expected shape for updates based on indices and target tensor
let expected_updates_shape = {
let mut shape = indices_shape[..indices_shape.len() - 1].to_vec();
shape.extend(&update_element_shape);
shape
};
// Validate or reshape updates to expected shape
let updates = if updates.dims() != expected_updates_shape {
if updates.rank() == 0 {
// Handle scalar updates
let mut target_shape = vec![num_updates];
target_shape.extend(&update_element_shape);
updates.broadcast_as(target_shape)?
} else {
// Try to broadcast or reshape updates to expected shape
let flat_shape =
vec![num_updates, update_element_shape.iter().product::<usize>()];
let flattened = updates.reshape(flat_shape)?;
flattened.reshape(expected_updates_shape)?
}
} else {
updates.clone()
};
let mut output = data.clone();
// convert indices to flat indices
let mut flat_output = output.flatten_all()?;
let flat_updates = if update_element_shape.is_empty() {
updates.reshape(num_updates)?
} else {
let product = update_element_shape.iter().product::<usize>();
updates.reshape((num_updates, product))?
};
// Calculate strides for the output tensor
let mut strides: Vec<usize> = vec![1];
for i in (0..data_shape.len() - 1).rev() {
strides.push(strides.last().unwrap() * data_shape[i + 1]);
}
strides.reverse();
// Process each update
for i in 0..num_updates {
let index_slice = flat_indices.narrow(0, i, 1)?;
let indices_vec = index_slice.squeeze(0)?.to_vec1::<i64>()?;
// Convert multi-dimensional indices to flat index
let mut flat_idx: usize = 0;
for (dim, &idx) in indices_vec.iter().enumerate() {
let dim_size = data_shape[dim] as i64;
let norm_idx = if idx < 0 { dim_size + idx } else { idx };
if norm_idx < 0 || norm_idx >= dim_size {
bail!(
"Index {} out of bounds for dimension {} with size {}",
idx,
dim,
dim_size
);
}
flat_idx += (norm_idx as usize) * strides[dim];
}
// Extract current update
let update_slice = if update_element_shape.is_empty() {
flat_updates.narrow(0, i, 1)?.squeeze(0)?
} else {
flat_updates.narrow(0, i, 1)?
};
match reduction {
"add" => {
if update_element_shape.is_empty() {
let existing = flat_output.narrow(0, flat_idx, 1)?;
let new_value = existing.add(&update_slice.unsqueeze(0)?)?;
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
} else {
let slice_size = update_element_shape.iter().product::<usize>();
let existing = flat_output.narrow(0, flat_idx, slice_size)?;
let new_value = existing.add(&update_slice)?;
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
}
}
"none" | _ => {
if update_element_shape.is_empty() {
flat_output = flat_output.slice_scatter(
&update_slice.unsqueeze(0)?,
0,
flat_idx,
)?;
} else {
flat_output =
flat_output.slice_scatter(&update_slice, 0, flat_idx)?;
}
}
}
}
// Reshape flat output back to original shape
output = flat_output.reshape(data_shape.to_vec())?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}

View File

@ -842,13 +842,22 @@ fn test_flatten_operation() -> Result<()> {
#[test]
fn test_constant_of_shape() -> Result<()> {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?;
test(
&[4i64, 3, 2],
Some(1.),
&[
[[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]],
],
)?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
test(&[0.], Some(0i64), &[0i64])?;
test(&[1i64], Some(0i64), &[0i64])?;
// "value" defaults to 0 f32
test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?;
test(&[4i64], None as Option<i64>, &[0., 0., 0., 0.])?;
fn test(
input: impl NdArray,
@ -5968,3 +5977,512 @@ fn test_sign_operation() -> Result<()> {
);
Ok(())
}
#[test]
fn test_scatternd_operation() -> Result<()> {
// Example 1 based on ONNX documentation
test(
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&[[4i64], [3], [1], [7]],
&[9.0f32, 10.0, 11.0, 12.0],
&[1.0f32, 11.0, 3.0, 10.0, 9.0, 6.0, 7.0, 12.0],
)?;
// A more complex example with 2D data
test(
&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
&[[0i64, 1], [1, 0]],
&[10.0f32, 20.0],
&[[1.0f32, 10.0], [20.0, 4.0], [5.0, 6.0]],
)?;
// 3D example with indices pointing to specific locations
test(
&[
[[1.0f32, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
],
&[[0i64, 0, 1], [1, 1, 0]],
&[100.0f32, 200.0],
&[
[[1.0f32, 100.0], [3.0, 4.0]],
[[5.0, 6.0], [200.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
],
)?;
fn test(
data: impl NdArray,
indices: impl NdArray,
updates: impl NdArray,
expected: impl NdArray,
) -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "ScatterND".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![
INPUT_X.to_string(),
INPUT_Y.to_string(),
INPUT_A.to_string(),
],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
inputs.insert(INPUT_A.to_string(), Tensor::new(updates, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
1 => assert_eq!(z.to_vec1::<f32>()?, expected.to_vec1::<f32>()?),
2 => assert_eq!(z.to_vec2::<f32>()?, expected.to_vec2::<f32>()?),
3 => assert_eq!(z.to_vec3::<f32>()?, expected.to_vec3::<f32>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
#[test]
fn test_trilu_operation() -> Result<()> {
// Test 1: Upper triangular matrix (default behavior with upper=true)
{
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![], // empty attribute means default upper=true
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![4, 7, 3, 7, 9],
vec![0, 2, 8, 6, 9],
vec![0, 0, 0, 8, 7],
vec![0, 0, 0, 2, 4]
]
);
}
// Test 2: Upper triangular with positive k=1 (diagonal above main)
{
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![1i64, 4, 9, 7, 1, 9, 2, 8, 8, 4, 3, 9, 7, 4, 2],
&[3, 5],
&Device::Cpu,
)?;
let k = Tensor::from_vec(vec![1i64], (), &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), k);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![0, 4, 9, 7, 1],
vec![0, 0, 8, 8, 4],
vec![0, 0, 0, 4, 2]
]
);
}
// Test 3: Upper triangular with negative k=-1 (one diagonal below main)
{
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), k);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![4, 7, 3, 7, 9],
vec![1, 2, 8, 6, 9],
vec![0, 4, 0, 8, 7],
vec![0, 0, 4, 2, 4]
]
);
}
// Test 4: Lower triangular matrix (upper=0)
{
let att_upper = AttributeProto {
name: "upper".to_string(),
ref_attr_name: "upper".to_string(),
i: 0, // 0 means false, use lower triangular
doc_string: "upper".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![att_upper],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
// Lower triangular matrix (default k=0)
assert_eq!(
results,
vec![
vec![4, 0, 0, 0, 0],
vec![1, 2, 0, 0, 0],
vec![9, 4, 1, 0, 0],
vec![4, 3, 4, 2, 0]
]
);
}
// Test 5: Lower triangular with negative k=-1
{
let att_upper = AttributeProto {
name: "upper".to_string(),
ref_attr_name: "upper".to_string(),
i: 0,
doc_string: "upper".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![att_upper],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), k);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![0, 0, 0, 0, 0],
vec![1, 0, 0, 0, 0],
vec![9, 4, 0, 0, 0],
vec![4, 3, 4, 0, 0]
]
);
}
// Test 6: Lower triangular with positive k=2
{
let att_upper = AttributeProto {
name: "upper".to_string(),
ref_attr_name: "upper".to_string(),
i: 0,
doc_string: "upper".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![att_upper],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let k = Tensor::from_vec(vec![2i64], (), &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), k);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![4, 7, 3, 0, 0],
vec![1, 2, 8, 6, 0],
vec![9, 4, 1, 8, 7],
vec![4, 3, 4, 2, 4]
]
);
}
Ok(())
}