From cd7b877d6b5f2072e77b3bedc310dd8df0091257 Mon Sep 17 00:00:00 2001 From: Kyle Birnbaum Date: Thu, 29 May 2025 22:36:09 -0700 Subject: [PATCH] candle-onnx: Implement Trilu and ScatterND ops (#2952) * onnx attention * setup an example, adding and fixing onnx ops bit by bit * model working, output is garbage data * trilu working * close but not quite, Issues still with scatterND * closer but the outputs are still slightly wrong * added tests for trilu and scatterND * lint * readme * clippy * removed unnessisary comments * changed device selection, took hyperparameters from model config --- candle-examples/Cargo.toml | 4 + candle-examples/examples/onnx-llm/README.md | 11 + candle-examples/examples/onnx-llm/main.rs | 209 ++++++++ candle-onnx/src/eval.rs | 207 +++++++- candle-onnx/tests/ops.rs | 524 +++++++++++++++++++- 5 files changed, 950 insertions(+), 5 deletions(-) create mode 100644 candle-examples/examples/onnx-llm/README.md create mode 100644 candle-examples/examples/onnx-llm/main.rs diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 0d5f3cb6..83d1d6b4 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -84,6 +84,10 @@ required-features = ["pyo3"] name = "onnx" required-features = ["onnx"] +[[example]] +name = "onnx-llm" +required-features = ["onnx"] + [[example]] name = "onnx_basics" required-features = ["onnx"] diff --git a/candle-examples/examples/onnx-llm/README.md b/candle-examples/examples/onnx-llm/README.md new file mode 100644 index 00000000..506acd3a --- /dev/null +++ b/candle-examples/examples/onnx-llm/README.md @@ -0,0 +1,11 @@ +## Using ONNX models in Candle + +This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based LLM models in Candle. + +This script only implements SmolLM-135M right now. + +You can run the examples with following commands: + +```bash +cargo run --example onnx-llm --features onnx +``` \ No newline at end of file diff --git a/candle-examples/examples/onnx-llm/main.rs b/candle-examples/examples/onnx-llm/main.rs new file mode 100644 index 00000000..6cdb8d17 --- /dev/null +++ b/candle-examples/examples/onnx-llm/main.rs @@ -0,0 +1,209 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use anyhow::Result; +use candle::{DType, Tensor}; +use candle_transformers::generation::{LogitsProcessor, Sampling}; +use clap::{Parser, ValueEnum}; +use hf_hub::api::sync::Api; +use serde::Deserialize; +use std::io::Write; +use tokenizers::Tokenizer; + +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub num_hidden_layers: usize, + pub num_key_value_heads: usize, + pub hidden_size: usize, + pub num_attention_heads: usize, +} + +#[derive(Clone, Copy, Debug, ValueEnum)] +enum Which { + SmolLM135M, +} + +#[derive(Parser)] +struct Args { + /// The prompt to be used. + #[arg(long, default_value = "My favorite theorem is ")] + prompt: String, + + /// The model to be used. + #[arg(value_enum, long, default_value_t = Which::SmolLM135M)] + which: Which, + + /// Run on CPU rather than GPU. + #[arg(long)] + cpu: bool, + + /// The number of tokens to generate. + #[arg(long, default_value_t = 100)] + max_tokens: usize, + + /// The temperature used for sampling. + #[arg(long, default_value_t = 0.8)] + temperature: f32, + + /// Nucleus sampling probability cutoff. + #[arg(long)] + top_p: Option, + + /// Only sample among the top K samples. + #[arg(long)] + top_k: Option, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, +} + +pub fn main() -> Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + + let (model_id, tokenizer_id) = match args.which { + Which::SmolLM135M => ("HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-135M"), + }; + + let api = Api::new()?; + let model_repo = api.model(model_id.to_string()); + let tokenizer_repo = api.model(tokenizer_id.to_string()); + + let model_path = model_repo.get("onnx/model.onnx")?; + let config_file = model_repo.get("config.json")?; + let config: Config = serde_json::from_reader(std::fs::File::open(config_file)?)?; + + let tokenizer_path = tokenizer_repo.get("tokenizer.json")?; + let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?; + + let tokens_u32 = tokenizer + .encode(args.prompt.as_str(), true) + .map_err(anyhow::Error::msg)? + .get_ids() + .to_vec(); + + let tokens: Vec = tokens_u32.iter().map(|&t| t as i64).collect(); + + println!("Loading ONNX model from {:?}", model_path); + let model = candle_onnx::read_file(model_path)?; + + let mut generated_tokens = tokens.clone(); + print!("{}", args.prompt); + std::io::stdout().flush()?; + + let mut logits_processor = { + let temperature = args.temperature as f64; + let sampling = if temperature <= 0. { + Sampling::ArgMax + } else { + match (args.top_k, args.top_p) { + (None, None) => Sampling::All { temperature }, + (Some(k), None) => Sampling::TopK { k, temperature }, + (None, Some(p)) => Sampling::TopP { p, temperature }, + (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, + } + }; + LogitsProcessor::from_sampling(args.seed, sampling) + }; + + let mut past_key_values: Option> = None; + let num_layers = config.num_hidden_layers; + + for _ in 0..args.max_tokens { + let mut inputs = std::collections::HashMap::new(); + + if let Some(past_kv) = &past_key_values { + let last_token = vec![generated_tokens[generated_tokens.len() - 1]]; + let input_tensor = Tensor::new(last_token, &device)?.unsqueeze(0)?; + inputs.insert("input_ids".to_string(), input_tensor); + + let seq_len = generated_tokens.len(); + let attention_mask = vec![vec![1i64; seq_len]]; + let attention_mask_tensor = Tensor::new(attention_mask, &device)?; + inputs.insert("attention_mask".to_string(), attention_mask_tensor); + + let position_ids = vec![vec![(seq_len - 1) as i64]]; + let position_ids_tensor = Tensor::new(position_ids, &device)?; + inputs.insert("position_ids".to_string(), position_ids_tensor); + + for (i, (key, value)) in past_kv.iter().enumerate() { + inputs.insert(format!("past_key_values.{}.key", i), key.clone()); + inputs.insert(format!("past_key_values.{}.value", i), value.clone()); + } + } else { + let input_tensor = Tensor::new(generated_tokens.clone(), &device)?.unsqueeze(0)?; + inputs.insert("input_ids".to_string(), input_tensor); + + let seq_len = generated_tokens.len(); + let attention_mask = vec![vec![1i64; seq_len]]; + let attention_mask_tensor = Tensor::new(attention_mask, &device)?; + inputs.insert("attention_mask".to_string(), attention_mask_tensor); + + let position_ids: Vec = (0..seq_len as i64).collect(); + let position_ids_tensor = Tensor::new(position_ids, &device)?.unsqueeze(0)?; + inputs.insert("position_ids".to_string(), position_ids_tensor); + + // Create empty key and value tensors + for i in 0..num_layers { + let batch_size = 1; + let num_heads = config.num_key_value_heads; + let head_dim = config.hidden_size / config.num_attention_heads; + let seq_len = 0; + + let empty_key = Tensor::zeros( + &[batch_size, num_heads, seq_len, head_dim], + DType::F32, + &device, + )?; + let empty_value = Tensor::zeros( + &[batch_size, num_heads, seq_len, head_dim], + DType::F32, + &device, + )?; + + inputs.insert(format!("past_key_values.{}.key", i), empty_key); + inputs.insert(format!("past_key_values.{}.value", i), empty_value); + } + } + + let outputs = candle_onnx::simple_eval(&model, inputs)?; + + let logits = outputs.get("logits").unwrap(); + + let mut new_past_kv = Vec::with_capacity(num_layers); + for i in 0..num_layers { + let key = outputs + .get(&format!("present.{}.key", i)) + .ok_or_else(|| anyhow::anyhow!("Missing present.{}.key", i))?; + let value = outputs + .get(&format!("present.{}.value", i)) + .ok_or_else(|| anyhow::anyhow!("Missing present.{}.value", i))?; + new_past_kv.push((key.clone(), value.clone())); + } + past_key_values = Some(new_past_kv); + + let logits_dim = logits.dims(); + let seq_len = logits_dim[1]; + + let next_token_id = logits_processor.sample(&logits.get(0)?.get(seq_len - 1)?)?; + generated_tokens.push(next_token_id as i64); + + if let Some(token_str) = tokenizer.decode(&[next_token_id], true).ok() { + print!("{}", token_str); + std::io::stdout().flush()?; + } + + if let Some(eos_id) = tokenizer.token_to_id("<|endoftext|>") { + if next_token_id == eos_id { + break; + } + } + } + + println!("\nGeneration complete!"); + Ok(()) +} diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 56a916fe..8af0c645 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -583,7 +583,13 @@ fn simple_eval_( &Device::Cpu, )?); - let xs = Tensor::ones(input.shape(), value.dtype(), input.device())? + let shape_vec: Vec = input + .to_vec1::()? + .iter() + .map(|&x| x as usize) + .collect(); + + let xs = Tensor::ones(shape_vec, value.dtype(), input.device())? .broadcast_mul(&value)?; values.insert(node.output[0].clone(), xs); } @@ -1238,7 +1244,7 @@ fn simple_eval_( } let indexes = Tensor::arange_step(s, e, p, data.device())?; - out = out.index_select(&indexes, axis)? + out = out.contiguous()?.index_select(&indexes, axis)? } values.insert(node.output[0].clone(), out); } @@ -2030,6 +2036,203 @@ fn simple_eval_( values.insert(node.output[0].clone(), output); } + "Trilu" => { + let input = get(&node.input[0])?; + + // Get the diagonal offset 'k' from the second input if provided + let k = if node.input.len() > 1 && !node.input[1].is_empty() { + get(&node.input[1])?.to_vec0::()? + } else { + 0 + }; + + // Get the 'upper' attribute + let upper = get_attr_opt::(node, "upper")?.copied().unwrap_or(1); + + // For batched inputs, we need to handle each matrix separately + let dims = input.dims(); + if dims.len() < 2 { + bail!("Trilu expects input with at least 2 dimensions: {:?}", dims); + } + + // Get the last two dimensions which represent the matrix + let n = dims[dims.len() - 2]; + let m = dims[dims.len() - 1]; + let max_dim = std::cmp::max(n, m); + + // Handle the diagonal offset k + let mask = if k != 0 { + let mut data = vec![0u32; n * m]; + for i in 0..n { + for j in 0..m { + if (upper != 0 && (j as i64) >= (i as i64) + k) + || (upper == 0 && (j as i64) <= (i as i64) + k) + { + data[i * m + j] = 1u32; + } + } + } + Tensor::from_vec(data, (n, m), input.device())?.to_dtype(input.dtype())? + } else if upper == 0 { + Tensor::tril2(max_dim, input.dtype(), input.device())? + } else { + Tensor::triu2(max_dim, input.dtype(), input.device())? + }; + + let final_mask = if n != m { + mask.narrow(0, 0, n)?.narrow(1, 0, m)? + } else { + mask + }; + + let output = (input * &final_mask)?; + + values.insert(node.output[0].clone(), output); + } + "ScatterND" => { + let data = get(&node.input[0])?; + + let indices = get(&node.input[1])?; + let indices = indices.to_dtype(DType::I64)?; + + let updates = get(&node.input[2])?; + + let reduction = get_attr_opt::(node, "reduction")?.unwrap_or("none"); + + let indices_shape = indices.dims(); + let data_shape = data.dims(); + let updates_shape = updates.dims(); + + // Last dimension of indices represents the depth of indexing + let k = indices_shape.last().unwrap().clone(); + + if k > data.rank() { + bail!("ScatterND expects k (indices.shape[-1]) to be at most the rank of data"); + } + + let num_updates = indices_shape[..indices_shape.len() - 1] + .iter() + .product::(); + + let flat_indices = if indices.rank() == 1 && k == 1 { + indices.unsqueeze(0)? + } else { + indices.reshape((num_updates, k))? + }; + + // Calculate the shape of each update element + let update_element_shape = if k < data_shape.len() { + data_shape[k..].to_vec() + } else { + vec![] + }; + + // Expected shape for updates based on indices and target tensor + let expected_updates_shape = { + let mut shape = indices_shape[..indices_shape.len() - 1].to_vec(); + shape.extend(&update_element_shape); + shape + }; + + // Validate or reshape updates to expected shape + let updates = if updates.dims() != expected_updates_shape { + if updates.rank() == 0 { + // Handle scalar updates + let mut target_shape = vec![num_updates]; + target_shape.extend(&update_element_shape); + updates.broadcast_as(target_shape)? + } else { + // Try to broadcast or reshape updates to expected shape + let flat_shape = + vec![num_updates, update_element_shape.iter().product::()]; + let flattened = updates.reshape(flat_shape)?; + flattened.reshape(expected_updates_shape)? + } + } else { + updates.clone() + }; + + let mut output = data.clone(); + + // convert indices to flat indices + let mut flat_output = output.flatten_all()?; + let flat_updates = if update_element_shape.is_empty() { + updates.reshape(num_updates)? + } else { + let product = update_element_shape.iter().product::(); + updates.reshape((num_updates, product))? + }; + + // Calculate strides for the output tensor + let mut strides: Vec = vec![1]; + for i in (0..data_shape.len() - 1).rev() { + strides.push(strides.last().unwrap() * data_shape[i + 1]); + } + strides.reverse(); + + // Process each update + for i in 0..num_updates { + let index_slice = flat_indices.narrow(0, i, 1)?; + let indices_vec = index_slice.squeeze(0)?.to_vec1::()?; + + // Convert multi-dimensional indices to flat index + let mut flat_idx: usize = 0; + for (dim, &idx) in indices_vec.iter().enumerate() { + let dim_size = data_shape[dim] as i64; + let norm_idx = if idx < 0 { dim_size + idx } else { idx }; + + if norm_idx < 0 || norm_idx >= dim_size { + bail!( + "Index {} out of bounds for dimension {} with size {}", + idx, + dim, + dim_size + ); + } + + flat_idx += (norm_idx as usize) * strides[dim]; + } + + // Extract current update + let update_slice = if update_element_shape.is_empty() { + flat_updates.narrow(0, i, 1)?.squeeze(0)? + } else { + flat_updates.narrow(0, i, 1)? + }; + + match reduction { + "add" => { + if update_element_shape.is_empty() { + let existing = flat_output.narrow(0, flat_idx, 1)?; + let new_value = existing.add(&update_slice.unsqueeze(0)?)?; + flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?; + } else { + let slice_size = update_element_shape.iter().product::(); + let existing = flat_output.narrow(0, flat_idx, slice_size)?; + let new_value = existing.add(&update_slice)?; + flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?; + } + } + "none" | _ => { + if update_element_shape.is_empty() { + flat_output = flat_output.slice_scatter( + &update_slice.unsqueeze(0)?, + 0, + flat_idx, + )?; + } else { + flat_output = + flat_output.slice_scatter(&update_slice, 0, flat_idx)?; + } + } + } + } + + // Reshape flat output back to original shape + output = flat_output.reshape(data_shape.to_vec())?; + + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index dffb79b7..ccd0a0e9 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -842,13 +842,22 @@ fn test_flatten_operation() -> Result<()> { #[test] fn test_constant_of_shape() -> Result<()> { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31 - test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?; + test( + &[4i64, 3, 2], + Some(1.), + &[ + [[1., 1.], [1., 1.], [1., 1.]], + [[1., 1.], [1., 1.], [1., 1.]], + [[1., 1.], [1., 1.], [1., 1.]], + [[1., 1.], [1., 1.], [1., 1.]], + ], + )?; // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31 - test(&[0.], Some(0i64), &[0i64])?; + test(&[1i64], Some(0i64), &[0i64])?; // "value" defaults to 0 f32 - test(&[1i64, 2, 3, 4], None as Option, &[0., 0., 0., 0.])?; + test(&[4i64], None as Option, &[0., 0., 0., 0.])?; fn test( input: impl NdArray, @@ -5968,3 +5977,512 @@ fn test_sign_operation() -> Result<()> { ); Ok(()) } + +#[test] +fn test_scatternd_operation() -> Result<()> { + // Example 1 based on ONNX documentation + test( + &[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + &[[4i64], [3], [1], [7]], + &[9.0f32, 10.0, 11.0, 12.0], + &[1.0f32, 11.0, 3.0, 10.0, 9.0, 6.0, 7.0, 12.0], + )?; + + // A more complex example with 2D data + test( + &[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]], + &[[0i64, 1], [1, 0]], + &[10.0f32, 20.0], + &[[1.0f32, 10.0], [20.0, 4.0], [5.0, 6.0]], + )?; + + // 3D example with indices pointing to specific locations + test( + &[ + [[1.0f32, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + ], + &[[0i64, 0, 1], [1, 1, 0]], + &[100.0f32, 200.0], + &[ + [[1.0f32, 100.0], [3.0, 4.0]], + [[5.0, 6.0], [200.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + ], + )?; + + fn test( + data: impl NdArray, + indices: impl NdArray, + updates: impl NdArray, + expected: impl NdArray, + ) -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "ScatterND".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![ + INPUT_X.to_string(), + INPUT_Y.to_string(), + INPUT_A.to_string(), + ], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?); + inputs.insert(INPUT_A.to_string(), Tensor::new(updates, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let expected = Tensor::new(expected, &Device::Cpu)?; + + match expected.dims().len() { + 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), + 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), + 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} + +#[test] +fn test_trilu_operation() -> Result<()> { + // Test 1: Upper triangular matrix (default behavior with upper=true) + { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Trilu".to_string(), + domain: "".to_string(), + attribute: vec![], // empty attribute means default upper=true + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let x = Tensor::from_vec( + vec![ + 4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4, + ], + &[4, 5], + &Device::Cpu, + )?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let results = z.to_vec2::()?; + + assert_eq!( + results, + vec![ + vec![4, 7, 3, 7, 9], + vec![0, 2, 8, 6, 9], + vec![0, 0, 0, 8, 7], + vec![0, 0, 0, 2, 4] + ] + ); + } + + // Test 2: Upper triangular with positive k=1 (diagonal above main) + { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Trilu".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ + ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ValueInfoProto { + name: INPUT_Y.to_string(), + doc_string: "".to_string(), + r#type: None, + }, + ], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let x = Tensor::from_vec( + vec![1i64, 4, 9, 7, 1, 9, 2, 8, 8, 4, 3, 9, 7, 4, 2], + &[3, 5], + &Device::Cpu, + )?; + + let k = Tensor::from_vec(vec![1i64], (), &Device::Cpu)?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + inputs.insert(INPUT_Y.to_string(), k); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let results = z.to_vec2::()?; + + assert_eq!( + results, + vec![ + vec![0, 4, 9, 7, 1], + vec![0, 0, 8, 8, 4], + vec![0, 0, 0, 4, 2] + ] + ); + } + + // Test 3: Upper triangular with negative k=-1 (one diagonal below main) + { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Trilu".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let x = Tensor::from_vec( + vec![ + 4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4, + ], + &[4, 5], + &Device::Cpu, + )?; + + let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + inputs.insert(INPUT_Y.to_string(), k); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let results = z.to_vec2::()?; + + assert_eq!( + results, + vec![ + vec![4, 7, 3, 7, 9], + vec![1, 2, 8, 6, 9], + vec![0, 4, 0, 8, 7], + vec![0, 0, 4, 2, 4] + ] + ); + } + + // Test 4: Lower triangular matrix (upper=0) + { + let att_upper = AttributeProto { + name: "upper".to_string(), + ref_attr_name: "upper".to_string(), + i: 0, // 0 means false, use lower triangular + doc_string: "upper".to_string(), + r#type: 2, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Trilu".to_string(), + domain: "".to_string(), + attribute: vec![att_upper], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let x = Tensor::from_vec( + vec![ + 4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4, + ], + &[4, 5], + &Device::Cpu, + )?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let results = z.to_vec2::()?; + + // Lower triangular matrix (default k=0) + assert_eq!( + results, + vec![ + vec![4, 0, 0, 0, 0], + vec![1, 2, 0, 0, 0], + vec![9, 4, 1, 0, 0], + vec![4, 3, 4, 2, 0] + ] + ); + } + + // Test 5: Lower triangular with negative k=-1 + { + let att_upper = AttributeProto { + name: "upper".to_string(), + ref_attr_name: "upper".to_string(), + i: 0, + doc_string: "upper".to_string(), + r#type: 2, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Trilu".to_string(), + domain: "".to_string(), + attribute: vec![att_upper], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let x = Tensor::from_vec( + vec![ + 4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4, + ], + &[4, 5], + &Device::Cpu, + )?; + + let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + inputs.insert(INPUT_Y.to_string(), k); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let results = z.to_vec2::()?; + + assert_eq!( + results, + vec![ + vec![0, 0, 0, 0, 0], + vec![1, 0, 0, 0, 0], + vec![9, 4, 0, 0, 0], + vec![4, 3, 4, 0, 0] + ] + ); + } + + // Test 6: Lower triangular with positive k=2 + { + let att_upper = AttributeProto { + name: "upper".to_string(), + ref_attr_name: "upper".to_string(), + i: 0, + doc_string: "upper".to_string(), + r#type: 2, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Trilu".to_string(), + domain: "".to_string(), + attribute: vec![att_upper], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let x = Tensor::from_vec( + vec![ + 4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4, + ], + &[4, 5], + &Device::Cpu, + )?; + + let k = Tensor::from_vec(vec![2i64], (), &Device::Cpu)?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + inputs.insert(INPUT_Y.to_string(), k); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let results = z.to_vec2::()?; + + assert_eq!( + results, + vec![ + vec![4, 7, 3, 0, 0], + vec![1, 2, 8, 6, 0], + vec![9, 4, 1, 8, 7], + vec![4, 3, 4, 2, 4] + ] + ); + } + + Ok(()) +}