mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 19:47:12 +00:00
[ONNX] Do not generate values for constants. (#1272)
* Do not generate values for constants. * Add an onnx based example using squeezenet.
This commit is contained in:
@ -35,33 +35,34 @@ pub fn main() -> Result<()> {
|
||||
}
|
||||
Command::SimpleEval { file } => {
|
||||
let model = candle_onnx::read_file(file)?;
|
||||
let inputs = model
|
||||
.graph
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.input
|
||||
.iter()
|
||||
.map(|input| {
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
let graph = model.graph.as_ref().unwrap();
|
||||
let constants: std::collections::HashSet<_> =
|
||||
graph.initializer.iter().map(|i| i.name.as_str()).collect();
|
||||
let mut inputs = std::collections::HashMap::new();
|
||||
for input in graph.input.iter() {
|
||||
use candle_onnx::onnx::tensor_proto::DataType;
|
||||
if constants.contains(input.name.as_str()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let type_ = input.r#type.as_ref().expect("no type for input");
|
||||
let type_ = type_.value.as_ref().expect("no type.value for input");
|
||||
let value = match type_ {
|
||||
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
|
||||
let dt = match DataType::try_from(tt.elem_type) {
|
||||
Ok(dt) => match candle_onnx::dtype(dt) {
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
anyhow::bail!(
|
||||
"unsupported 'value' data-type {dt:?} for {}",
|
||||
input.name
|
||||
)
|
||||
}
|
||||
},
|
||||
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
|
||||
let dims = shape
|
||||
let type_ = input.r#type.as_ref().expect("no type for input");
|
||||
let type_ = type_.value.as_ref().expect("no type.value for input");
|
||||
let value = match type_ {
|
||||
candle_onnx::onnx::type_proto::Value::TensorType(tt) => {
|
||||
let dt = match DataType::try_from(tt.elem_type) {
|
||||
Ok(dt) => match candle_onnx::dtype(dt) {
|
||||
Some(dt) => dt,
|
||||
None => {
|
||||
anyhow::bail!(
|
||||
"unsupported 'value' data-type {dt:?} for {}",
|
||||
input.name
|
||||
)
|
||||
}
|
||||
},
|
||||
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
let shape = tt.shape.as_ref().expect("no tensortype.shape for input");
|
||||
let dims = shape
|
||||
.dim
|
||||
.iter()
|
||||
.map(|dim| match dim.value.as_ref().expect("no dim value") {
|
||||
@ -69,16 +70,16 @@ pub fn main() -> Result<()> {
|
||||
candle_onnx::onnx::tensor_shape_proto::dimension::Value::DimParam(_) => anyhow::bail!("DimParam is unsupported for input {}", input.name),
|
||||
})
|
||||
.collect::<Result<Vec<usize>>>()?;
|
||||
Tensor::zeros(dims, dt, &Device::Cpu)?
|
||||
}
|
||||
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
Ok::<_, anyhow::Error>((input.name.clone(), value))
|
||||
})
|
||||
.collect::<Result<_>>()?;
|
||||
Tensor::zeros(dims, dt, &Device::Cpu)?
|
||||
}
|
||||
type_ => anyhow::bail!("unsupported input type {type_:?}"),
|
||||
};
|
||||
println!("input {}: {value:?}", input.name);
|
||||
inputs.insert(input.name.clone(), value);
|
||||
}
|
||||
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||
for (name, value) in outputs.iter() {
|
||||
println!("{name}: {value:?}")
|
||||
println!("output {name}: {value:?}")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user