mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 01:48:08 +00:00
candle-onnx: Implement Trilu and ScatterND ops (#2952)
* onnx attention * setup an example, adding and fixing onnx ops bit by bit * model working, output is garbage data * trilu working * close but not quite, Issues still with scatterND * closer but the outputs are still slightly wrong * added tests for trilu and scatterND * lint * readme * clippy * removed unnessisary comments * changed device selection, took hyperparameters from model config
This commit is contained in:
@ -84,6 +84,10 @@ required-features = ["pyo3"]
|
||||
name = "onnx"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx-llm"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
|
11
candle-examples/examples/onnx-llm/README.md
Normal file
11
candle-examples/examples/onnx-llm/README.md
Normal file
@ -0,0 +1,11 @@
|
||||
## Using ONNX models in Candle
|
||||
|
||||
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based LLM models in Candle.
|
||||
|
||||
This script only implements SmolLM-135M right now.
|
||||
|
||||
You can run the examples with following commands:
|
||||
|
||||
```bash
|
||||
cargo run --example onnx-llm --features onnx
|
||||
```
|
209
candle-examples/examples/onnx-llm/main.rs
Normal file
209
candle-examples/examples/onnx-llm/main.rs
Normal file
@ -0,0 +1,209 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Tensor};
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
use serde::Deserialize;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_attention_heads: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
SmolLM135M,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
/// The prompt to be used.
|
||||
#[arg(long, default_value = "My favorite theorem is ")]
|
||||
prompt: String,
|
||||
|
||||
/// The model to be used.
|
||||
#[arg(value_enum, long, default_value_t = Which::SmolLM135M)]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The number of tokens to generate.
|
||||
#[arg(long, default_value_t = 100)]
|
||||
max_tokens: usize,
|
||||
|
||||
/// The temperature used for sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f32,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let (model_id, tokenizer_id) = match args.which {
|
||||
Which::SmolLM135M => ("HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-135M"),
|
||||
};
|
||||
|
||||
let api = Api::new()?;
|
||||
let model_repo = api.model(model_id.to_string());
|
||||
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
||||
|
||||
let model_path = model_repo.get("onnx/model.onnx")?;
|
||||
let config_file = model_repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
|
||||
|
||||
let tokenizer_path = tokenizer_repo.get("tokenizer.json")?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||
|
||||
let tokens_u32 = tokenizer
|
||||
.encode(args.prompt.as_str(), true)
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let tokens: Vec<i64> = tokens_u32.iter().map(|&t| t as i64).collect();
|
||||
|
||||
println!("Loading ONNX model from {:?}", model_path);
|
||||
let model = candle_onnx::read_file(model_path)?;
|
||||
|
||||
let mut generated_tokens = tokens.clone();
|
||||
print!("{}", args.prompt);
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature as f64;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let mut past_key_values: Option<Vec<(Tensor, Tensor)>> = None;
|
||||
let num_layers = config.num_hidden_layers;
|
||||
|
||||
for _ in 0..args.max_tokens {
|
||||
let mut inputs = std::collections::HashMap::new();
|
||||
|
||||
if let Some(past_kv) = &past_key_values {
|
||||
let last_token = vec![generated_tokens[generated_tokens.len() - 1]];
|
||||
let input_tensor = Tensor::new(last_token, &device)?.unsqueeze(0)?;
|
||||
inputs.insert("input_ids".to_string(), input_tensor);
|
||||
|
||||
let seq_len = generated_tokens.len();
|
||||
let attention_mask = vec![vec![1i64; seq_len]];
|
||||
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
|
||||
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
|
||||
|
||||
let position_ids = vec![vec![(seq_len - 1) as i64]];
|
||||
let position_ids_tensor = Tensor::new(position_ids, &device)?;
|
||||
inputs.insert("position_ids".to_string(), position_ids_tensor);
|
||||
|
||||
for (i, (key, value)) in past_kv.iter().enumerate() {
|
||||
inputs.insert(format!("past_key_values.{}.key", i), key.clone());
|
||||
inputs.insert(format!("past_key_values.{}.value", i), value.clone());
|
||||
}
|
||||
} else {
|
||||
let input_tensor = Tensor::new(generated_tokens.clone(), &device)?.unsqueeze(0)?;
|
||||
inputs.insert("input_ids".to_string(), input_tensor);
|
||||
|
||||
let seq_len = generated_tokens.len();
|
||||
let attention_mask = vec![vec![1i64; seq_len]];
|
||||
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
|
||||
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
|
||||
|
||||
let position_ids: Vec<i64> = (0..seq_len as i64).collect();
|
||||
let position_ids_tensor = Tensor::new(position_ids, &device)?.unsqueeze(0)?;
|
||||
inputs.insert("position_ids".to_string(), position_ids_tensor);
|
||||
|
||||
// Create empty key and value tensors
|
||||
for i in 0..num_layers {
|
||||
let batch_size = 1;
|
||||
let num_heads = config.num_key_value_heads;
|
||||
let head_dim = config.hidden_size / config.num_attention_heads;
|
||||
let seq_len = 0;
|
||||
|
||||
let empty_key = Tensor::zeros(
|
||||
&[batch_size, num_heads, seq_len, head_dim],
|
||||
DType::F32,
|
||||
&device,
|
||||
)?;
|
||||
let empty_value = Tensor::zeros(
|
||||
&[batch_size, num_heads, seq_len, head_dim],
|
||||
DType::F32,
|
||||
&device,
|
||||
)?;
|
||||
|
||||
inputs.insert(format!("past_key_values.{}.key", i), empty_key);
|
||||
inputs.insert(format!("past_key_values.{}.value", i), empty_value);
|
||||
}
|
||||
}
|
||||
|
||||
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||
|
||||
let logits = outputs.get("logits").unwrap();
|
||||
|
||||
let mut new_past_kv = Vec::with_capacity(num_layers);
|
||||
for i in 0..num_layers {
|
||||
let key = outputs
|
||||
.get(&format!("present.{}.key", i))
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.key", i))?;
|
||||
let value = outputs
|
||||
.get(&format!("present.{}.value", i))
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.value", i))?;
|
||||
new_past_kv.push((key.clone(), value.clone()));
|
||||
}
|
||||
past_key_values = Some(new_past_kv);
|
||||
|
||||
let logits_dim = logits.dims();
|
||||
let seq_len = logits_dim[1];
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits.get(0)?.get(seq_len - 1)?)?;
|
||||
generated_tokens.push(next_token_id as i64);
|
||||
|
||||
if let Some(token_str) = tokenizer.decode(&[next_token_id], true).ok() {
|
||||
print!("{}", token_str);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
if let Some(eos_id) = tokenizer.token_to_id("<|endoftext|>") {
|
||||
if next_token_id == eos_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("\nGeneration complete!");
|
||||
Ok(())
|
||||
}
|
@ -583,7 +583,13 @@ fn simple_eval_(
|
||||
&Device::Cpu,
|
||||
)?);
|
||||
|
||||
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
|
||||
let shape_vec: Vec<usize> = input
|
||||
.to_vec1::<i64>()?
|
||||
.iter()
|
||||
.map(|&x| x as usize)
|
||||
.collect();
|
||||
|
||||
let xs = Tensor::ones(shape_vec, value.dtype(), input.device())?
|
||||
.broadcast_mul(&value)?;
|
||||
values.insert(node.output[0].clone(), xs);
|
||||
}
|
||||
@ -1238,7 +1244,7 @@ fn simple_eval_(
|
||||
}
|
||||
|
||||
let indexes = Tensor::arange_step(s, e, p, data.device())?;
|
||||
out = out.index_select(&indexes, axis)?
|
||||
out = out.contiguous()?.index_select(&indexes, axis)?
|
||||
}
|
||||
values.insert(node.output[0].clone(), out);
|
||||
}
|
||||
@ -2030,6 +2036,203 @@ fn simple_eval_(
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Trilu" => {
|
||||
let input = get(&node.input[0])?;
|
||||
|
||||
// Get the diagonal offset 'k' from the second input if provided
|
||||
let k = if node.input.len() > 1 && !node.input[1].is_empty() {
|
||||
get(&node.input[1])?.to_vec0::<i64>()?
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// Get the 'upper' attribute
|
||||
let upper = get_attr_opt::<i64>(node, "upper")?.copied().unwrap_or(1);
|
||||
|
||||
// For batched inputs, we need to handle each matrix separately
|
||||
let dims = input.dims();
|
||||
if dims.len() < 2 {
|
||||
bail!("Trilu expects input with at least 2 dimensions: {:?}", dims);
|
||||
}
|
||||
|
||||
// Get the last two dimensions which represent the matrix
|
||||
let n = dims[dims.len() - 2];
|
||||
let m = dims[dims.len() - 1];
|
||||
let max_dim = std::cmp::max(n, m);
|
||||
|
||||
// Handle the diagonal offset k
|
||||
let mask = if k != 0 {
|
||||
let mut data = vec![0u32; n * m];
|
||||
for i in 0..n {
|
||||
for j in 0..m {
|
||||
if (upper != 0 && (j as i64) >= (i as i64) + k)
|
||||
|| (upper == 0 && (j as i64) <= (i as i64) + k)
|
||||
{
|
||||
data[i * m + j] = 1u32;
|
||||
}
|
||||
}
|
||||
}
|
||||
Tensor::from_vec(data, (n, m), input.device())?.to_dtype(input.dtype())?
|
||||
} else if upper == 0 {
|
||||
Tensor::tril2(max_dim, input.dtype(), input.device())?
|
||||
} else {
|
||||
Tensor::triu2(max_dim, input.dtype(), input.device())?
|
||||
};
|
||||
|
||||
let final_mask = if n != m {
|
||||
mask.narrow(0, 0, n)?.narrow(1, 0, m)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
|
||||
let output = (input * &final_mask)?;
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"ScatterND" => {
|
||||
let data = get(&node.input[0])?;
|
||||
|
||||
let indices = get(&node.input[1])?;
|
||||
let indices = indices.to_dtype(DType::I64)?;
|
||||
|
||||
let updates = get(&node.input[2])?;
|
||||
|
||||
let reduction = get_attr_opt::<str>(node, "reduction")?.unwrap_or("none");
|
||||
|
||||
let indices_shape = indices.dims();
|
||||
let data_shape = data.dims();
|
||||
let updates_shape = updates.dims();
|
||||
|
||||
// Last dimension of indices represents the depth of indexing
|
||||
let k = indices_shape.last().unwrap().clone();
|
||||
|
||||
if k > data.rank() {
|
||||
bail!("ScatterND expects k (indices.shape[-1]) to be at most the rank of data");
|
||||
}
|
||||
|
||||
let num_updates = indices_shape[..indices_shape.len() - 1]
|
||||
.iter()
|
||||
.product::<usize>();
|
||||
|
||||
let flat_indices = if indices.rank() == 1 && k == 1 {
|
||||
indices.unsqueeze(0)?
|
||||
} else {
|
||||
indices.reshape((num_updates, k))?
|
||||
};
|
||||
|
||||
// Calculate the shape of each update element
|
||||
let update_element_shape = if k < data_shape.len() {
|
||||
data_shape[k..].to_vec()
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Expected shape for updates based on indices and target tensor
|
||||
let expected_updates_shape = {
|
||||
let mut shape = indices_shape[..indices_shape.len() - 1].to_vec();
|
||||
shape.extend(&update_element_shape);
|
||||
shape
|
||||
};
|
||||
|
||||
// Validate or reshape updates to expected shape
|
||||
let updates = if updates.dims() != expected_updates_shape {
|
||||
if updates.rank() == 0 {
|
||||
// Handle scalar updates
|
||||
let mut target_shape = vec![num_updates];
|
||||
target_shape.extend(&update_element_shape);
|
||||
updates.broadcast_as(target_shape)?
|
||||
} else {
|
||||
// Try to broadcast or reshape updates to expected shape
|
||||
let flat_shape =
|
||||
vec![num_updates, update_element_shape.iter().product::<usize>()];
|
||||
let flattened = updates.reshape(flat_shape)?;
|
||||
flattened.reshape(expected_updates_shape)?
|
||||
}
|
||||
} else {
|
||||
updates.clone()
|
||||
};
|
||||
|
||||
let mut output = data.clone();
|
||||
|
||||
// convert indices to flat indices
|
||||
let mut flat_output = output.flatten_all()?;
|
||||
let flat_updates = if update_element_shape.is_empty() {
|
||||
updates.reshape(num_updates)?
|
||||
} else {
|
||||
let product = update_element_shape.iter().product::<usize>();
|
||||
updates.reshape((num_updates, product))?
|
||||
};
|
||||
|
||||
// Calculate strides for the output tensor
|
||||
let mut strides: Vec<usize> = vec![1];
|
||||
for i in (0..data_shape.len() - 1).rev() {
|
||||
strides.push(strides.last().unwrap() * data_shape[i + 1]);
|
||||
}
|
||||
strides.reverse();
|
||||
|
||||
// Process each update
|
||||
for i in 0..num_updates {
|
||||
let index_slice = flat_indices.narrow(0, i, 1)?;
|
||||
let indices_vec = index_slice.squeeze(0)?.to_vec1::<i64>()?;
|
||||
|
||||
// Convert multi-dimensional indices to flat index
|
||||
let mut flat_idx: usize = 0;
|
||||
for (dim, &idx) in indices_vec.iter().enumerate() {
|
||||
let dim_size = data_shape[dim] as i64;
|
||||
let norm_idx = if idx < 0 { dim_size + idx } else { idx };
|
||||
|
||||
if norm_idx < 0 || norm_idx >= dim_size {
|
||||
bail!(
|
||||
"Index {} out of bounds for dimension {} with size {}",
|
||||
idx,
|
||||
dim,
|
||||
dim_size
|
||||
);
|
||||
}
|
||||
|
||||
flat_idx += (norm_idx as usize) * strides[dim];
|
||||
}
|
||||
|
||||
// Extract current update
|
||||
let update_slice = if update_element_shape.is_empty() {
|
||||
flat_updates.narrow(0, i, 1)?.squeeze(0)?
|
||||
} else {
|
||||
flat_updates.narrow(0, i, 1)?
|
||||
};
|
||||
|
||||
match reduction {
|
||||
"add" => {
|
||||
if update_element_shape.is_empty() {
|
||||
let existing = flat_output.narrow(0, flat_idx, 1)?;
|
||||
let new_value = existing.add(&update_slice.unsqueeze(0)?)?;
|
||||
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
|
||||
} else {
|
||||
let slice_size = update_element_shape.iter().product::<usize>();
|
||||
let existing = flat_output.narrow(0, flat_idx, slice_size)?;
|
||||
let new_value = existing.add(&update_slice)?;
|
||||
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
|
||||
}
|
||||
}
|
||||
"none" | _ => {
|
||||
if update_element_shape.is_empty() {
|
||||
flat_output = flat_output.slice_scatter(
|
||||
&update_slice.unsqueeze(0)?,
|
||||
0,
|
||||
flat_idx,
|
||||
)?;
|
||||
} else {
|
||||
flat_output =
|
||||
flat_output.slice_scatter(&update_slice, 0, flat_idx)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reshape flat output back to original shape
|
||||
output = flat_output.reshape(data_shape.to_vec())?;
|
||||
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
@ -842,13 +842,22 @@ fn test_flatten_operation() -> Result<()> {
|
||||
#[test]
|
||||
fn test_constant_of_shape() -> Result<()> {
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
||||
test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?;
|
||||
test(
|
||||
&[4i64, 3, 2],
|
||||
Some(1.),
|
||||
&[
|
||||
[[1., 1.], [1., 1.], [1., 1.]],
|
||||
[[1., 1.], [1., 1.], [1., 1.]],
|
||||
[[1., 1.], [1., 1.], [1., 1.]],
|
||||
[[1., 1.], [1., 1.], [1., 1.]],
|
||||
],
|
||||
)?;
|
||||
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
||||
test(&[0.], Some(0i64), &[0i64])?;
|
||||
test(&[1i64], Some(0i64), &[0i64])?;
|
||||
|
||||
// "value" defaults to 0 f32
|
||||
test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?;
|
||||
test(&[4i64], None as Option<i64>, &[0., 0., 0., 0.])?;
|
||||
|
||||
fn test(
|
||||
input: impl NdArray,
|
||||
@ -5968,3 +5977,512 @@ fn test_sign_operation() -> Result<()> {
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scatternd_operation() -> Result<()> {
|
||||
// Example 1 based on ONNX documentation
|
||||
test(
|
||||
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
|
||||
&[[4i64], [3], [1], [7]],
|
||||
&[9.0f32, 10.0, 11.0, 12.0],
|
||||
&[1.0f32, 11.0, 3.0, 10.0, 9.0, 6.0, 7.0, 12.0],
|
||||
)?;
|
||||
|
||||
// A more complex example with 2D data
|
||||
test(
|
||||
&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
||||
&[[0i64, 1], [1, 0]],
|
||||
&[10.0f32, 20.0],
|
||||
&[[1.0f32, 10.0], [20.0, 4.0], [5.0, 6.0]],
|
||||
)?;
|
||||
|
||||
// 3D example with indices pointing to specific locations
|
||||
test(
|
||||
&[
|
||||
[[1.0f32, 2.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
],
|
||||
&[[0i64, 0, 1], [1, 1, 0]],
|
||||
&[100.0f32, 200.0],
|
||||
&[
|
||||
[[1.0f32, 100.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [200.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
],
|
||||
)?;
|
||||
|
||||
fn test(
|
||||
data: impl NdArray,
|
||||
indices: impl NdArray,
|
||||
updates: impl NdArray,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "ScatterND".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![
|
||||
INPUT_X.to_string(),
|
||||
INPUT_Y.to_string(),
|
||||
INPUT_A.to_string(),
|
||||
],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
|
||||
inputs.insert(INPUT_A.to_string(), Tensor::new(updates, &Device::Cpu)?);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
|
||||
match expected.dims().len() {
|
||||
1 => assert_eq!(z.to_vec1::<f32>()?, expected.to_vec1::<f32>()?),
|
||||
2 => assert_eq!(z.to_vec2::<f32>()?, expected.to_vec2::<f32>()?),
|
||||
3 => assert_eq!(z.to_vec3::<f32>()?, expected.to_vec3::<f32>()?),
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trilu_operation() -> Result<()> {
|
||||
// Test 1: Upper triangular matrix (default behavior with upper=true)
|
||||
{
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![], // empty attribute means default upper=true
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![4, 7, 3, 7, 9],
|
||||
vec![0, 2, 8, 6, 9],
|
||||
vec![0, 0, 0, 8, 7],
|
||||
vec![0, 0, 0, 2, 4]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 2: Upper triangular with positive k=1 (diagonal above main)
|
||||
{
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![1i64, 4, 9, 7, 1, 9, 2, 8, 8, 4, 3, 9, 7, 4, 2],
|
||||
&[3, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let k = Tensor::from_vec(vec![1i64], (), &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), k);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![0, 4, 9, 7, 1],
|
||||
vec![0, 0, 8, 8, 4],
|
||||
vec![0, 0, 0, 4, 2]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 3: Upper triangular with negative k=-1 (one diagonal below main)
|
||||
{
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), k);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![4, 7, 3, 7, 9],
|
||||
vec![1, 2, 8, 6, 9],
|
||||
vec![0, 4, 0, 8, 7],
|
||||
vec![0, 0, 4, 2, 4]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 4: Lower triangular matrix (upper=0)
|
||||
{
|
||||
let att_upper = AttributeProto {
|
||||
name: "upper".to_string(),
|
||||
ref_attr_name: "upper".to_string(),
|
||||
i: 0, // 0 means false, use lower triangular
|
||||
doc_string: "upper".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_upper],
|
||||
input: vec![INPUT_X.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
// Lower triangular matrix (default k=0)
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![4, 0, 0, 0, 0],
|
||||
vec![1, 2, 0, 0, 0],
|
||||
vec![9, 4, 1, 0, 0],
|
||||
vec![4, 3, 4, 2, 0]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 5: Lower triangular with negative k=-1
|
||||
{
|
||||
let att_upper = AttributeProto {
|
||||
name: "upper".to_string(),
|
||||
ref_attr_name: "upper".to_string(),
|
||||
i: 0,
|
||||
doc_string: "upper".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_upper],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), k);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![0, 0, 0, 0, 0],
|
||||
vec![1, 0, 0, 0, 0],
|
||||
vec![9, 4, 0, 0, 0],
|
||||
vec![4, 3, 4, 0, 0]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
// Test 6: Lower triangular with positive k=2
|
||||
{
|
||||
let att_upper = AttributeProto {
|
||||
name: "upper".to_string(),
|
||||
ref_attr_name: "upper".to_string(),
|
||||
i: 0,
|
||||
doc_string: "upper".to_string(),
|
||||
r#type: 2,
|
||||
f: 0.0,
|
||||
s: vec![],
|
||||
t: None,
|
||||
g: None,
|
||||
sparse_tensor: None,
|
||||
tp: None,
|
||||
floats: vec![],
|
||||
ints: vec![],
|
||||
strings: vec![],
|
||||
tensors: vec![],
|
||||
graphs: vec![],
|
||||
sparse_tensors: vec![],
|
||||
type_protos: vec![],
|
||||
};
|
||||
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "Trilu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![att_upper],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
|
||||
let x = Tensor::from_vec(
|
||||
vec![
|
||||
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
|
||||
],
|
||||
&[4, 5],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let k = Tensor::from_vec(vec![2i64], (), &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), k);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<i64>()?;
|
||||
|
||||
assert_eq!(
|
||||
results,
|
||||
vec![
|
||||
vec![4, 7, 3, 0, 0],
|
||||
vec![1, 2, 8, 6, 0],
|
||||
vec![9, 4, 1, 8, 7],
|
||||
vec![4, 3, 4, 2, 4]
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user