mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Add some preliminary ONNX support (#1260)
* Add the onnx protos. * Move the reading bits. * Install protoc on the CI. * Install protoc on the cuda CI too. * Use clap for the onnx tool. * Tweak the CI protoc install. * Add some simple evalution function. * Add some binary operator support.
This commit is contained in:
56
candle-onnx/examples/onnx_basics.rs
Normal file
56
candle-onnx/examples/onnx_basics.rs
Normal file
@ -0,0 +1,56 @@
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
enum Command {
|
||||
Print {
|
||||
#[arg(long)]
|
||||
file: String,
|
||||
},
|
||||
SimpleEval {
|
||||
#[arg(long)]
|
||||
file: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
pub struct Args {
|
||||
#[command(subcommand)]
|
||||
command: Command,
|
||||
}
|
||||
|
||||
pub fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
match args.command {
|
||||
Command::Print { file } => {
|
||||
let model = candle_onnx::read_file(file)?;
|
||||
println!("{model:?}");
|
||||
let graph = model.graph.unwrap();
|
||||
for node in graph.node.iter() {
|
||||
println!("{node:?}");
|
||||
}
|
||||
}
|
||||
Command::SimpleEval { file } => {
|
||||
let model = candle_onnx::read_file(file)?;
|
||||
let inputs = model
|
||||
.graph
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.input
|
||||
.iter()
|
||||
.map(|name| {
|
||||
let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?;
|
||||
Ok((name.name.clone(), value))
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||
for (name, value) in outputs.iter() {
|
||||
println!("{name}: {value:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
Reference in New Issue
Block a user