mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
candle-onnx: Implement Trilu and ScatterND ops (#2952)
* onnx attention * setup an example, adding and fixing onnx ops bit by bit * model working, output is garbage data * trilu working * close but not quite, Issues still with scatterND * closer but the outputs are still slightly wrong * added tests for trilu and scatterND * lint * readme * clippy * removed unnessisary comments * changed device selection, took hyperparameters from model config
This commit is contained in:
@ -583,7 +583,13 @@ fn simple_eval_(
|
||||
&Device::Cpu,
|
||||
)?);
|
||||
|
||||
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
|
||||
let shape_vec: Vec<usize> = input
|
||||
.to_vec1::<i64>()?
|
||||
.iter()
|
||||
.map(|&x| x as usize)
|
||||
.collect();
|
||||
|
||||
let xs = Tensor::ones(shape_vec, value.dtype(), input.device())?
|
||||
.broadcast_mul(&value)?;
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
@ -1238,7 +1244,7 @@ fn simple_eval_(
|
||||
}
|
||||
|
||||
let indexes = Tensor::arange_step(s, e, p, data.device())?;
|
||||
out = out.index_select(&indexes, axis)?
|
||||
out = out.contiguous()?.index_select(&indexes, axis)?
|
||||
}
|
||||
values.insert(node.output[0].clone(), out);
|
||||
}
|
||||
@ -2030,6 +2036,203 @@ fn simple_eval_(
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Trilu" => {
|
||||
let input = get(&node.input[0])?;
|
||||
|
||||
// Get the diagonal offset 'k' from the second input if provided
|
||||
let k = if node.input.len() > 1 && !node.input[1].is_empty() {
|
||||
get(&node.input[1])?.to_vec0::<i64>()?
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// Get the 'upper' attribute
|
||||
let upper = get_attr_opt::<i64>(node, "upper")?.copied().unwrap_or(1);
|
||||
|
||||
// For batched inputs, we need to handle each matrix separately
|
||||
let dims = input.dims();
|
||||
if dims.len() < 2 {
|
||||
bail!("Trilu expects input with at least 2 dimensions: {:?}", dims);
|
||||
}
|
||||
|
||||
// Get the last two dimensions which represent the matrix
|
||||
let n = dims[dims.len() - 2];
|
||||
let m = dims[dims.len() - 1];
|
||||
let max_dim = std::cmp::max(n, m);
|
||||
|
||||
// Handle the diagonal offset k
|
||||
let mask = if k != 0 {
|
||||
let mut data = vec![0u32; n * m];
|
||||
for i in 0..n {
|
||||
for j in 0..m {
|
||||
if (upper != 0 && (j as i64) >= (i as i64) + k)
|
||||
|| (upper == 0 && (j as i64) <= (i as i64) + k)
|
||||
{
|
||||
data[i * m + j] = 1u32;
|
||||
}
|
||||
}
|
||||
}
|
||||
Tensor::from_vec(data, (n, m), input.device())?.to_dtype(input.dtype())?
|
||||
} else if upper == 0 {
|
||||
Tensor::tril2(max_dim, input.dtype(), input.device())?
|
||||
} else {
|
||||
Tensor::triu2(max_dim, input.dtype(), input.device())?
|
||||
};
|
||||
|
||||
let final_mask = if n != m {
|
||||
mask.narrow(0, 0, n)?.narrow(1, 0, m)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
|
||||
let output = (input * &final_mask)?;
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"ScatterND" => {
|
||||
let data = get(&node.input[0])?;
|
||||
|
||||
let indices = get(&node.input[1])?;
|
||||
let indices = indices.to_dtype(DType::I64)?;
|
||||
|
||||
let updates = get(&node.input[2])?;
|
||||
|
||||
let reduction = get_attr_opt::<str>(node, "reduction")?.unwrap_or("none");
|
||||
|
||||
let indices_shape = indices.dims();
|
||||
let data_shape = data.dims();
|
||||
let updates_shape = updates.dims();
|
||||
|
||||
// Last dimension of indices represents the depth of indexing
|
||||
let k = indices_shape.last().unwrap().clone();
|
||||
|
||||
if k > data.rank() {
|
||||
bail!("ScatterND expects k (indices.shape[-1]) to be at most the rank of data");
|
||||
}
|
||||
|
||||
let num_updates = indices_shape[..indices_shape.len() - 1]
|
||||
.iter()
|
||||
.product::<usize>();
|
||||
|
||||
let flat_indices = if indices.rank() == 1 && k == 1 {
|
||||
indices.unsqueeze(0)?
|
||||
} else {
|
||||
indices.reshape((num_updates, k))?
|
||||
};
|
||||
|
||||
// Calculate the shape of each update element
|
||||
let update_element_shape = if k < data_shape.len() {
|
||||
data_shape[k..].to_vec()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Expected shape for updates based on indices and target tensor
|
||||
let expected_updates_shape = {
|
||||
let mut shape = indices_shape[..indices_shape.len() - 1].to_vec();
|
||||
shape.extend(&update_element_shape);
|
||||
shape
|
||||
};
|
||||
|
||||
// Validate or reshape updates to expected shape
|
||||
let updates = if updates.dims() != expected_updates_shape {
|
||||
if updates.rank() == 0 {
|
||||
// Handle scalar updates
|
||||
let mut target_shape = vec![num_updates];
|
||||
target_shape.extend(&update_element_shape);
|
||||
updates.broadcast_as(target_shape)?
|
||||
} else {
|
||||
// Try to broadcast or reshape updates to expected shape
|
||||
let flat_shape =
|
||||
vec![num_updates, update_element_shape.iter().product::<usize>()];
|
||||
let flattened = updates.reshape(flat_shape)?;
|
||||
flattened.reshape(expected_updates_shape)?
|
||||
}
|
||||
} else {
|
||||
updates.clone()
|
||||
};
|
||||
|
||||
let mut output = data.clone();
|
||||
|
||||
// convert indices to flat indices
|
||||
let mut flat_output = output.flatten_all()?;
|
||||
let flat_updates = if update_element_shape.is_empty() {
|
||||
updates.reshape(num_updates)?
|
||||
} else {
|
||||
let product = update_element_shape.iter().product::<usize>();
|
||||
updates.reshape((num_updates, product))?
|
||||
};
|
||||
|
||||
// Calculate strides for the output tensor
|
||||
let mut strides: Vec<usize> = vec![1];
|
||||
for i in (0..data_shape.len() - 1).rev() {
|
||||
strides.push(strides.last().unwrap() * data_shape[i + 1]);
|
||||
}
|
||||
strides.reverse();
|
||||
|
||||
// Process each update
|
||||
for i in 0..num_updates {
|
||||
let index_slice = flat_indices.narrow(0, i, 1)?;
|
||||
let indices_vec = index_slice.squeeze(0)?.to_vec1::<i64>()?;
|
||||
|
||||
// Convert multi-dimensional indices to flat index
|
||||
let mut flat_idx: usize = 0;
|
||||
for (dim, &idx) in indices_vec.iter().enumerate() {
|
||||
let dim_size = data_shape[dim] as i64;
|
||||
let norm_idx = if idx < 0 { dim_size + idx } else { idx };
|
||||
|
||||
if norm_idx < 0 || norm_idx >= dim_size {
|
||||
bail!(
|
||||
"Index {} out of bounds for dimension {} with size {}",
|
||||
idx,
|
||||
dim,
|
||||
dim_size
|
||||
);
|
||||
}
|
||||
|
||||
flat_idx += (norm_idx as usize) * strides[dim];
|
||||
}
|
||||
|
||||
// Extract current update
|
||||
let update_slice = if update_element_shape.is_empty() {
|
||||
flat_updates.narrow(0, i, 1)?.squeeze(0)?
|
||||
} else {
|
||||
flat_updates.narrow(0, i, 1)?
|
||||
};
|
||||
|
||||
match reduction {
|
||||
"add" => {
|
||||
if update_element_shape.is_empty() {
|
||||
let existing = flat_output.narrow(0, flat_idx, 1)?;
|
||||
let new_value = existing.add(&update_slice.unsqueeze(0)?)?;
|
||||
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
|
||||
} else {
|
||||
let slice_size = update_element_shape.iter().product::<usize>();
|
||||
let existing = flat_output.narrow(0, flat_idx, slice_size)?;
|
||||
let new_value = existing.add(&update_slice)?;
|
||||
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
|
||||
}
|
||||
}
|
||||
"none" | _ => {
|
||||
if update_element_shape.is_empty() {
|
||||
flat_output = flat_output.slice_scatter(
|
||||
&update_slice.unsqueeze(0)?,
|
||||
0,
|
||||
flat_idx,
|
||||
)?;
|
||||
} else {
|
||||
flat_output =
|
||||
flat_output.slice_scatter(&update_slice, 0, flat_idx)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reshape flat output back to original shape
|
||||
output = flat_output.reshape(data_shape.to_vec())?;
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
@ -842,13 +842,22 @@ fn test_flatten_operation() -> Result<()> {
|
||||
#[test]
|
||||
fn test_constant_of_shape() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
||||
test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?;
|
||||
test(
|
||||
&[4i64, 3, 2],
|
||||
Some(1.),
|
||||
&[
|
||||
[[1., 1.], [1., 1.], [1., 1.]],
|
||||
[[1., 1.], [1., 1.], [1., 1.]],
|
||||
[[1., 1.], [1., 1.], [1., 1.]],
|
||||
[[1., 1.], [1., 1.], [1., 1.]],
|
||||
],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
||||
test(&[0.], Some(0i64), &[0i64])?;
|
||||
test(&[1i64], Some(0i64), &[0i64])?;
|
||||
|
||||
// "value" defaults to 0 f32
|
||||
test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?;
|
||||
test(&[4i64], None as Option<i64>, &[0., 0., 0., 0.])?;
|
||||
|
||||
fn test(
|
||||
input: impl NdArray,
|
||||
@ -5968,3 +5977,512 @@ fn test_sign_operation() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatternd_operation() -> Result<()> {
|
||||
// Example 1 based on ONNX documentation
|
||||
test(
|
||||
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
|
||||
&[[4i64], [3], [1], [7]],
|
||||
&[9.0f32, 10.0, 11.0, 12.0],
|
||||
&[1.0f32, 11.0, 3.0, 10.0, 9.0, 6.0, 7.0, 12.0],
|
||||
)?;
|
||||
|
||||
// A more complex example with 2D data
|
||||
test(
|
||||
&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
||||
&[[0i64, 1], [1, 0]],
|
||||
&[10.0f32, 20.0],
|
||||
&[[1.0f32, 10.0], [20.0, 4.0], [5.0, 6.0]],
|
||||
)?;
|
||||
|
||||
// 3D example with indices pointing to specific locations
|
||||
test(
|
||||
&[
|
||||
[[1.0f32, 2.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
],
|
||||
&[[0i64, 0, 1], [1, 1, 0]],
|
||||
&[100.0f32, 200.0],
|
||||
&[
|
||||
[[1.0f32, 100.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [200.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
],
|
||||
)?;
|
||||
|
||||
fn test(
|
||||
data: impl NdArray,
|
||||
indices: impl NdArray,
|
||||
updates: impl NdArray,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "ScatterND".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![
|
||||
INPUT_X.to_string(),
|
||||
INPUT_Y.to_string(),
|
||||
INPUT_A.to_string(),
|
||||
],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_A.to_string(), Tensor::new(updates, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
|
||||
match expected.dims().len() {
|
||||
1 => assert_eq!(z.to_vec1::<f32>()?, expected.to_vec1::<f32>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f32>()?, expected.to_vec2::<f32>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f32>()?, expected.to_vec3::<f32>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trilu_operation() -> Result<()> {
|
||||
// Test 1: Upper triangular matrix (default behavior with upper=true)
|
||||
{
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![], // empty attribute means default upper=true
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![4, 7, 3, 7, 9],
|
||||
vec![0, 2, 8, 6, 9],
|
||||
vec![0, 0, 0, 8, 7],
|
||||
vec![0, 0, 0, 2, 4]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 2: Upper triangular with positive k=1 (diagonal above main)
|
||||
{
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![1i64, 4, 9, 7, 1, 9, 2, 8, 8, 4, 3, 9, 7, 4, 2],
|
||||
&[3, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let k = Tensor::from_vec(vec![1i64], (), &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), k);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![0, 4, 9, 7, 1],
|
||||
vec![0, 0, 8, 8, 4],
|
||||
vec![0, 0, 0, 4, 2]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 3: Upper triangular with negative k=-1 (one diagonal below main)
|
||||
{
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), k);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![4, 7, 3, 7, 9],
|
||||
vec![1, 2, 8, 6, 9],
|
||||
vec![0, 4, 0, 8, 7],
|
||||
vec![0, 0, 4, 2, 4]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 4: Lower triangular matrix (upper=0)
|
||||
{
|
||||
let att_upper = AttributeProto {
|
||||
name: "upper".to_string(),
|
||||
ref_attr_name: "upper".to_string(),
|
||||
i: 0, // 0 means false, use lower triangular
|
||||
doc_string: "upper".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_upper],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
// Lower triangular matrix (default k=0)
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![4, 0, 0, 0, 0],
|
||||
vec![1, 2, 0, 0, 0],
|
||||
vec![9, 4, 1, 0, 0],
|
||||
vec![4, 3, 4, 2, 0]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 5: Lower triangular with negative k=-1
|
||||
{
|
||||
let att_upper = AttributeProto {
|
||||
name: "upper".to_string(),
|
||||
ref_attr_name: "upper".to_string(),
|
||||
i: 0,
|
||||
doc_string: "upper".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_upper],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), k);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![0, 0, 0, 0, 0],
|
||||
vec![1, 0, 0, 0, 0],
|
||||
vec![9, 4, 0, 0, 0],
|
||||
vec![4, 3, 4, 0, 0]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 6: Lower triangular with positive k=2
|
||||
{
|
||||
let att_upper = AttributeProto {
|
||||
name: "upper".to_string(),
|
||||
ref_attr_name: "upper".to_string(),
|
||||
i: 0,
|
||||
doc_string: "upper".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_upper],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let k = Tensor::from_vec(vec![2i64], (), &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), k);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![4, 7, 3, 0, 0],
|
||||
vec![1, 2, 8, 6, 0],
|
||||
vec![9, 4, 1, 8, 7],
|
||||
vec![4, 3, 4, 2, 4]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user