mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 12:06:35 +00:00
Clippy fixes for onnx + fix a broken test. (#2510)
This commit is contained in:
@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use crate::onnx::{self, GraphProto};
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use std::{collections::HashMap, usize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub type Value = Tensor;
|
||||
|
||||
@ -321,7 +321,7 @@ fn simple_eval_(
|
||||
for node in graph.node.iter() {
|
||||
let get = |input_name: &str| match values.get(input_name) {
|
||||
Some(value) => Ok(value),
|
||||
None => bail!("cannot find {input_name} for op {}", node.name),
|
||||
None => bail!("cannot find {input_name} for op '{}'", node.name),
|
||||
};
|
||||
let get_opt = |i: usize| {
|
||||
node.input
|
||||
@ -362,7 +362,7 @@ fn simple_eval_(
|
||||
// HACK: current implementation of broadcast_pow cannot handle negative base,
|
||||
// so we use powf where we can, which *does* correctly handle negative base.
|
||||
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
|
||||
let output = input0.powf(exp as f64)?;
|
||||
let output = input0.powf(exp)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
} else {
|
||||
let output = input0.broadcast_pow(input1)?;
|
||||
@ -643,7 +643,7 @@ fn simple_eval_(
|
||||
let mask = indices.lt(&zeros)?;
|
||||
mask.to_dtype(indices.dtype())?
|
||||
.broadcast_mul(&max)?
|
||||
.add(&indices)?
|
||||
.add(indices)?
|
||||
};
|
||||
|
||||
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
|
||||
@ -767,7 +767,7 @@ fn simple_eval_(
|
||||
|
||||
// where_cond requires that all inputs are the same shape.
|
||||
// In contrast, the Where op in ONNX only requires that they are broadcastable.
|
||||
let shape = broadcast_shape_from_many(&[&cond.dims(), &a.dims(), &b.dims()])?;
|
||||
let shape = broadcast_shape_from_many(&[cond.dims(), a.dims(), b.dims()])?;
|
||||
let cond = cond.broadcast_as(shape.clone())?;
|
||||
let a = a.broadcast_as(shape.clone())?;
|
||||
let b = b.broadcast_as(shape)?;
|
||||
@ -1283,8 +1283,7 @@ fn simple_eval_(
|
||||
.map(|x| x as usize)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let target_shape =
|
||||
broadcast_shape(&input_tensor_dims, input_shape_dims.as_slice())?;
|
||||
let target_shape = broadcast_shape(input_tensor_dims, input_shape_dims.as_slice())?;
|
||||
|
||||
let expanded_tensor = input_tensor.broadcast_as(target_shape)?;
|
||||
|
||||
@ -1301,12 +1300,12 @@ fn simple_eval_(
|
||||
.unwrap_or(0);
|
||||
|
||||
let axes = match axes {
|
||||
Some(axes) => axes?
|
||||
Some(Ok(axes)) => axes
|
||||
.to_vec1::<i64>()?
|
||||
.into_iter()
|
||||
.map(|x| x as usize)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
Some(Err(_)) | None => {
|
||||
if noop_with_empty_axes == 1 {
|
||||
vec![]
|
||||
} else {
|
||||
@ -1640,7 +1639,7 @@ fn simple_eval_(
|
||||
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
|
||||
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
|
||||
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
|
||||
let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?;
|
||||
let idx_wb = Tensor::arange(0, 4 * hidden_size, x.device())?;
|
||||
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
|
||||
let wb = b.index_select(&idx_wb, 0)?;
|
||||
let rb = b.index_select(&idx_rb, 0)?;
|
||||
@ -1649,8 +1648,8 @@ fn simple_eval_(
|
||||
|
||||
// w, r, wb, rb are all iofc but lstm expects ifco
|
||||
// so we need to move some stuff around
|
||||
let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?;
|
||||
let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?;
|
||||
let idx_i = Tensor::arange(0, hidden_size, x.device())?;
|
||||
let idx_o = Tensor::arange(hidden_size, 2 * hidden_size, x.device())?;
|
||||
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
|
||||
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
|
||||
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
|
||||
@ -1674,7 +1673,7 @@ fn simple_eval_(
|
||||
)?;
|
||||
|
||||
let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
|
||||
let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" {
|
||||
let mut h_acc = if node.output.first().map(String::as_str).unwrap_or("") != "" {
|
||||
Some(vec![])
|
||||
} else {
|
||||
None
|
||||
@ -1688,7 +1687,7 @@ fn simple_eval_(
|
||||
}
|
||||
|
||||
assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
|
||||
if let Some(name) = node.output.get(0) {
|
||||
if let Some(name) = node.output.first() {
|
||||
let h_acc = h_acc.as_ref().unwrap();
|
||||
let h_acc = lstm.states_to_tensor(h_acc)?;
|
||||
let h_acc = h_acc.reshape((
|
||||
|
Reference in New Issue
Block a user