mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
candle-onnx: Implement Trilu and ScatterND ops (#2952)
* onnx attention * setup an example, adding and fixing onnx ops bit by bit * model working, output is garbage data * trilu working * close but not quite, Issues still with scatterND * closer but the outputs are still slightly wrong * added tests for trilu and scatterND * lint * readme * clippy * removed unnessisary comments * changed device selection, took hyperparameters from model config
This commit is contained in:
@ -84,6 +84,10 @@ required-features = ["pyo3"]
|
||||
name = "onnx"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx-llm"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
|
11
candle-examples/examples/onnx-llm/README.md
Normal file
11
candle-examples/examples/onnx-llm/README.md
Normal file
@ -0,0 +1,11 @@
|
||||
## Using ONNX models in Candle
|
||||
|
||||
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based LLM models in Candle.
|
||||
|
||||
This script only implements SmolLM-135M right now.
|
||||
|
||||
You can run the examples with following commands:
|
||||
|
||||
```bash
|
||||
cargo run --example onnx-llm --features onnx
|
||||
```
|
209
candle-examples/examples/onnx-llm/main.rs
Normal file
209
candle-examples/examples/onnx-llm/main.rs
Normal file
@ -0,0 +1,209 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Tensor};
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
use serde::Deserialize;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub num_hidden_layers: usize,
|
||||
pub num_key_value_heads: usize,
|
||||
pub hidden_size: usize,
|
||||
pub num_attention_heads: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
SmolLM135M,
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
/// The prompt to be used.
|
||||
#[arg(long, default_value = "My favorite theorem is ")]
|
||||
prompt: String,
|
||||
|
||||
/// The model to be used.
|
||||
#[arg(value_enum, long, default_value_t = Which::SmolLM135M)]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The number of tokens to generate.
|
||||
#[arg(long, default_value_t = 100)]
|
||||
max_tokens: usize,
|
||||
|
||||
/// The temperature used for sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f32,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
}
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let (model_id, tokenizer_id) = match args.which {
|
||||
Which::SmolLM135M => ("HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-135M"),
|
||||
};
|
||||
|
||||
let api = Api::new()?;
|
||||
let model_repo = api.model(model_id.to_string());
|
||||
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
||||
|
||||
let model_path = model_repo.get("onnx/model.onnx")?;
|
||||
let config_file = model_repo.get("config.json")?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
|
||||
|
||||
let tokenizer_path = tokenizer_repo.get("tokenizer.json")?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||
|
||||
let tokens_u32 = tokenizer
|
||||
.encode(args.prompt.as_str(), true)
|
||||
.map_err(anyhow::Error::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let tokens: Vec<i64> = tokens_u32.iter().map(|&t| t as i64).collect();
|
||||
|
||||
println!("Loading ONNX model from {:?}", model_path);
|
||||
let model = candle_onnx::read_file(model_path)?;
|
||||
|
||||
let mut generated_tokens = tokens.clone();
|
||||
print!("{}", args.prompt);
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature as f64;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let mut past_key_values: Option<Vec<(Tensor, Tensor)>> = None;
|
||||
let num_layers = config.num_hidden_layers;
|
||||
|
||||
for _ in 0..args.max_tokens {
|
||||
let mut inputs = std::collections::HashMap::new();
|
||||
|
||||
if let Some(past_kv) = &past_key_values {
|
||||
let last_token = vec![generated_tokens[generated_tokens.len() - 1]];
|
||||
let input_tensor = Tensor::new(last_token, &device)?.unsqueeze(0)?;
|
||||
inputs.insert("input_ids".to_string(), input_tensor);
|
||||
|
||||
let seq_len = generated_tokens.len();
|
||||
let attention_mask = vec![vec![1i64; seq_len]];
|
||||
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
|
||||
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
|
||||
|
||||
let position_ids = vec![vec![(seq_len - 1) as i64]];
|
||||
let position_ids_tensor = Tensor::new(position_ids, &device)?;
|
||||
inputs.insert("position_ids".to_string(), position_ids_tensor);
|
||||
|
||||
for (i, (key, value)) in past_kv.iter().enumerate() {
|
||||
inputs.insert(format!("past_key_values.{}.key", i), key.clone());
|
||||
inputs.insert(format!("past_key_values.{}.value", i), value.clone());
|
||||
}
|
||||
} else {
|
||||
let input_tensor = Tensor::new(generated_tokens.clone(), &device)?.unsqueeze(0)?;
|
||||
inputs.insert("input_ids".to_string(), input_tensor);
|
||||
|
||||
let seq_len = generated_tokens.len();
|
||||
let attention_mask = vec![vec![1i64; seq_len]];
|
||||
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
|
||||
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
|
||||
|
||||
let position_ids: Vec<i64> = (0..seq_len as i64).collect();
|
||||
let position_ids_tensor = Tensor::new(position_ids, &device)?.unsqueeze(0)?;
|
||||
inputs.insert("position_ids".to_string(), position_ids_tensor);
|
||||
|
||||
// Create empty key and value tensors
|
||||
for i in 0..num_layers {
|
||||
let batch_size = 1;
|
||||
let num_heads = config.num_key_value_heads;
|
||||
let head_dim = config.hidden_size / config.num_attention_heads;
|
||||
let seq_len = 0;
|
||||
|
||||
let empty_key = Tensor::zeros(
|
||||
&[batch_size, num_heads, seq_len, head_dim],
|
||||
DType::F32,
|
||||
&device,
|
||||
)?;
|
||||
let empty_value = Tensor::zeros(
|
||||
&[batch_size, num_heads, seq_len, head_dim],
|
||||
DType::F32,
|
||||
&device,
|
||||
)?;
|
||||
|
||||
inputs.insert(format!("past_key_values.{}.key", i), empty_key);
|
||||
inputs.insert(format!("past_key_values.{}.value", i), empty_value);
|
||||
}
|
||||
}
|
||||
|
||||
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||
|
||||
let logits = outputs.get("logits").unwrap();
|
||||
|
||||
let mut new_past_kv = Vec::with_capacity(num_layers);
|
||||
for i in 0..num_layers {
|
||||
let key = outputs
|
||||
.get(&format!("present.{}.key", i))
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.key", i))?;
|
||||
let value = outputs
|
||||
.get(&format!("present.{}.value", i))
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.value", i))?;
|
||||
new_past_kv.push((key.clone(), value.clone()));
|
||||
}
|
||||
past_key_values = Some(new_past_kv);
|
||||
|
||||
let logits_dim = logits.dims();
|
||||
let seq_len = logits_dim[1];
|
||||
|
||||
let next_token_id = logits_processor.sample(&logits.get(0)?.get(seq_len - 1)?)?;
|
||||
generated_tokens.push(next_token_id as i64);
|
||||
|
||||
if let Some(token_str) = tokenizer.decode(&[next_token_id], true).ok() {
|
||||
print!("{}", token_str);
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
if let Some(eos_id) = tokenizer.token_to_id("<|endoftext|>") {
|
||||
if next_token_id == eos_id {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("\nGeneration complete!");
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user