Files
candle/candle-onnx/src/eval.rs
2024-10-17 20:22:35 +02:00

1992 lines
86 KiB
Rust

use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::collections::{HashMap, HashSet};
pub type Value = Tensor;
pub fn dtype(dt: DataType) -> Option<DType> {
match dt {
DataType::Uint8 => Some(DType::U8),
DataType::Uint32 => Some(DType::U32),
DataType::Int64 => Some(DType::I64),
DataType::Float16 => Some(DType::F16),
DataType::Float => Some(DType::F32),
DataType::Double => Some(DType::F64),
DataType::Bool => Some(DType::U8),
_ => None,
}
}
trait Attr {
const TYPE: AttributeType;
fn get(attr: &onnx::AttributeProto) -> Result<&Self>;
}
trait AttrOwned: Sized {
const TYPE: AttributeType;
fn get(attr: &onnx::AttributeProto) -> Result<Self>;
}
impl Attr for i64 {
const TYPE: AttributeType = AttributeType::Int;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(&attr.i)
}
}
impl Attr for f32 {
const TYPE: AttributeType = AttributeType::Float;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(&attr.f)
}
}
impl Attr for [i64] {
const TYPE: AttributeType = AttributeType::Ints;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
Ok(attr.ints.as_slice())
}
}
impl Attr for str {
const TYPE: AttributeType = AttributeType::String;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
std::str::from_utf8(&attr.s).map_err(candle::Error::wrap)
}
}
impl Attr for GraphProto {
const TYPE: AttributeType = AttributeType::Graph;
fn get(attr: &onnx::AttributeProto) -> Result<&Self> {
attr.g
.as_ref()
.ok_or_else(|| candle::Error::Msg("attribute does not contain graph".to_string()))
}
}
impl AttrOwned for Vec<String> {
const TYPE: AttributeType = AttributeType::Strings;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
let mut ret = vec![];
for bytes in attr.strings.iter() {
let s = String::from_utf8(bytes.clone()).map_err(candle::Error::wrap)?;
ret.push(s);
}
Ok(ret)
}
}
impl AttrOwned for Tensor {
const TYPE: AttributeType = AttributeType::Tensor;
fn get(attr: &onnx::AttributeProto) -> Result<Self> {
let tensor_proto = match &attr.t {
Some(value) => value,
None => bail!(
"attribute {} was of type TENSOR, but no tensor was found",
attr.name
),
};
let data_type = match DataType::try_from(tensor_proto.data_type) {
Ok(value) => value,
Err(_) => bail!(
"attribute {} of type TENSOR was an invalid data_type number {}",
attr.name,
tensor_proto.data_type
),
};
let dtype = match dtype(data_type) {
Some(value) => value,
None => bail!(
"attribute {} of type TENSOR has an unsupported data_type {}",
attr.name,
data_type.as_str_name()
),
};
let mut dims = Vec::with_capacity(tensor_proto.dims.len());
for dim in &tensor_proto.dims {
if dim < &0 {
bail!(
"attribute {} of type TENSOR has a negative dimension, which is unsupported",
attr.name
)
}
dims.push(*dim as usize)
}
Tensor::from_raw_buffer(&tensor_proto.raw_data, dtype, &dims, &Device::Cpu)
}
}
fn get_attr_<'a>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a onnx::AttributeProto> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => {
bail!(
"cannot find the '{name}' attribute in '{}' for {}",
node.op_type,
node.name
)
}
Some(dt) => Ok(dt),
}
}
fn get_attr<'a, T: Attr + ?Sized>(node: &'a onnx::NodeProto, name: &str) -> Result<&'a T> {
let attr = get_attr_(node, name)?;
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
T::get(attr)
}
fn get_attr_opt<'a, T: Attr + ?Sized>(
node: &'a onnx::NodeProto,
name: &str,
) -> Result<Option<&'a T>> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => Ok(None),
Some(attr) => {
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
let val = T::get(attr)?;
Ok(Some(val))
}
}
}
fn get_attr_opt_owned<T: AttrOwned>(node: &onnx::NodeProto, name: &str) -> Result<Option<T>> {
match node.attribute.iter().find(|attr| attr.name == name) {
None => Ok(None),
Some(attr) => {
if attr.r#type() != T::TYPE {
bail!(
"unsupported type {:?} for '{name}' attribute in '{}' for {}",
attr.r#type,
node.op_type,
node.name
)
}
let val = T::get(attr)?;
Ok(Some(val))
}
}
}
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
Ok(DataType::Int32) => {
if t.int32_data.is_empty() {
let len = t.raw_data.len() / 4;
let data: &[i32] =
unsafe { std::slice::from_raw_parts(t.raw_data.as_ptr() as *const i32, len) };
let data = data.iter().map(|v| *v as i64).collect::<Vec<_>>();
Tensor::from_vec(data, len, &Device::Cpu)
} else {
let data = t.int32_data.iter().map(|v| *v as i64).collect::<Vec<_>>();
Tensor::from_vec(data, t.int32_data.len(), &Device::Cpu)
}
}
Ok(dt) => match dtype(dt) {
Some(dt) => {
if dt == DType::F32 && !t.float_data.is_empty() {
Tensor::from_slice(&t.float_data, dims.as_slice(), &Device::Cpu)
} else if dt == DType::F64 && !t.double_data.is_empty() {
Tensor::from_slice(&t.double_data, dims.as_slice(), &Device::Cpu)
} else if dt == DType::I64 && !t.int64_data.is_empty() {
Tensor::from_slice(&t.int64_data, dims.as_slice(), &Device::Cpu)
} else {
Tensor::from_raw_buffer(
t.raw_data.as_slice(),
dt,
dims.as_slice(),
&Device::Cpu,
)
}
}
None => {
bail!("unsupported 'value' data-type {dt:?} for {name}")
}
},
Err(_) => {
bail!("unsupported 'value' data-type {} for {name}", t.data_type,)
}
}
}
// This function provides a direct evaluation of the proto.
// Longer-term, we should first convert the proto to an intermediate representation of the compute
// graph so as to make multiple evaluations more efficient.
// An example upside of this would be to remove intermediary values when they are not needed
// anymore.
pub fn simple_eval(
model: &onnx::ModelProto,
mut inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
let graph = match &model.graph {
None => bail!("no graph defined in proto"),
Some(graph) => graph,
};
simple_eval_(graph, &mut inputs)
}
fn simple_eval_(
graph: &onnx::GraphProto,
values: &mut HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
for t in graph.initializer.iter() {
let tensor = get_tensor(t, t.name.as_str())?;
values.insert(t.name.to_string(), tensor);
}
for input in graph.input.iter() {
let input_type = match &input.r#type {
Some(input_type) => input_type,
None => continue,
};
let input_type = match &input_type.value {
Some(input_type) => input_type,
None => continue,
};
let tensor_type = match input_type {
onnx::type_proto::Value::TensorType(tt) => tt,
_ => continue,
};
let tensor = match values.get(&input.name) {
None => bail!("missing input {}", input.name),
Some(tensor) => tensor,
};
let dt = match DataType::try_from(tensor_type.elem_type) {
Ok(dt) => match dtype(dt) {
Some(dt) => dt,
None => {
bail!("unsupported 'value' data-type {dt:?} for {}", input.name)
}
},
type_ => bail!("unsupported input type {type_:?}"),
};
match &tensor_type.shape {
None => continue,
Some(shape) => {
if shape.dim.len() != tensor.rank() {
bail!(
"unexpected rank for {}, got {:?}, expected {:?}",
input.name,
shape.dim,
tensor.shape()
)
}
for (idx, (d, &dim)) in shape.dim.iter().zip(tensor.dims().iter()).enumerate() {
match &d.value {
Some(onnx::tensor_shape_proto::dimension::Value::DimValue(v)) => {
if *v as usize != dim {
bail!(
"unexpected dim {idx} for {}, got {:?}, expected {:?}",
input.name,
shape.dim,
tensor.shape()
)
}
}
// We do not check equality constraints for the DimParam dimensions for now.
Some(onnx::tensor_shape_proto::dimension::Value::DimParam(_)) | None => (),
}
}
}
};
if dt != tensor.dtype() {
bail!(
"unexpected dtype for {}, got {:?}, expected {dt:?}",
input.name,
tensor.dtype()
)
}
}
// The nodes are topologically sorted so we can just process them in order.
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op '{}'", node.name),
};
let get_opt = |i: usize| {
node.input
.get(i)
.filter(|s: &&String| !s.is_empty())
.map(|s| get(s))
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_add(input1)?;
values.insert(node.output[0].clone(), output);
}
"Sub" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_sub(input1)?;
values.insert(node.output[0].clone(), output);
}
"Mul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_mul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Div" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output);
}
"Pow" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
// HACK: current implementation of broadcast_pow cannot handle negative base,
// so we use powf where we can, which *does* correctly handle negative base.
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
let output = input0.powf(exp)?;
values.insert(node.output[0].clone(), output);
} else {
let output = input0.broadcast_pow(input1)?;
values.insert(node.output[0].clone(), output);
}
}
"Exp" => {
let xs = get(&node.input[0])?;
let output = xs.exp()?;
values.insert(node.output[0].clone(), output);
}
"Equal" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_eq(input1)?;
values.insert(node.output[0].clone(), output);
}
"Not" => {
let xs = get(&node.input[0])?;
let xs = xs.eq(&xs.zeros_like()?)?;
values.insert(node.output[0].clone(), xs);
}
"MatMul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_matmul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Reshape" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?.to_vec1::<i64>()?;
// TODO: Check that there is at most a single -1 or 0, handle other neg values.
let mut other_than_minus1 = 1usize;
for &v in input1.iter() {
if v != -1 && v != 0 {
other_than_minus1 *= v as usize
}
}
let input1 = input1
.iter()
.enumerate()
.map(|(idx, &v)| match v {
-1 => Ok(input0.elem_count() / other_than_minus1),
0 => input0.dim(idx),
_ => Ok(v as usize),
})
.collect::<Result<Vec<usize>>>()?;
let output = input0.reshape(input1)?;
values.insert(node.output[0].clone(), output);
}
"LogSoftmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
let axis = input.normalize_axis(axis)?;
candle_nn::ops::log_softmax(input, axis)?
}
};
values.insert(node.output[0].clone(), output);
}
"Softmax" => {
let input = get(&node.input[0])?;
let output = match get_attr_opt::<i64>(node, "axis")? {
None => candle_nn::ops::softmax_last_dim(input)?,
Some(&axis) => {
let axis = input.normalize_axis(axis)?;
candle_nn::ops::softmax(input, axis)?
}
};
values.insert(node.output[0].clone(), output);
}
"Transpose" => {
let input = get(&node.input[0])?;
let output = match get_attr_opt::<[i64]>(node, "perm")? {
None => input.t()?,
Some(perm) => {
let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>();
input.permute(perm)?
}
};
values.insert(node.output[0].clone(), output);
}
"Dropout" => {
let input = get(&node.input[0])?;
// Do not apply dropout at the moment, consider that we're only doing inference.
values.insert(node.output[0].clone(), input.clone());
}
"MaxPool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#MaxPool
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
if let Some(d) = dilations {
if d.iter().any(|&v| v != 1) {
bail!("MaxPool with dilation != 1, {dilations:?}")
}
}
if let Some(d) = pads {
if d.iter().any(|&v| v != 0) {
bail!("MaxPool with pads != 0, {pads:?}")
}
}
let xs = get(&node.input[0])?;
let (k1, k2) = match kernel_shape {
[k1, k2] => (*k1 as usize, *k2 as usize),
_ => bail!("only 2d MaxPool is supported, kernel shape {kernel_shape:?}"),
};
let ys = match strides {
None => xs.max_pool2d((k1, k2))?,
Some([s1, s2]) => {
xs.max_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
}
Some(strides) => bail!("only 2d MaxPool is supported, strides {strides:?}"),
};
values.insert(node.output[0].clone(), ys);
}
"AveragePool" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#AveragePool
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let kernel_shape = get_attr::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
if let Some(d) = dilations {
if d.iter().any(|&v| v != 1) {
bail!("AvgPool with dilation != 1, {dilations:?}")
}
}
if let Some(d) = pads {
if d.iter().any(|&v| v != 0) {
bail!("AvgPool with pads != 0, {pads:?}")
}
}
let xs = get(&node.input[0])?;
let (k1, k2) = match kernel_shape {
[k1, k2] => (*k1 as usize, *k2 as usize),
_ => bail!("only 2d AvgPool is supported, kernel shape {kernel_shape:?}"),
};
let ys = match strides {
None => xs.avg_pool2d((k1, k2))?,
Some([s1, s2]) => {
xs.avg_pool2d_with_stride((k1, k2), (*s1 as usize, *s2 as usize))?
}
Some(strides) => bail!("only 2d AvgPool is supported, strides {strides:?}"),
};
values.insert(node.output[0].clone(), ys);
}
"BatchNormalization" => {
let training_mode = get_attr_opt::<i64>(node, "training_mode")?;
if training_mode.copied().unwrap_or(0) != 0 {
bail!("training mode is not supported for BatchNorm")
}
let eps = get_attr_opt::<f32>(node, "epsilon")?
.copied()
.unwrap_or(1e-5);
let xs = get(&node.input[0])?;
let weight = get(&node.input[1])?;
let bias = get(&node.input[2])?;
let running_mean = get(&node.input[3])?;
let running_var = get(&node.input[4])?;
let target_shape: Vec<usize> = xs
.dims()
.iter()
.enumerate()
.map(|(idx, v)| if idx == 1 { *v } else { 1 })
.collect();
let target_shape = target_shape.as_slice();
let xs = xs
.broadcast_sub(&running_mean.reshape(target_shape)?)?
.broadcast_div(&(running_var.reshape(target_shape)? + eps as f64)?.sqrt()?)?;
let weight = weight.reshape(target_shape)?;
let bias = bias.reshape(target_shape)?;
let xs = xs.broadcast_mul(&weight)?.broadcast_add(&bias)?;
values.insert(node.output[0].clone(), xs);
}
"Squeeze" => {
let xs = get(&node.input[0])?;
let mut axes = if node.input.len() <= 1 {
// contract all the dimensions with size 1 except the batch dim.
xs.dims()
.iter()
.enumerate()
.flat_map(|(idx, &s)| if s == 1 && idx > 0 { Some(idx) } else { None })
.collect()
} else {
get(&node.input[1])?
.to_vec1::<i64>()?
.iter()
.map(|&i| xs.normalize_axis(i))
.collect::<Result<Vec<_>>>()?
};
axes.sort();
let mut xs = xs.clone();
for &axis in axes.iter().rev() {
xs = xs.squeeze(axis)?
}
values.insert(node.output[0].clone(), xs);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ConstantOfShape
"ConstantOfShape" => {
let input = get(&node.input[0])?;
let value = get_attr_opt_owned::<Tensor>(node, "value")?.unwrap_or(Tensor::zeros(
(),
DType::F32,
&Device::Cpu,
)?);
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
.broadcast_mul(&value)?;
values.insert(node.output[0].clone(), xs);
}
"Unsqueeze" => {
let xs = get(&node.input[0])?;
let axes = match get_attr_opt::<[i64]>(node, "axes")? {
Some(axis) => axis.to_vec(),
None => get(&node.input[1])?.to_vec1::<i64>()?,
};
let mut axes = axes
.iter()
.map(|&i| {
if i == xs.rank() as i64 {
Ok(xs.rank())
} else if i < 0 {
// normalize_axis doesn't work correctly here
// because we actually want normalized with respect
// to the final size, not the current (off by one)
Ok(xs.rank() - (-i as usize) + 1)
} else {
xs.normalize_axis(i)
}
})
.collect::<Result<Vec<_>>>()?;
axes.sort();
let mut xs = xs.clone();
for &axis in axes.iter().rev() {
xs = xs.unsqueeze(axis)?
}
values.insert(node.output[0].clone(), xs);
}
"Clip" => {
let xs = get(&node.input[0])?;
let xs = if let Some(mins) = get_opt(1) {
xs.broadcast_maximum(mins?)?
} else {
xs.clone()
};
let xs = if let Some(maxs) = get_opt(2) {
xs.broadcast_minimum(maxs?)?
} else {
xs.clone()
};
values.insert(node.output[0].clone(), xs);
}
"Gather" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather
let xs = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = xs.normalize_axis(axis)?;
// index_select does not support negative indices, so normalize them
// to positive indices.
let indices = &{
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
let max = Tensor::new(xs.dims()[axis] as i64, indices.device())?
.to_dtype(indices.dtype())?;
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(indices)?
};
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
// tensor directly, but candle does not support tensor indexing at the moment, so
// some workarounds must be done.
let xs = match indices.dims() {
[] => {
let index = indices.to_vec0::<i64>()? as usize;
xs.narrow(axis, index, 1)?.squeeze(axis)?
}
[_] => xs.index_select(indices, axis)?,
[first, _] => {
let mut v = Vec::with_capacity(*first);
for i in 0..*first {
v.push(xs.index_select(&indices.get(i)?, axis)?)
}
Tensor::stack(&v, axis)?
}
_ => {
// TODO: Provide an op to handle the ONNX generalized gather op ideally in a
// differentiable way.
todo!("implement gather for {xs:?} {indices:?} axis {axis}")
}
};
values.insert(node.output[0].clone(), xs);
}
// https://onnx.ai/onnx/operators/onnx__GatherElements.html#gatherelements
// A Note to fellow lurkers:
// The numpy based `gather_elements` implementation in `onnx` tests [here](https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gatherelements.py)
// and examples is incorrect.
// Use `torch.gather` for the validating/ verifying against the proper behaviour
"GatherElements" => {
let data = get(&node.input[0])?;
let indices = get(&node.input[1])?;
let rank = data.rank();
if rank != indices.rank() {
bail!("indices must have same rank as input data. Data rank [{}] != indices rank [{}]", data.rank(), indices.rank());
}
let axis = {
let axis_i64 = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = data.normalize_axis(axis_i64)?;
if axis >= rank {
bail!(
"axis ({}) out of accepted range [-rank, rank-1] which was [-{rank}, {}]",
axis_i64,
rank - 1
)
}
axis
};
// index_select does not support negative indices, so normalize them
// to positive indices.
let indices = &{
let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?;
let max = Tensor::new(data.dims()[axis] as i64, indices.device())?
.to_dtype(indices.dtype())?;
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(indices)?
};
values.insert(node.output[0].clone(), data.gather(indices, axis)?);
}
"Shape" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape
let xs = get(&node.input[0])?;
let start = get_attr_opt::<i64>(node, "start")?.copied().unwrap_or(0);
let end = get_attr_opt::<i64>(node, "end")?.copied().unwrap_or(-1);
let start = xs.normalize_axis(start)?;
let end = xs.normalize_axis(end)?;
let mut dims = vec![];
for idx in start..=end {
dims.push(xs.dim(idx)? as i64)
}
let dims = Tensor::from_vec(dims, xs.rank(), xs.device())?;
values.insert(node.output[0].clone(), dims);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Size
"Size" => {
let data = get(&node.input[0])?;
let size: usize = data.dims().iter().product();
let output = Tensor::from_slice(&[size as i64], (), data.device())?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt
"Sqrt" => {
let xs = get(&node.input[0])?;
let output = xs.sqrt()?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Range
"Range" => {
let start = get(&node.input[0])?;
let limit = get(&node.input[1])?;
let delta = get(&node.input[2])?;
macro_rules! arange_step {
($t: ty) => {
Tensor::arange_step(
start.to_vec0::<$t>()?,
limit.to_vec0::<$t>()?,
delta.to_vec0::<$t>()?,
&Device::Cpu,
)?
};
}
let output = match start.dtype() {
DType::U8 => arange_step!(u8),
DType::U32 => arange_step!(u32),
DType::I64 => arange_step!(i64),
DType::BF16 => arange_step!(f32),
DType::F16 => arange_step!(f32),
DType::F32 => arange_step!(f32),
DType::F64 => arange_step!(f64),
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Greater
"Greater" => {
let a = get(&node.input[0])?;
let b = get(&node.input[1])?;
let output = a.broadcast_gt(b)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Less
"Less" => {
let a = get(&node.input[0])?;
let b = get(&node.input[1])?;
let output = a.broadcast_lt(b)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Log
"Log" => {
let a = get(&node.input[0])?;
let output = a.log()?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Min
"Min" => {
let mut output = get(&node.input[0])?.clone();
for input in node.input.iter() {
let input = get(input)?;
output = output.broadcast_minimum(input)?
}
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Where
"Where" => {
let cond = get(&node.input[0])?;
let a = get(&node.input[1])?;
let b = get(&node.input[2])?;
// where_cond requires that all inputs are the same shape.
// In contrast, the Where op in ONNX only requires that they are broadcastable.
let shape = broadcast_shape_from_many(&[cond.dims(), a.dims(), b.dims()])?;
let cond = cond.broadcast_as(shape.clone())?;
let a = a.broadcast_as(shape.clone())?;
let b = b.broadcast_as(shape)?;
let output = cond.where_cond(&a, &b)?;
values.insert(node.output[0].clone(), output);
}
"Conv" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Conv
let dilations = get_attr_opt::<[i64]>(node, "dilations")?;
let groups = get_attr_opt::<i64>(node, "group")?.copied().unwrap_or(1);
let _kernel_shape = get_attr_opt::<[i64]>(node, "kernel_shape")?;
let pads = get_attr_opt::<[i64]>(node, "pads")?;
let strides = get_attr_opt::<[i64]>(node, "strides")?;
let auto_pad = get_attr_opt::<str>(node, "auto_pad")?;
match auto_pad {
None | Some("NOTSET") => (),
Some(s) => bail!("unsupported auto_pad {s}"),
};
let xs = get(&node.input[0])?;
let ws = get(&node.input[1])?;
let ys = match ws.rank() {
3 => {
let (pads, xs) = match pads {
None => (0, xs.clone()),
Some([p]) => (*p as usize, xs.clone()),
Some([p1, p2]) => {
if p1 != p2 {
(0usize, xs.pad_with_zeros(2, *p1 as usize, *p2 as usize)?)
} else {
(*p1 as usize, xs.clone())
}
}
Some(pads) => {
bail!("more pads than expected in conv1d {pads:?} {}", node.name)
}
};
let strides = match strides {
None => 1,
Some([p]) => *p as usize,
Some(s) => {
bail!("more strides than expected in conv1d {s:?} {}", node.name)
}
};
let dilations = match dilations {
None => 1,
Some([p]) => *p as usize,
Some(s) => {
bail!("more dilations than expected in conv1d {s:?} {}", node.name)
}
};
xs.conv1d(ws, pads, strides, dilations, groups as usize)?
}
4 => {
let (pads, xs) = match pads {
None => (0, xs.clone()),
Some([p]) => (*p as usize, xs.clone()),
Some(&[p1, p2, p3, p4]) => {
let p1 = p1 as usize;
let p2 = p2 as usize;
let p3 = p3 as usize;
let p4 = p4 as usize;
if p1 != p2 || p1 != p3 || p1 != p4 {
(0, xs.pad_with_zeros(2, p1, p3)?.pad_with_zeros(3, p2, p4)?)
} else {
(p1, xs.clone())
}
}
Some(pads) => {
bail!("more pads than expected in conv2d {pads:?} {}", node.name)
}
};
let strides = match strides {
None => 1,
Some([p]) => *p as usize,
Some([p1, p2]) => {
if p1 != p2 {
bail!(
"strides have to be the same on both axis {pads:?} {}",
node.name
)
}
*p1 as usize
}
Some(s) => {
bail!("more strides than expected in conv2d {s:?} {}", node.name)
}
};
let dilations = match dilations {
None => 1,
Some([p]) => *p as usize,
Some([p1, p2]) => {
if p1 != p2 {
bail!(
"dilations have to be the same on both axis {pads:?} {}",
node.name
)
}
*p1 as usize
}
Some(s) => {
bail!("more dilations than expected in conv2d {s:?} {}", node.name)
}
};
xs.conv2d(ws, pads, strides, dilations, groups as usize)?
}
rank => bail!(
"unsupported rank for weight matrix {rank} in conv {}",
node.name
),
};
let ys = if node.input.len() > 2 {
let bs = get(&node.input[2])?;
let mut bs_shape = vec![1; ys.rank()];
bs_shape[1] = bs.elem_count();
ys.broadcast_add(&bs.reshape(bs_shape)?)?
} else {
ys
};
values.insert(node.output[0].clone(), ys);
}
"Concat" => {
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
let inputs = node
.input
.iter()
.map(|n| Ok(get(n.as_str())?.clone()))
.collect::<Result<Vec<Value>>>()?;
let axis: i64 = *get_attr(node, "axis")?;
if inputs.is_empty() {
bail!("empty concat")
};
let axis = inputs[0].normalize_axis(axis)?;
let output = Tensor::cat(&inputs, axis)?;
values.insert(node.output[0].clone(), output);
}
"Abs" => {
let input = get(&node.input[0])?;
let output = input.abs()?;
values.insert(node.output[0].clone(), output);
}
"Cos" => {
let input = get(&node.input[0])?;
let output = input.cos()?;
values.insert(node.output[0].clone(), output);
}
"Sin" => {
let input = get(&node.input[0])?;
let output = input.sin()?;
values.insert(node.output[0].clone(), output);
}
"Neg" => {
let input = get(&node.input[0])?;
let output = input.neg()?;
values.insert(node.output[0].clone(), output);
}
"Erf" => {
let input = get(&node.input[0])?;
let output = input.erf()?;
values.insert(node.output[0].clone(), output);
}
"Tanh" => {
let input = get(&node.input[0])?;
let output = input.tanh()?;
values.insert(node.output[0].clone(), output);
}
"Sigmoid" => {
let input = get(&node.input[0])?;
let output = candle_nn::ops::sigmoid(input)?;
values.insert(node.output[0].clone(), output);
}
"Gelu" => {
let input = get(&node.input[0])?;
let output = input.gelu_erf()?;
values.insert(node.output[0].clone(), output);
}
"Relu" => {
let input = get(&node.input[0])?;
let output = input.relu()?;
values.insert(node.output[0].clone(), output);
}
"Ceil" => {
let input = get(&node.input[0])?;
let output = input.ceil()?;
values.insert(node.output[0].clone(), output);
}
"Floor" => {
let input = get(&node.input[0])?;
let output = input.floor()?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
"Constant" => {
let value = match node.attribute.iter().find(|attr| attr.name == "value") {
None => {
// TODO: support sparse_value etc.
bail!("cannot find 'value' attr in 'Constant' for {}", node.name)
}
Some(value) => value,
};
let output = match value.r#type() {
AttributeType::Tensor => {
let t = value.t.as_ref().unwrap();
get_tensor(t, &node.name)?
}
rtype => bail!("unsupported 'value' type {rtype:?} for {}", node.name),
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Cast
"Cast" => {
let input = get(&node.input[0])?;
let dt: i64 = *get_attr(node, "to")?;
let dtype = match DataType::try_from(dt as i32) {
Ok(DataType::Int32) => DType::I64,
Ok(dt) => match dtype(dt) {
Some(dt) => dt,
None => {
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
}
},
Err(_) => {
bail!("unsupported 'to' value {dt:?} for cast {}", node.name)
}
};
let output = input.to_dtype(dtype)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#CumSum
"CumSum" => {
let exclusive = get_attr_opt::<i64>(node, "exclusive")?
.copied()
.unwrap_or(0);
let reverse = get_attr_opt::<i64>(node, "reverse")?.copied().unwrap_or(0);
if exclusive != 0 {
bail!("only exclusive == 0 is supported in CumSum")
}
if reverse != 0 {
bail!("only reverse == 0 is supported in CumSum")
}
let input = get(&node.input[0])?;
let axis = get(&node.input[1])?
.to_dtype(DType::U32)?
.to_vec0::<u32>()?;
let output = input.cumsum(axis as usize)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#flatten
"Flatten" => {
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(1) as usize;
let input = get(&node.input[0])?;
let first_part: usize = input.shape().dims().iter().take(axis).product();
let end_index = input.shape().dims().iter().product::<usize>();
let new_shape = (first_part, end_index / first_part);
let output = input.reshape(new_shape)?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#identity
"Identity" => {
let input = get(&node.input[0])?;
values.insert(node.output[0].clone(), input.clone());
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#if
"If" => {
// protobuf encodes boolean false as 0 and true as 1
let cond = get(&node.input[0])?.get(0)?.to_scalar::<u8>()?;
let attr_name = if cond != 0 {
"then_branch"
} else {
"else_branch"
};
let sub_graph = get_attr::<GraphProto>(node, attr_name)?;
if sub_graph.output.len() != node.output.len() {
bail!(
"If node {:?} is malformed: branch outputs ({}) don't match node outputs ({})",
node.name,
sub_graph.output.len(),
node.output.len()
);
}
let branch_out = simple_eval_(sub_graph, values)?;
for (i, out) in node.output.iter().enumerate() {
values.insert(
out.clone(),
branch_out.get(&sub_graph.output[i].name).unwrap().clone(),
);
}
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#pad
"Pad" => {
let mode = get_attr_opt(node, "mode")?.unwrap_or("constant");
let data = get(&node.input[0])?;
let pads = get(&node.input[1])?;
if node.input.len() > 2 {
bail!(
"unsupported number of inputs {} for Pad node {:?}, expected 2",
node.input.len(),
node.name
);
}
if pads.rank() != 1 {
bail!("Pad expects 'pads' input to be 1D vector: {pads:?}");
}
if pads.dim(0).unwrap() != 2 * data.rank() {
bail!("Pad expects 'pads' input len to be 2 * rank of 'data' input: pads: {}, data rank: {}", pads, data.rank());
}
let pads = pads.to_vec1::<i64>()?;
let (pads_pre, pads_post) = pads.split_at(pads.len() / 2);
match mode {
"reflect" => {
let mut out = data.clone();
for (i, &dim) in data.dims().iter().enumerate().rev() {
if pads_pre[i] == 0 && pads_post[i] == 0 {
continue;
}
fn zigzag(min: i64, max: i64) -> impl Iterator<Item = i64> {
std::iter::repeat((min..max).chain((min + 1..=max).rev())).flatten()
}
let idx = if dim > 1 {
let cycle_len = dim * 2 - 2;
let skip = cycle_len - ((pads_pre[i] as usize) % cycle_len);
let idx = zigzag(0, (dim - 1) as i64)
.skip(skip)
.take((pads_pre[i] as usize) + dim + (pads_post[i] as usize));
Tensor::from_iter(idx, out.device())?
} else {
Tensor::full(0i64, (dim,), out.device())?
};
out = out.index_select(&idx, i)?;
}
values.insert(node.output[0].clone(), out);
}
_ => bail!(
"unsupported 'mode' value {mode:?} for Pad node {:?}",
node.name
),
}
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#slice
"Slice" => {
let data = get(&node.input[0])?;
let starts = get(&node.input[1])?;
let ends = get(&node.input[2])?;
let default_axes;
let default_steps;
let axes: &Tensor;
let steps: &Tensor;
// If axes are omitted, they are set to [0, ..., r-1]. If steps are omitted,
// they are set to [1, ..., 1] of length len(starts)
match node.input.len() {
3 => {
let len = starts.dims()[0];
default_axes = Some(Tensor::arange(0, len as i64, starts.device())?);
axes = default_axes.as_ref().unwrap();
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
steps = default_steps.as_ref().unwrap();
}
4 => {
let len = starts.dims()[0];
axes = get(&node.input[3])?;
default_steps = Some(Tensor::ones((len,), DType::I64, starts.device())?);
steps = default_steps.as_ref().unwrap();
}
5 => {
steps = get(&node.input[4])?;
axes = get(&node.input[3])?;
}
_ => bail!(
"Slice node is invalid, expected 3-5 inputs, got {}: {:?}",
node.input.len(),
node
),
}
let mut out = data.clone();
for (i, axis) in axes.to_vec1::<i64>()?.into_iter().enumerate() {
// All negative elements of axes are made non-negative by
// adding r to them, where r = rank(input).
let axis = if axis < 0 {
axis + data.rank() as i64
} else {
axis
} as usize;
let data_dim = data.dims()[axis] as i64;
let mut s = starts.get(i)?.to_scalar::<i64>()?;
let mut e = ends.get(i)?.to_scalar::<i64>()?;
// All negative values in starts[i] and ends[i] have
// dims[axes[i]] added to them, where dims are the
// dimensions of input.
if s < 0 {
s += data_dim;
}
if e < 0 {
e += data_dim;
}
let p = steps.get(i)?.to_scalar::<i64>()?;
// starts[i] is clamped into the range [0, dims[axes[i]]]
// for positive stepping and [0, dims[axes[i]]-1] for
// negative stepping.
// for positive stepping ends[axes[i]] is clamped to
// [0, dims[axes[i]]], while for negative stepping it is
// clamped to [-1, dims[axes[i]]-1].
if p >= 0 {
s = s.clamp(0, data_dim);
e = e.clamp(0, data_dim);
} else {
s = s.clamp(0, data_dim - 1);
e = e.clamp(-1, data_dim - 1);
}
let indexes = Tensor::arange_step(s, e, p, data.device())?;
out = out.index_select(&indexes, axis)?
}
values.insert(node.output[0].clone(), out);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax
"ReduceMax" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;
let axes = if let Some(Ok(axes)) = axes {
// Satisfies version 18+
axes.to_vec1::<i64>().ok()
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
// Backward compatiblity with version 13 and below
Some(axes.to_vec())
} else {
None
};
let axes = if let Some(axes) = axes {
let rank = input.rank();
let mut axes_set = HashSet::new();
let mut axes = axes
.iter()
.map(|a| {
let axis = if *a < 0 {
(rank as i64 + *a) as usize
} else {
*a as usize
};
axes_set.insert(axis);
axis
})
.collect::<Vec<_>>();
if axes_set.len() < axes.len() {
bail!("Duplicate value in 'axes'");
}
if axes.len() > 1 {
axes.sort();
}
Some(axes)
} else {
None
};
// TODO: Handle empty set
// Definition:
// "Reduction over an empty set of values yields minus infinity (if supported by the datatype) or the minimum value of the data type otherwise"
// For now, this will throw an error
if input.elem_count() == 0 {
bail!("reduction over zero-size tensor not supported");
}
let output = if let Some(axes) = axes {
let mut result = input.clone();
for &axis in axes.iter().rev() {
result = if keepdims {
result.max_keepdim(axis)?
} else {
result.max(axis)?
}
}
result
} else {
// If `axes` is empty and `noop_with_empty_axes` is set to `true (1)`
// ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor.""
if get_attr_opt::<i64>(node, "noop_with_empty_axes")?.copied() == Some(1) {
input.clone()
} else {
let mut result = input.flatten_all()?;
if keepdims {
result = result.max_keepdim(0)?;
// If keepdims is true, reshape to match input dimensions
let shape = vec![1; input.rank()];
result.reshape(shape)?
} else {
result.max(0)?
}
}
};
values.insert(node.output[0].clone(), output);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13
// TODO: This version is only compatible with ReduceMean V13 and below.
"ReduceMean" => {
let input = get(&node.input[0])?;
let axes = get_attr_opt::<[i64]>(node, "axes")?;
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let n_dims = input.dims().len();
let axes: Vec<usize> = if let Some(axes) = axes {
axes.iter()
.map(|e| (if e < &0 { (n_dims as i64) + *e } else { *e }) as usize)
.collect()
} else {
(0..n_dims).collect()
};
let output = if keepdims == 1 {
input.mean_keepdim(axes)?
} else {
input.mean(axes)?
};
values.insert(node.output[0].clone(), output);
}
// https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin
"ReduceMin" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1) == 1;
let axes = if let Some(Ok(axes)) = axes {
// Satisfies version 18+
axes.to_vec1::<i64>().ok()
} else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") {
// Backward compatiblity with version 13 and below
Some(axes.to_vec())
} else {
None
};
let axes = if let Some(axes) = axes {
let rank = input.rank();
let mut axes_set = HashSet::new();
let mut axes = axes
.iter()
.map(|a| {
let axis = if *a < 0 {
(rank as i64 + *a) as usize
} else {
*a as usize
};
axes_set.insert(axis);
axis
})
.collect::<Vec<_>>();
if axes_set.len() < axes.len() {
bail!("Duplicate value in 'axes'");
}
if axes.len() > 1 {
axes.sort();
}
Some(axes)
} else {
None
};
// TODO: Handle empty set
// Definition:
// "Reduction over an empty set of values yields positive infinity (if supported by the datatype) or the max value of the data type otherwise"
// For now, this will throw an error
if input.elem_count() == 0 {
bail!("reduction over zero-size tensor not supported");
}
let output = if let Some(axes) = axes {
let mut result = input.clone();
for &axis in axes.iter().rev() {
result = if keepdims {
result.min_keepdim(axis)?
} else {
result.min(axis)?
}
}
result
} else {
// If `axes` is empty and `noop_with_empty_axes` is set to `true (1)`
// ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor.""
if get_attr_opt::<i64>(node, "noop_with_empty_axes")?.copied() == Some(1) {
input.clone()
} else {
let mut result = input.flatten_all()?;
if keepdims {
result = result.min_keepdim(0)?;
// If keepdims is true, reshape to match input dimensions
let shape = vec![1; input.rank()];
result.reshape(shape)?
} else {
result.min(0)?
}
}
};
values.insert(node.output[0].clone(), output);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split
// Version 18 impl
"Split" => {
let input_tensor = get(&node.input[0])?;
let axis = get_attr_opt::<i64>(node, "axis")?.copied().unwrap_or(0);
let axis = input_tensor.normalize_axis(axis)?;
// Determine split sizes
let splits = if node.input.len() > 1 {
// If the split tensor is provided, use it to determine sizes
let split_tensor = get(&node.input[1])?.to_vec1::<i64>()?;
split_tensor.iter().map(|&x| x as usize).collect::<Vec<_>>()
} else {
let num_outputs = if let Some(&num_outputs_attrib) =
get_attr_opt::<i64>(node, "num_outputs")?
{
num_outputs_attrib as usize
} else {
node.output.len()
};
let input_dim = input_tensor.dim(axis)?;
let mut split_sizes =
vec![input_dim / num_outputs as usize; num_outputs as usize];
let remainder = input_dim % num_outputs as usize;
if remainder > 0 {
// If there's a remainder, add it to the last split size
split_sizes[num_outputs as usize - 1] += remainder;
}
split_sizes
};
// Perform the split operation
let mut outputs = vec![];
let mut start = 0;
for &size in &splits {
let end = start + size;
let slice = input_tensor.narrow(axis, start, size)?;
outputs.push(slice);
start = end;
}
// Insert the split outputs into the values map
for (output, slice) in node.output.iter().zip(outputs.into_iter()) {
values.insert(output.clone(), slice);
}
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#Expand
// Version 13 impl
"Expand" => {
// unlike broadcast_to, expand allows for the output shape to
// be different from the specified shape.
let input_tensor = get(&node.input[0])?;
let input_shape = get(&node.input[1])?;
// Check that the shape tensor is 1D
if input_shape.rank() != 1 {
bail!(
"Expand expects 'shape' input to be 1D tensor: {:?}",
input_shape
);
}
let input_tensor_dims = input_tensor.dims();
let input_shape_dims = input_shape
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>();
let target_shape = broadcast_shape(input_tensor_dims, input_shape_dims.as_slice())?;
let expanded_tensor = input_tensor.broadcast_as(target_shape)?;
values.insert(node.output[0].clone(), expanded_tensor);
}
//https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum
// Version 13 impl
"ReduceSum" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
.copied()
.unwrap_or(0);
let axes = match axes {
Some(Ok(axes)) => axes
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
Some(Err(_)) | None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
(0..input.rank()).collect()
}
}
};
let output = if keepdims == 1 {
input.sum_keepdim(axes)?
} else {
input.sum(axes)?
};
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2
// Version 18 impl
"ReduceL2" => {
let input = get(&node.input[0])?;
let axes = get_opt(1);
let keepdims = get_attr_opt::<i64>(node, "keepdims")?.copied().unwrap_or(1);
let noop_with_empty_axes = get_attr_opt::<i64>(node, "noop_with_empty_axes")?
.copied()
.unwrap_or(0);
let input_sq = input.sqr()?;
let axes = match axes {
Some(axes) => axes?
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
(0..input_sq.rank()).collect()
}
}
};
let output = if keepdims == 1 {
input_sq.sum_keepdim(axes)?.sqrt()?
} else {
input_sq.sum(axes)?.sqrt()?
};
values.insert(node.output[0].clone(), output);
}
random_type @ ("RandomUniform" | "RandomNormal") => {
let dt: i64 = get_attr_opt(node, "dtype")?.copied().unwrap_or(1); // 1 is float
// type by
// default
let dtype = match DataType::try_from(dt as i32) {
Ok(dt) => match dtype(dt) {
Some(DType::U8 | DType::U32 | DType::I64) => {
bail!(
"unsupported 'dtype' value {dt:?}, only floats are allowed, for {random_type} {}",
node.name
)
}
Some(dt) => dt,
None => {
bail!(
"unsupported 'dtype' value {dt:?} for {random_type} {}",
node.name
)
}
},
Err(_) => {
bail!(
"unsupported 'dtype' value {dt:?} for {random_type} {}",
node.name
)
}
};
let seed: Option<f32> = get_attr_opt(node, "seed")?.copied();
if seed.is_some() {
bail!("seed for {random_type} is currently not supported")
};
let shape: Vec<usize> = get_attr::<[i64]>(node, "shape")?
.iter()
.map(|x| *x as usize)
.collect();
let output = if random_type == "RandomUniform" {
let low: f32 = get_attr_opt(node, "low")?.copied().unwrap_or(0.0);
let high: f32 = get_attr_opt(node, "high")?.copied().unwrap_or(1.0);
Tensor::rand(low, high, shape, &Device::Cpu)?.to_dtype(dtype)?
} else {
let mean: f32 = get_attr_opt(node, "mean")?.copied().unwrap_or(0.0);
let scale: f32 = get_attr_opt(node, "scale")?.copied().unwrap_or(1.0);
Tensor::randn(mean, scale, shape, &Device::Cpu)?.to_dtype(dtype)?
};
values.insert(node.output[0].clone(), output);
}
"ArgMin" => {
let input = get(&node.input[0])?;
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
let rank_i64: i64 = input.rank().try_into().unwrap();
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
bail!(
"axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]",
axis_i64,
-rank_i64,
rank_i64 - 1
)
}
let axis = input.normalize_axis(axis_i64)?;
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?
.copied()
.unwrap_or(0);
if select_last_index == 1 {
bail!("select_last_index for ArgMin is currently not supported")
}
let output = if keepdims == 1 {
input.argmin_keepdim(axis)?
} else {
input.argmin(axis)?
}
.to_dtype(DType::I64)?;
values.insert(node.output[0].clone(), output);
}
"ArgMax" => {
let input = get(&node.input[0])?;
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
let rank_i64: i64 = input.rank().try_into().unwrap();
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
bail!(
"axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]",
axis_i64,
-rank_i64,
rank_i64 - 1
)
}
let axis = input.normalize_axis(axis_i64)?;
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?
.copied()
.unwrap_or(0);
if select_last_index == 1 {
bail!("select_last_index for ArgMin is currently not supported")
}
let output = if keepdims == 1 {
input.argmax_keepdim(axis)?
} else {
input.argmax(axis)?
}
.to_dtype(DType::I64)?;
values.insert(node.output[0].clone(), output);
}
"LeakyRelu" => {
let input = get(&node.input[0])?;
let dt = input.dtype();
match dt {
DType::U8 | DType::U32 | DType::I64 => {
bail!(
"unsupported dtype {}, only float types are allowed for LeakyRelu",
dt.as_str()
)
}
DType::BF16 | DType::F16 | DType::F32 | DType::F64 => {}
}
let alpha = get_attr_opt::<f32>(node, "alpha")?.copied().unwrap_or(0.01);
let output = candle_nn::ops::leaky_relu(input, alpha.into())?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm
"Gemm" => {
let a = get(&node.input[0])?;
let b = get(&node.input[1])?;
let c = get(&node.input[2])?;
let alpha = get_attr_opt::<f32>(node, "alpha")?.copied().unwrap_or(1.0);
let beta = get_attr_opt::<f32>(node, "beta")?.copied().unwrap_or(1.0);
let alpha = Tensor::full(alpha, a.shape(), &Device::Cpu)?;
let beta = Tensor::full(beta, c.shape(), &Device::Cpu)?;
let trans_a = get_attr_opt::<i64>(node, "transA")?.copied().unwrap_or(0);
let trans_b = get_attr_opt::<i64>(node, "transB")?.copied().unwrap_or(0);
let a = if trans_a == 0 { a.clone() } else { a.t()? };
let b = if trans_b == 0 { b.clone() } else { b.t()? };
let output = a
.broadcast_mul(&alpha)?
.broadcast_matmul(&b)?
.broadcast_add(&c.broadcast_mul(&beta)?)?;
values.insert(node.output[0].clone(), output);
}
"LSTM" => {
let direction = get_attr_opt(node, "direction")?.unwrap_or("forward");
if direction != "forward" {
bail!("LSTM currently only supports direction == \"forward\"");
}
let num_directions = if direction == "bidirectional" { 2 } else { 1 };
let hidden_size: i64 = get_attr(node, "hidden_size").copied()?;
let input_forget = get_attr_opt(node, "input_forget")?.copied().unwrap_or(0);
if input_forget != 0 {
bail!("LSTM currently only supports input_forget == 0");
}
let activations_default = vec![
"Sigmoid".to_string(),
"Tanh".to_string(),
"Tanh".to_string(),
];
let activations = get_attr_opt_owned::<Vec<String>>(node, "activations")?
.unwrap_or(activations_default.clone());
if activations != activations_default {
bail!("LSTM currently only supports default activations ({activations_default:?})");
}
// activation_alpha and activation_beta don't apply to (Sigmoid, Tanh, Tanh) so ignoring them is okay
if get_attr_opt::<f32>(node, "clip")?.is_some() {
bail!("LSTM does not currently support clip attribute");
}
// The shape format of inputs X, initial_h and outputs Y, Y_h.
// If 0, the following shapes are expected:
// X.shape = [seq_length, batch_size, input_size],
// Y.shape = [seq_length, num_directions, batch_size, hidden_size],
// initial_h.shape = Y_h.shape = [num_directions, batch_size, hidden_size].
// If 1, the following shapes are expected:
// X.shape = [batch_size, seq_length, input_size],
// Y.shape = [batch_size, seq_length, num_directions, hidden_size],
// initial_h.shape = Y_h.shape = [batch_size, num_directions, hidden_size].
let layout = get_attr_opt(node, "layout")?.copied().unwrap_or(0);
if layout != 0 {
bail!("LSTM currently only supports layout == 0");
}
// The input sequences packed (and potentially padded) into one 3-D tensor
// with the shape of `[seq_length, batch_size, input_size]`.
let x = get(&node.input[0])?;
// XXX: depends on layout
let (seq_length, batch_size, input_size) = x.dims3()?;
// The weight tensor for the gates.
// Concatenation of `W[iofc]` and `WB[iofc]` (if bidirectional) along dimension 0.
// The tensor has shape `[num_directions, 4*hidden_size, input_size]`.
let w = get(&node.input[1])?;
// The recurrence weight tensor.
// Concatenation of `R[iofc]` and `RB[iofc]` (if bidirectional) along dimension 0.
// This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`.
let r = get(&node.input[2])?;
// The bias tensor for input gate.
// Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0.
// This tensor has shape `[num_directions, 8*hidden_size]`.
// Optional: If not specified - assumed to be 0.
let b_default: Tensor;
let b = match get_opt(3) {
Some(n) => n?,
None => {
b_default = Tensor::zeros(
(num_directions, 8 * hidden_size as usize),
DType::F32,
x.device(),
)?;
&b_default
}
};
// Optional tensor specifying lengths of the sequences in a batch.
// If not specified - assumed all sequences in the batch to have length `seq_length`.
// It has shape `[batch_size]`.
let seq_lens_default: Tensor;
let seq_lens = match get_opt(4) {
Some(n) => n?,
None => {
seq_lens_default =
Tensor::full(seq_length as i64, (batch_size,), x.device())?;
&seq_lens_default
}
};
let seq_lens_is_default =
(seq_lens.to_vec1::<i64>()?.iter()).all(|e| *e as usize == seq_length);
if !seq_lens_is_default {
bail!("LSTM currently only supports default value of seq_lens");
}
// Optional initial value of the hidden. If not specified - assumed to be 0.
// It has shape `[num_directions, batch_size, hidden_size]`.
let initial_h_default: Tensor;
let initial_h = match get_opt(5) {
Some(n) => n?,
_ => {
initial_h_default = Tensor::zeros(
(num_directions, batch_size, hidden_size as usize),
DType::F32,
x.device(),
)?;
&initial_h_default
}
};
// Optional initial value of the cell.
// If not specified - assumed to be 0.
// It has shape `[num_directions, batch_size, hidden_size]`.
let initial_c_default: Tensor;
let initial_c = match node.input.get(6) {
Some(n) if !n.is_empty() => get(n)?,
_ => {
initial_c_default = Tensor::zeros(
(num_directions, batch_size, hidden_size as usize),
DType::F32,
x.device(),
)?;
&initial_c_default
}
};
// The weight tensor for peepholes.
// Concatenation of `P[iof]` and `PB[iof]` (if bidirectional) along dimension 0.
// It has shape `[num_directions, 3*hidde_size]`. Optional: If not specified - assumed to be 0.
let p_default = Tensor::zeros(
(num_directions, 3 * hidden_size as usize),
DType::F32,
x.device(),
)?;
let p = get_opt(7).unwrap_or(Ok(&p_default))?;
let p_is_zeros = (p.to_vec2::<f32>()?.iter()).all(|v| v.iter().all(|e| *e == 0.0));
if !p_is_zeros {
bail!(
"LSTM currently only supports default value of p (a Tensor of all zeroes)"
);
}
// these all have [num_directions, ...] shapes
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
let idx_wb = Tensor::arange(0, 4 * hidden_size, x.device())?;
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
let wb = b.index_select(&idx_wb, 0)?;
let rb = b.index_select(&idx_rb, 0)?;
let c = initial_c.get(0)?;
let h = initial_h.get(0)?;
// w, r, wb, rb are all iofc but lstm expects ifco
// so we need to move some stuff around
let idx_i = Tensor::arange(0, hidden_size, x.device())?;
let idx_o = Tensor::arange(hidden_size, 2 * hidden_size, x.device())?;
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
let w = w.index_select(&idx_ifco, 0)?;
let r = r.index_select(&idx_ifco, 0)?;
let wb = wb.index_select(&idx_ifco, 0)?;
let rb = rb.index_select(&idx_ifco, 0)?;
let vmap = candle_nn::VarMap::new();
vmap.data().lock().unwrap().extend([
("weight_ih_l0".to_string(), candle::Var::from_tensor(&w)?),
("weight_hh_l0".to_string(), candle::Var::from_tensor(&r)?),
("bias_ih_l0".to_string(), candle::Var::from_tensor(&wb)?),
("bias_hh_l0".to_string(), candle::Var::from_tensor(&rb)?),
]);
use candle_nn::rnn::RNN as _;
let lstm = candle_nn::rnn::lstm(
input_size,
hidden_size as usize,
candle_nn::rnn::LSTMConfig::default(),
candle_nn::VarBuilder::from_varmap(&vmap, w.dtype(), w.device()),
)?;
let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
let mut h_acc = if node.output.first().map(String::as_str).unwrap_or("") != "" {
Some(vec![])
} else {
None
};
for t in 0..seq_length {
let x = x.get(t)?;
lstm_state = lstm.step(&x, &lstm_state)?;
if let Some(h_acc) = &mut h_acc {
h_acc.push(lstm_state.clone());
}
}
assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
if let Some(name) = node.output.first() {
let h_acc = h_acc.as_ref().unwrap();
let h_acc = lstm.states_to_tensor(h_acc)?;
let h_acc = h_acc.reshape((
seq_length,
num_directions,
batch_size,
hidden_size as usize,
))?;
values.insert(name.clone(), h_acc);
}
if let Some(name) = node.output.get(1) {
values.insert(
name.clone(),
lstm_state.h().reshape((
num_directions,
batch_size,
hidden_size as usize,
))?,
);
}
if let Some(name) = node.output.get(2) {
values.insert(
name.clone(),
lstm_state.c().reshape((
num_directions,
batch_size,
hidden_size as usize,
))?,
);
}
}
// https://onnx.ai/onnx/operators/onnx__Xor.html
"Xor" => {
// Since we don't have a `DType::Bool` yet, this ensures that we are working with `0`(False) & `1`(True)
let a = get(&node.input[0])?.gt(0_u8)?;
let b = get(&node.input[1])?.gt(0_u8)?;
let out = a.broadcast_add(&b)?.eq(1_u8)?;
values.insert(node.output[0].clone(), out);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}
graph
.output
.iter()
.map(|output| match values.remove(&output.name) {
None => bail!("cannot find output {}", output.name),
Some(value) => Ok((output.name.clone(), value)),
})
.collect()
}
fn broadcast_shape(shape_a: &[usize], shape_b: &[usize]) -> Result<Vec<usize>> {
let (longest, shortest) = if shape_a.len() > shape_b.len() {
(shape_a, shape_b)
} else {
(shape_b, shape_a)
};
let diff = longest.len() - shortest.len();
let mut target_shape = longest[0..diff].to_vec();
for (dim1, dim2) in longest[diff..].iter().zip(shortest.iter()) {
if *dim1 == *dim2 || *dim2 == 1 || *dim1 == 1 {
target_shape.push(usize::max(*dim1, *dim2));
} else {
bail!(
"Expand: incompatible shapes for broadcast, {:?} and {:?}",
shape_a,
shape_b
);
}
}
Ok(target_shape)
}
fn broadcast_shape_from_many(shapes: &[&[usize]]) -> Result<Vec<usize>> {
if shapes.is_empty() {
return Ok(Vec::new());
}
let mut shape_out = shapes[0].to_vec();
for shape in shapes[1..].iter() {
shape_out = broadcast_shape(&shape_out, shape)?;
}
Ok(shape_out)
}