candle-onnx: Implement Trilu and ScatterND ops (#2952)

* onnx attention

* setup an example, adding and fixing onnx ops bit by bit

* model working, output is garbage data

* trilu working

* close but not quite, Issues still with scatterND

* closer but the outputs are still slightly wrong

* added tests for trilu and scatterND

* lint

* readme

* clippy

* removed unnessisary comments

* changed device selection, took hyperparameters from model config
This commit is contained in:
Kyle Birnbaum
2025-05-29 22:36:09 -07:00
committed by GitHub
parent 5aed817f1b
commit cd7b877d6b
5 changed files with 950 additions and 5 deletions

View File

@ -84,6 +84,10 @@ required-features = ["pyo3"]
name = "onnx"
required-features = ["onnx"]
[[example]]
name = "onnx-llm"
required-features = ["onnx"]
[[example]]
name = "onnx_basics"
required-features = ["onnx"]

View File

@ -0,0 +1,11 @@
## Using ONNX models in Candle
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based LLM models in Candle.
This script only implements SmolLM-135M right now.
You can run the examples with following commands:
```bash
cargo run --example onnx-llm --features onnx
```

View File

@ -0,0 +1,209 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Result;
use candle::{DType, Tensor};
use candle_transformers::generation::{LogitsProcessor, Sampling};
use clap::{Parser, ValueEnum};
use hf_hub::api::sync::Api;
use serde::Deserialize;
use std::io::Write;
use tokenizers::Tokenizer;
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub hidden_size: usize,
pub num_attention_heads: usize,
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
SmolLM135M,
}
#[derive(Parser)]
struct Args {
/// The prompt to be used.
#[arg(long, default_value = "My favorite theorem is ")]
prompt: String,
/// The model to be used.
#[arg(value_enum, long, default_value_t = Which::SmolLM135M)]
which: Which,
/// Run on CPU rather than GPU.
#[arg(long)]
cpu: bool,
/// The number of tokens to generate.
#[arg(long, default_value_t = 100)]
max_tokens: usize,
/// The temperature used for sampling.
#[arg(long, default_value_t = 0.8)]
temperature: f32,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// Only sample among the top K samples.
#[arg(long)]
top_k: Option<usize>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
}
pub fn main() -> Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let (model_id, tokenizer_id) = match args.which {
Which::SmolLM135M => ("HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-135M"),
};
let api = Api::new()?;
let model_repo = api.model(model_id.to_string());
let tokenizer_repo = api.model(tokenizer_id.to_string());
let model_path = model_repo.get("onnx/model.onnx")?;
let config_file = model_repo.get("config.json")?;
let config: Config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
let tokenizer_path = tokenizer_repo.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
let tokens_u32 = tokenizer
.encode(args.prompt.as_str(), true)
.map_err(anyhow::Error::msg)?
.get_ids()
.to_vec();
let tokens: Vec<i64> = tokens_u32.iter().map(|&t| t as i64).collect();
println!("Loading ONNX model from {:?}", model_path);
let model = candle_onnx::read_file(model_path)?;
let mut generated_tokens = tokens.clone();
print!("{}", args.prompt);
std::io::stdout().flush()?;
let mut logits_processor = {
let temperature = args.temperature as f64;
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
match (args.top_k, args.top_p) {
(None, None) => Sampling::All { temperature },
(Some(k), None) => Sampling::TopK { k, temperature },
(None, Some(p)) => Sampling::TopP { p, temperature },
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
}
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
let mut past_key_values: Option<Vec<(Tensor, Tensor)>> = None;
let num_layers = config.num_hidden_layers;
for _ in 0..args.max_tokens {
let mut inputs = std::collections::HashMap::new();
if let Some(past_kv) = &past_key_values {
let last_token = vec![generated_tokens[generated_tokens.len() - 1]];
let input_tensor = Tensor::new(last_token, &device)?.unsqueeze(0)?;
inputs.insert("input_ids".to_string(), input_tensor);
let seq_len = generated_tokens.len();
let attention_mask = vec![vec![1i64; seq_len]];
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
let position_ids = vec![vec![(seq_len - 1) as i64]];
let position_ids_tensor = Tensor::new(position_ids, &device)?;
inputs.insert("position_ids".to_string(), position_ids_tensor);
for (i, (key, value)) in past_kv.iter().enumerate() {
inputs.insert(format!("past_key_values.{}.key", i), key.clone());
inputs.insert(format!("past_key_values.{}.value", i), value.clone());
}
} else {
let input_tensor = Tensor::new(generated_tokens.clone(), &device)?.unsqueeze(0)?;
inputs.insert("input_ids".to_string(), input_tensor);
let seq_len = generated_tokens.len();
let attention_mask = vec![vec![1i64; seq_len]];
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
let position_ids: Vec<i64> = (0..seq_len as i64).collect();
let position_ids_tensor = Tensor::new(position_ids, &device)?.unsqueeze(0)?;
inputs.insert("position_ids".to_string(), position_ids_tensor);
// Create empty key and value tensors
for i in 0..num_layers {
let batch_size = 1;
let num_heads = config.num_key_value_heads;
let head_dim = config.hidden_size / config.num_attention_heads;
let seq_len = 0;
let empty_key = Tensor::zeros(
&[batch_size, num_heads, seq_len, head_dim],
DType::F32,
&device,
)?;
let empty_value = Tensor::zeros(
&[batch_size, num_heads, seq_len, head_dim],
DType::F32,
&device,
)?;
inputs.insert(format!("past_key_values.{}.key", i), empty_key);
inputs.insert(format!("past_key_values.{}.value", i), empty_value);
}
}
let outputs = candle_onnx::simple_eval(&model, inputs)?;
let logits = outputs.get("logits").unwrap();
let mut new_past_kv = Vec::with_capacity(num_layers);
for i in 0..num_layers {
let key = outputs
.get(&format!("present.{}.key", i))
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.key", i))?;
let value = outputs
.get(&format!("present.{}.value", i))
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.value", i))?;
new_past_kv.push((key.clone(), value.clone()));
}
past_key_values = Some(new_past_kv);
let logits_dim = logits.dims();
let seq_len = logits_dim[1];
let next_token_id = logits_processor.sample(&logits.get(0)?.get(seq_len - 1)?)?;
generated_tokens.push(next_token_id as i64);
if let Some(token_str) = tokenizer.decode(&[next_token_id], true).ok() {
print!("{}", token_str);
std::io::stdout().flush()?;
}
if let Some(eos_id) = tokenizer.token_to_id("<|endoftext|>") {
if next_token_id == eos_id {
break;
}
}
}
println!("\nGeneration complete!");
Ok(())
}

