Clippy fixes for onnx + fix a broken test. (#2510)

This commit is contained in:
Laurent Mazare
2024-09-26 23:37:59 +02:00
committed by GitHub
parent ed48f54b54
commit 2c25754281
2 changed files with 273 additions and 281 deletions

View File

@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::{collections::HashMap, usize};
use std::collections::HashMap;
pub type Value = Tensor;
@ -321,7 +321,7 @@ fn simple_eval_(
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
None => bail!("cannot find {input_name} for op '{}'", node.name),
};
let get_opt = |i: usize| {
node.input
@ -362,7 +362,7 @@ fn simple_eval_(
// HACK: current implementation of broadcast_pow cannot handle negative base,
// so we use powf where we can, which *does* correctly handle negative base.
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
let output = input0.powf(exp as f64)?;
let output = input0.powf(exp)?;
values.insert(node.output[0].clone(), output);
} else {
let output = input0.broadcast_pow(input1)?;
@ -643,7 +643,7 @@ fn simple_eval_(
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(&indices)?
.add(indices)?
};
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
@ -767,7 +767,7 @@ fn simple_eval_(
// where_cond requires that all inputs are the same shape.
// In contrast, the Where op in ONNX only requires that they are broadcastable.
let shape = broadcast_shape_from_many(&[&cond.dims(), &a.dims(), &b.dims()])?;
let shape = broadcast_shape_from_many(&[cond.dims(), a.dims(), b.dims()])?;
let cond = cond.broadcast_as(shape.clone())?;
let a = a.broadcast_as(shape.clone())?;
let b = b.broadcast_as(shape)?;
@ -1283,8 +1283,7 @@ fn simple_eval_(
.map(|x| x as usize)
.collect::<Vec<_>>();
let target_shape =
broadcast_shape(&input_tensor_dims, input_shape_dims.as_slice())?;
let target_shape = broadcast_shape(input_tensor_dims, input_shape_dims.as_slice())?;
let expanded_tensor = input_tensor.broadcast_as(target_shape)?;
@ -1301,12 +1300,12 @@ fn simple_eval_(
.unwrap_or(0);
let axes = match axes {
Some(axes) => axes?
Some(Ok(axes)) => axes
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
None => {
Some(Err(_)) | None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
@ -1640,7 +1639,7 @@ fn simple_eval_(
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?;
let idx_wb = Tensor::arange(0, 4 * hidden_size, x.device())?;
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
let wb = b.index_select(&idx_wb, 0)?;
let rb = b.index_select(&idx_rb, 0)?;
@ -1649,8 +1648,8 @@ fn simple_eval_(
// w, r, wb, rb are all iofc but lstm expects ifco
// so we need to move some stuff around
let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?;
let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?;
let idx_i = Tensor::arange(0, hidden_size, x.device())?;
let idx_o = Tensor::arange(hidden_size, 2 * hidden_size, x.device())?;
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
@ -1674,7 +1673,7 @@ fn simple_eval_(
)?;
let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" {
let mut h_acc = if node.output.first().map(String::as_str).unwrap_or("") != "" {
Some(vec![])
} else {
None
@ -1688,7 +1687,7 @@ fn simple_eval_(
}
assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
if let Some(name) = node.output.get(0) {
if let Some(name) = node.output.first() {
let h_acc = h_acc.as_ref().unwrap();
let h_acc = lstm.states_to_tensor(h_acc)?;
let h_acc = h_acc.reshape((