View File

@ -583,7 +583,13 @@ fn simple_eval_(
&Device::Cpu,
)?);
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
let shape_vec: Vec<usize> = input
.to_vec1::<i64>()?
.iter()
.map(|&x| x as usize)
.collect();
let xs = Tensor::ones(shape_vec, value.dtype(), input.device())?
.broadcast_mul(&value)?;
values.insert(node.output[0].clone(), xs);
}
@ -1238,7 +1244,7 @@ fn simple_eval_(
}
let indexes = Tensor::arange_step(s, e, p, data.device())?;
out = out.index_select(&indexes, axis)?
out = out.contiguous()?.index_select(&indexes, axis)?
}
values.insert(node.output[0].clone(), out);
}
@ -2030,6 +2036,203 @@ fn simple_eval_(
values.insert(node.output[0].clone(), output);
}
"Trilu" => {
let input = get(&node.input[0])?;
// Get the diagonal offset 'k' from the second input if provided
let k = if node.input.len() > 1 && !node.input[1].is_empty() {
get(&node.input[1])?.to_vec0::<i64>()?
} else {
0
};
// Get the 'upper' attribute
let upper = get_attr_opt::<i64>(node, "upper")?.copied().unwrap_or(1);
// For batched inputs, we need to handle each matrix separately
let dims = input.dims();
if dims.len() < 2 {
bail!("Trilu expects input with at least 2 dimensions: {:?}", dims);
}
// Get the last two dimensions which represent the matrix
let n = dims[dims.len() - 2];
let m = dims[dims.len() - 1];
let max_dim = std::cmp::max(n, m);
// Handle the diagonal offset k
let mask = if k != 0 {
let mut data = vec![0u32; n * m];
for i in 0..n {
for j in 0..m {
if (upper != 0 && (j as i64) >= (i as i64) + k)
|| (upper == 0 && (j as i64) <= (i as i64) + k)
{
data[i * m + j] = 1u32;
}
}
}
Tensor::from_vec(data, (n, m), input.device())?.to_dtype(input.dtype())?
} else if upper == 0 {
Tensor::tril2(max_dim, input.dtype(), input.device())?
} else {
Tensor::triu2(max_dim, input.dtype(), input.device())?
};
let final_mask = if n != m {
mask.narrow(0, 0, n)?.narrow(1, 0, m)?
} else {
mask
};
let output = (input * &final_mask)?;
values.insert(node.output[0].clone(), output);
}
"ScatterND" => {
let data = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let indices = indices.to_dtype(DType::I64)?;
let updates = get(&node.input[2])?;
let reduction = get_attr_opt::<str>(node, "reduction")?.unwrap_or("none");
let indices_shape = indices.dims();
let data_shape = data.dims();
let updates_shape = updates.dims();
// Last dimension of indices represents the depth of indexing
let k = indices_shape.last().unwrap().clone();
if k > data.rank() {
bail!("ScatterND expects k (indices.shape[-1]) to be at most the rank of data");
}
let num_updates = indices_shape[..indices_shape.len() - 1]
.iter()
.product::<usize>();
let flat_indices = if indices.rank() == 1 && k == 1 {
indices.unsqueeze(0)?
} else {
indices.reshape((num_updates, k))?
};
// Calculate the shape of each update element
let update_element_shape = if k < data_shape.len() {
data_shape[k..].to_vec()
} else {
vec![]
};
// Expected shape for updates based on indices and target tensor
let expected_updates_shape = {
let mut shape = indices_shape[..indices_shape.len() - 1].to_vec();
shape.extend(&update_element_shape);
shape
};
// Validate or reshape updates to expected shape
let updates = if updates.dims() != expected_updates_shape {
if updates.rank() == 0 {
// Handle scalar updates
let mut target_shape = vec![num_updates];
target_shape.extend(&update_element_shape);
updates.broadcast_as(target_shape)?
} else {
// Try to broadcast or reshape updates to expected shape
let flat_shape =
vec![num_updates, update_element_shape.iter().product::<usize>()];
let flattened = updates.reshape(flat_shape)?;
flattened.reshape(expected_updates_shape)?
}
} else {
updates.clone()
};
let mut output = data.clone();
// convert indices to flat indices
let mut flat_output = output.flatten_all()?;
let flat_updates = if update_element_shape.is_empty() {
updates.reshape(num_updates)?
} else {
let product = update_element_shape.iter().product::<usize>();
updates.reshape((num_updates, product))?
};
// Calculate strides for the output tensor
let mut strides: Vec<usize> = vec![1];
for i in (0..data_shape.len() - 1).rev() {
strides.push(strides.last().unwrap() * data_shape[i + 1]);
}
strides.reverse();
// Process each update
for i in 0..num_updates {
let index_slice = flat_indices.narrow(0, i, 1)?;
let indices_vec = index_slice.squeeze(0)?.to_vec1::<i64>()?;
// Convert multi-dimensional indices to flat index
let mut flat_idx: usize = 0;
for (dim, &idx) in indices_vec.iter().enumerate() {
let dim_size = data_shape[dim] as i64;
let norm_idx = if idx < 0 { dim_size + idx } else { idx };
if norm_idx < 0 || norm_idx >= dim_size {
bail!(
"Index {} out of bounds for dimension {} with size {}",
idx,
dim,
dim_size
);
}
flat_idx += (norm_idx as usize) * strides[dim];
}
// Extract current update
let update_slice = if update_element_shape.is_empty() {
flat_updates.narrow(0, i, 1)?.squeeze(0)?
} else {
flat_updates.narrow(0, i, 1)?
};
match reduction {
"add" => {
if update_element_shape.is_empty() {
let existing = flat_output.narrow(0, flat_idx, 1)?;
let new_value = existing.add(&update_slice.unsqueeze(0)?)?;
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
} else {
let slice_size = update_element_shape.iter().product::<usize>();
let existing = flat_output.narrow(0, flat_idx, slice_size)?;
let new_value = existing.add(&update_slice)?;
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
}
}
"none" | _ => {
if update_element_shape.is_empty() {
flat_output = flat_output.slice_scatter(
&update_slice.unsqueeze(0)?,
0,
flat_idx,
)?;
} else {
flat_output =
flat_output.slice_scatter(&update_slice, 0, flat_idx)?;
}
}
}
}
// Reshape flat output back to original shape
output = flat_output.reshape(data_shape.to_vec())?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}

View File

@ -842,13 +842,22 @@ fn test_flatten_operation() -> Result<()> {
#[test]
fn test_constant_of_shape() -> Result<()> {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?;
test(
&[4i64, 3, 2],
Some(1.),
&[
[[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]],
[[1., 1.], [1., 1.], [1., 1.]],
],
)?;
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
test(&[0.], Some(0i64), &[0i64])?;
test(&[1i64], Some(0i64), &[0i64])?;
// "value" defaults to 0 f32
test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?;
test(&[4i64], None as Option<i64>, &[0., 0., 0., 0.])?;
fn test(
input: impl NdArray,
@ -5968,3 +5977,512 @@ fn test_sign_operation() -> Result<()> {
);
Ok(())
}
#[test]
fn test_scatternd_operation() -> Result<()> {
// Example 1 based on ONNX documentation
test(
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
&[[4i64], [3], [1], [7]],
&[9.0f32, 10.0, 11.0, 12.0],
&[1.0f32, 11.0, 3.0, 10.0, 9.0, 6.0, 7.0, 12.0],
)?;
// A more complex example with 2D data
test(
&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
&[[0i64, 1], [1, 0]],
&[10.0f32, 20.0],
&[[1.0f32, 10.0], [20.0, 4.0], [5.0, 6.0]],
)?;
// 3D example with indices pointing to specific locations
test(
&[
[[1.0f32, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
],
&[[0i64, 0, 1], [1, 1, 0]],
&[100.0f32, 200.0],
&[
[[1.0f32, 100.0], [3.0, 4.0]],
[[5.0, 6.0], [200.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
],
)?;
fn test(
data: impl NdArray,
indices: impl NdArray,
updates: impl NdArray,
expected: impl NdArray,
) -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "ScatterND".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![
INPUT_X.to_string(),
INPUT_Y.to_string(),
INPUT_A.to_string(),
],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
inputs.insert(INPUT_A.to_string(), Tensor::new(updates, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
1 => assert_eq!(z.to_vec1::<f32>()?, expected.to_vec1::<f32>()?),
2 => assert_eq!(z.to_vec2::<f32>()?, expected.to_vec2::<f32>()?),
3 => assert_eq!(z.to_vec3::<f32>()?, expected.to_vec3::<f32>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
#[test]
fn test_trilu_operation() -> Result<()> {
// Test 1: Upper triangular matrix (default behavior with upper=true)
{
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![], // empty attribute means default upper=true
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![4, 7, 3, 7, 9],
vec![0, 2, 8, 6, 9],
vec![0, 0, 0, 8, 7],
vec![0, 0, 0, 2, 4]
]
);
}
// Test 2: Upper triangular with positive k=1 (diagonal above main)
{
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![
ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
},
ValueInfoProto {
name: INPUT_Y.to_string(),
doc_string: "".to_string(),
r#type: None,
},
],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![1i64, 4, 9, 7, 1, 9, 2, 8, 8, 4, 3, 9, 7, 4, 2],
&[3, 5],
&Device::Cpu,
)?;
let k = Tensor::from_vec(vec![1i64], (), &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), k);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![0, 4, 9, 7, 1],
vec![0, 0, 8, 8, 4],
vec![0, 0, 0, 4, 2]
]
);
}
// Test 3: Upper triangular with negative k=-1 (one diagonal below main)
{
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), k);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![4, 7, 3, 7, 9],
vec![1, 2, 8, 6, 9],
vec![0, 4, 0, 8, 7],
vec![0, 0, 4, 2, 4]
]
);
}
// Test 4: Lower triangular matrix (upper=0)
{
let att_upper = AttributeProto {
name: "upper".to_string(),
ref_attr_name: "upper".to_string(),
i: 0, // 0 means false, use lower triangular
doc_string: "upper".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![att_upper],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
// Lower triangular matrix (default k=0)
assert_eq!(
results,
vec![
vec![4, 0, 0, 0, 0],
vec![1, 2, 0, 0, 0],
vec![9, 4, 1, 0, 0],
vec![4, 3, 4, 2, 0]
]
);
}
// Test 5: Lower triangular with negative k=-1
{
let att_upper = AttributeProto {
name: "upper".to_string(),
ref_attr_name: "upper".to_string(),
i: 0,
doc_string: "upper".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![att_upper],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), k);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![0, 0, 0, 0, 0],
vec![1, 0, 0, 0, 0],
vec![9, 4, 0, 0, 0],
vec![4, 3, 4, 0, 0]
]
);
}
// Test 6: Lower triangular with positive k=2
{
let att_upper = AttributeProto {
name: "upper".to_string(),
ref_attr_name: "upper".to_string(),
i: 0,
doc_string: "upper".to_string(),
r#type: 2,
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Trilu".to_string(),
domain: "".to_string(),
attribute: vec![att_upper],
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
vec![
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
],
&[4, 5],
&Device::Cpu,
)?;
let k = Tensor::from_vec(vec![2i64], (), &Device::Cpu)?;
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);
inputs.insert(INPUT_Y.to_string(), k);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let results = z.to_vec2::<i64>()?;
assert_eq!(
results,
vec![
vec![4, 7, 3, 0, 0],
vec![1, 2, 8, 6, 0],
vec![9, 4, 1, 8, 7],
vec![4, 3, 4, 2, 4]
]
);
}
Ok(())
}