Clippy fixes for onnx + fix a broken test. (#2510)

This commit is contained in:
Laurent Mazare
2024-09-26 23:37:59 +02:00
committed by GitHub
parent ed48f54b54
commit 2c25754281
2 changed files with 273 additions and 281 deletions

View File

@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType;
use crate::onnx::tensor_proto::DataType;
use crate::onnx::{self, GraphProto};
use candle::{bail, DType, Device, Result, Tensor};
use std::{collections::HashMap, usize};
use std::collections::HashMap;
pub type Value = Tensor;
@ -321,7 +321,7 @@ fn simple_eval_(
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => bail!("cannot find {input_name} for op {}", node.name),
None => bail!("cannot find {input_name} for op '{}'", node.name),
};
let get_opt = |i: usize| {
node.input
@ -362,7 +362,7 @@ fn simple_eval_(
// HACK: current implementation of broadcast_pow cannot handle negative base,
// so we use powf where we can, which *does* correctly handle negative base.
if let Ok(exp) = (|| input1.to_dtype(DType::F64)?.to_scalar::<f64>())() {
let output = input0.powf(exp as f64)?;
let output = input0.powf(exp)?;
values.insert(node.output[0].clone(), output);
} else {
let output = input0.broadcast_pow(input1)?;
@ -643,7 +643,7 @@ fn simple_eval_(
let mask = indices.lt(&zeros)?;
mask.to_dtype(indices.dtype())?
.broadcast_mul(&max)?
.add(&indices)?
.add(indices)?
};
// In Pytorch or Numpy this can be done by indexing the xs tensor using the indices
@ -767,7 +767,7 @@ fn simple_eval_(
// where_cond requires that all inputs are the same shape.
// In contrast, the Where op in ONNX only requires that they are broadcastable.
let shape = broadcast_shape_from_many(&[&cond.dims(), &a.dims(), &b.dims()])?;
let shape = broadcast_shape_from_many(&[cond.dims(), a.dims(), b.dims()])?;
let cond = cond.broadcast_as(shape.clone())?;
let a = a.broadcast_as(shape.clone())?;
let b = b.broadcast_as(shape)?;
@ -1283,8 +1283,7 @@ fn simple_eval_(
.map(|x| x as usize)
.collect::<Vec<_>>();
let target_shape =
broadcast_shape(&input_tensor_dims, input_shape_dims.as_slice())?;
let target_shape = broadcast_shape(input_tensor_dims, input_shape_dims.as_slice())?;
let expanded_tensor = input_tensor.broadcast_as(target_shape)?;
@ -1301,12 +1300,12 @@ fn simple_eval_(
.unwrap_or(0);
let axes = match axes {
Some(axes) => axes?
Some(Ok(axes)) => axes
.to_vec1::<i64>()?
.into_iter()
.map(|x| x as usize)
.collect::<Vec<_>>(),
None => {
Some(Err(_)) | None => {
if noop_with_empty_axes == 1 {
vec![]
} else {
@ -1640,7 +1639,7 @@ fn simple_eval_(
let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size]
let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size]
let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size]
let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?;
let idx_wb = Tensor::arange(0, 4 * hidden_size, x.device())?;
let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?;
let wb = b.index_select(&idx_wb, 0)?;
let rb = b.index_select(&idx_rb, 0)?;
@ -1649,8 +1648,8 @@ fn simple_eval_(
// w, r, wb, rb are all iofc but lstm expects ifco
// so we need to move some stuff around
let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?;
let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?;
let idx_i = Tensor::arange(0, hidden_size, x.device())?;
let idx_o = Tensor::arange(hidden_size, 2 * hidden_size, x.device())?;
let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?;
let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?;
let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?;
@ -1674,7 +1673,7 @@ fn simple_eval_(
)?;
let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c);
let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" {
let mut h_acc = if node.output.first().map(String::as_str).unwrap_or("") != "" {
Some(vec![])
} else {
None
@ -1688,7 +1687,7 @@ fn simple_eval_(
}
assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped");
if let Some(name) = node.output.get(0) {
if let Some(name) = node.output.first() {
let h_acc = h_acc.as_ref().unwrap();
let h_acc = lstm.states_to_tensor(h_acc)?;
let h_acc = h_acc.reshape((

View File

@ -1,12 +1,5 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::test_utils::to_vec2_round;
use candle::{DType, Device, NdArray, Result, Tensor};
use candle_onnx::eval::Value;
use candle_onnx::onnx::attribute_proto::AttributeType;
use candle_onnx::onnx::tensor_proto::DataType;
use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension};
@ -3574,312 +3567,312 @@ fn test_lstm() -> Result<()> {
let number_directions = 1;
let weight_ih_l0 = Tensor::from_vec::<_, f32>(
vec![
-1.5255959033966064,
-0.7502318024635315,
-0.6539809107780457,
-1.6094847917556763,
-0.1001671776175499,
-0.6091889142990112,
-0.9797722697257996,
-1.6090962886810303,
-0.7121446132659912,
0.30372199416160583,
-0.777314305305481,
-0.25145524740219116,
-0.22227048873901367,
1.6871134042739868,
0.22842517495155334,
0.46763551235198975,
-0.6969724297523499,
-1.1607614755630493,
0.6995424032211304,
0.1990816295146942,
0.8656923770904541,
0.2444038987159729,
-0.6629113554954529,
0.8073082566261292,
1.1016806364059448,
-0.1759360432624817,
-2.2455577850341797,
-1.4464579820632935,
0.0611552819609642,
-0.6177444458007812,
-0.7980698347091675,
-0.13162320852279663,
1.8793457746505737,
-0.07213178277015686,
0.15777060389518738,
-0.7734549045562744,
0.1990565061569214,
0.04570277780294418,
0.15295691788196564,
-0.47567880153656006,
-0.11101982742547989,
0.2927352488040924,
-0.1578451544046402,
-0.028787139803171158,
0.4532545804977417,
1.1421611309051514,
0.2486107051372528,
-1.7754007577896118,
-0.025502461940050125,
-1.023330569267273,
-0.5961851477622986,
-1.0055307149887085,
0.42854228615760803,
1.4760777950286865,
-1.7868678569793701,
1.610317587852478,
-0.703956663608551,
-0.18526579439640045,
-0.9962350726127625,
-0.8312552571296692,
-1.525_595_9,
-0.750_231_8,
-0.653_980_9,
-1.609_484_8,
-0.100_167_18,
-0.609_188_9,
-0.979_772_27,
-1.609_096_3,
-0.712_144_6,
0.303_722,
-0.777_314_3,
-0.251_455_25,
-0.222_270_49,
1.687_113_4,
0.228_425_17,
0.467_635_5,
-0.696_972_4,
-1.160_761_5,
0.699_542_4,
0.199_081_63,
0.865_692_4,
0.244_403_9,
-0.662_911_36,
0.807_308_26,
1.101_680_6,
-0.175_936_04,
-2.245_557_8,
-1.446_458,
0.061_155_282,
-0.617_744_45,
-0.798_069_83,
-0.131_623_21,
1.879_345_8,
-0.072_131_78,
0.157_770_6,
-0.773_454_9,
0.199_056_5,
0.045_702_778,
0.152_956_92,
-0.475_678_8,
-0.111_019_83,
0.292_735_25,
-0.157_845_15,
-0.028_787_14,
0.453_254_58,
1.142_161_1,
0.248_610_7,
-1.775_400_8,
-0.025_502_462,
-1.023_330_6,
-0.596_185_15,
-1.005_530_7,
0.428_542_3,
1.476_077_8,
-1.786_867_9,
1.610_317_6,
-0.703_956_66,
-0.185_265_8,
-0.996_235_1,
-0.831_255_26,
],
(20, 3),
&Device::Cpu,
)?;
let weight_hh_l0 = Tensor::from_vec::<_, f32>(
vec![
0.4099724292755127,
0.4084506630897522,
0.25786539912223816,
1.095021367073059,
-0.5064865946769714,
0.09977540373802185,
-0.653973400592804,
0.731693685054779,
-1.456732988357544,
1.6089353561401367,
0.09376997500658035,
-1.2597490549087524,
0.25463348627090454,
-0.5019572973251343,
-1.041200041770935,
0.7322672009468079,
1.3075355291366577,
-1.1627987623214722,
0.11963611096143723,
-0.1631353348493576,
0.6614453196525574,
1.1899205446243286,
0.8165339231491089,
-0.9135236144065857,
-0.3538065254688263,
0.7639270424842834,
-0.5889506936073303,
-0.7635973691940308,
1.3352056741714478,
0.6042736172676086,
-0.10344208031892776,
-0.15121692419052124,
1.2465683221817017,
0.505721390247345,
0.9505112171173096,
1.2966482639312744,
0.873796284198761,
-0.5602594017982483,
1.2857844829559326,
0.8168238401412964,
-1.464799404144287,
-1.2629283666610718,
1.122018814086914,
1.5663341283798218,
2.558138370513916,
-0.23336388170719147,
-0.013472129590809345,
1.8606348037719727,
1.549620509147644,
0.34762924909591675,
0.09300802648067474,
0.6147403120994568,
0.7123645544052124,
-1.7765072584152222,
0.3538645803928375,
1.1996132135391235,
-0.7122589349746704,
-0.620034396648407,
-0.22813494503498077,
-0.7892746329307556,
-1.6111117601394653,
-1.8716129064559937,
0.5430836081504822,
0.6606786251068115,
0.270527720451355,
0.5596919655799866,
-0.31839630007743835,
1.5117206573486328,
-1.363267183303833,
-0.9832196235656738,
1.5112667083740234,
0.6418707370758057,
-0.7474458813667297,
-0.923438549041748,
0.5733984112739563,
-0.10929951071739197,
0.5181121230125427,
0.10653535276651382,
0.26924076676368713,
1.3247679471969604,
0.037456899881362915,
-0.6378393173217773,
-0.8147554397583008,
-0.6895065307617188,
0.8436542749404907,
1.1657012701034546,
0.5269321799278259,
1.6192532777786255,
-0.963976263999939,
0.14152038097381592,
-0.1636609584093094,
-0.3582225739955902,
1.7222793102264404,
-0.3035756051540375,
0.23887419700622559,
1.3440011739730835,
0.1032256931066513,
1.1003541946411133,
-0.3416801989078522,
0.947338879108429,
0.409_972_43,
0.408_450_66,
0.257_865_4,
1.095_021_4,
-0.506_486_6,
0.099_775_404,
-0.653_973_4,
0.731_693_7,
-1.456_733,
1.608_935_4,
0.093_769_975,
-1.259_749,
0.254_633_5,
-0.501_957_3,
-1.041_2,
0.732_267_2,
1.307_535_5,
-1.162_798_8,
0.119_636_11,
-0.163_135_33,
0.661_445_3,
1.189_920_5,
0.816_533_9,
-0.913_523_6,
-0.353_806_53,
0.763_927_04,
-0.588_950_7,
-0.763_597_37,
1.335_205_7,
0.604_273_6,
-0.103_442_08,
-0.151_216_92,
1.246_568_3,
0.505_721_4,
0.950_511_2,
1.296_648_3,
0.873_796_3,
-0.560_259_4,
1.285_784_5,
0.816_823_84,
-1.464_799_4,
-1.262_928_4,
1.122_018_8,
1.566_334_1,
2.558_138_4,
-0.233_363_88,
-0.013_472_13,
1.860_634_8,
1.549_620_5,
0.347_629_25,
0.093_008_03,
0.614_740_3,
0.712_364_55,
-1.776_507_3,
0.353_864_58,
1.199_613_2,
-0.712_258_93,
-0.620_034_4,
-0.228_134_95,
-0.789_274_63,
-1.611_111_8,
-1.871_612_9,
0.543_083_6,
0.660_678_6,
0.270_527_72,
0.559_691_97,
-0.318_396_3,
1.511_720_7,
-1.363_267_2,
-0.983_219_6,
1.511_266_7,
0.641_870_74,
-0.747_445_9,
-0.923_438_55,
0.573_398_4,
-0.109_299_51,
0.518_112_1,
0.106_535_35,
0.269_240_77,
1.324_768,
0.037_456_9,
-0.637_839_3,
-0.814_755_44,
-0.689_506_53,
0.843_654_3,
1.165_701_3,
0.526_932_2,
1.619_253_3,
-0.963_976_26,
0.141_520_38,
-0.163_660_96,
-0.358_222_57,
1.722_279_3,
-0.303_575_6,
0.238_874_2,
1.344_001_2,
0.103_225_69,
1.100_354_2,
-0.341_680_2,
0.947_338_9,
],
(20, 5),
&Device::Cpu,
)?;
let bias_ih_l0 = Tensor::from_vec::<_, f32>(
vec![
-0.568515956401825,
0.8375961780548096,
1.783660650253296,
-0.1954246610403061,
0.235193133354187,
1.9142433404922485,
1.8364111185073853,
1.324532389640808,
-0.07051458209753036,
0.34697940945625305,
-0.653679609298706,
1.5586202144622803,
0.2185661494731903,
-0.5743072628974915,
1.4571250677108765,
1.7709556818008423,
-2.0172998905181885,
0.42350319027900696,
0.5730220079421997,
-1.7962429523468018,
-0.568_515_96,
0.837_596_2,
1.783_660_7,
-0.195_424_66,
0.235_193_13,
1.914_243_3,
1.836_411_1,
1.324_532_4,
-0.070_514_58,
0.346_979_4,
-0.653_679_6,
1.558_620_2,
0.218_566_15,
-0.574_307_26,
1.457_125_1,
1.770_955_7,
-2.017_3,
0.423_503_2,
0.573_022,
-1.796_243,
],
(20,),
&Device::Cpu,
)?;
let bias_hh_l0 = Tensor::from_vec::<_, f32>(
vec![
1.2470403909683228,
1.2738511562347412,
0.3909492492675781,
0.387210488319397,
0.14440394937992096,
0.7771684527397156,
-2.3381125926971436,
-0.829120397567749,
1.1661391258239746,
1.4786574840545654,
0.26760873198509216,
0.7561198472976685,
-0.5873361229896545,
-2.061920642852783,
0.4304734766483307,
0.3376566171646118,
-0.3437853455543518,
-0.6172260642051697,
1.2529692649841309,
-0.05141742154955864,
1.247_040_4,
1.273_851_2,
0.390_949_25,
0.387_210_5,
0.144_403_95,
0.777_168_45,
-2.338_112_6,
-0.829_120_4,
1.166_139_1,
1.478_657_5,
0.267_608_73,
0.756_119_85,
-0.587_336_1,
-2.061_920_6,
0.430_473_48,
0.337_656_62,
-0.343_785_35,
-0.617_226_06,
1.252_969_3,
-0.051_417_42,
],
(20,),
&Device::Cpu,
)?;
let input = Tensor::from_vec::<_, f32>(
vec![
0.6472128033638,
-0.04116716980934143,
-0.17749308049678802,
-0.500039279460907,
0.8672749400138855,
-0.27319222688674927,
-0.4607681334018707,
-0.0990937128663063,
0.47284480929374695,
1.0049484968185425,
-0.2871420383453369,
-1.1618621349334717,
0.647_212_8,
-0.041_167_17,
-0.177_493_08,
-0.500_039_3,
0.867_274_94,
-0.273_192_23,
-0.460_768_13,
-0.099_093_71,
0.472_844_8,
1.004_948_5,
-0.287_142_04,
-1.161_862_1,
],
(4, 1, 3),
&Device::Cpu,
)?;
let h0 = Tensor::from_vec::<_, f32>(
vec![
0.02758178487420082,
0.5652382373809814,
-0.011487378738820553,
0.6706400513648987,
-0.4929250478744507,
0.027_581_785,
0.565_238_24,
-0.011_487_379,
0.670_640_05,
-0.492_925_05,
],
(1, 1, 5),
&Device::Cpu,
)?;
let c0 = Tensor::from_vec::<_, f32>(
vec![
1.505028486251831,
-2.32635498046875,
1.6168899536132812,
-0.9026237726211548,
0.17366823554039001,
1.505_028_5,
-2.326_355,
1.616_89,
-0.902_623_8,
0.173_668_24,
],
(1, 1, 5),
&Device::Cpu,
)?;
let output = Tensor::from_vec::<_, f32>(
vec![
0.5956016778945923,
-0.01723279245197773,
0.11035571992397308,
-0.49323174357414246,
0.047632161527872086,
0.6358451843261719,
0.040328118950128555,
-0.3788611590862274,
-0.7464339733123779,
0.20080909132957458,
0.5840265154838562,
0.1453288197517395,
-0.7345298528671265,
-0.5214304327964783,
0.21903817355632782,
0.7420451641082764,
0.31943878531455994,
-0.04726646468043327,
-0.2823849618434906,
0.2713133990764618,
0.595_601_7,
-0.017_232_792,
0.110_355_72,
-0.493_231_74,
0.047_632_16,
0.635_845_2,
0.040_328_12,
-0.378_861_16,
-0.746_434,
0.200_809_09,
0.584_026_5,
0.145_328_82,
-0.734_529_85,
-0.521_430_43,
0.219_038_17,
0.742_045_16,
0.319_438_8,
-0.047_266_465,
-0.282_384_96,
0.271_313_4,
],
(4, 1, 5),
&Device::Cpu,
)?;
let hn = Tensor::from_vec::<_, f32>(
vec![
0.7420451641082764,
0.31943878531455994,
-0.04726646468043327,
-0.2823849618434906,
0.2713133990764618,
0.742_045_16,
0.319_438_8,
-0.047_266_465,
-0.282_384_96,
0.271_313_4,
],
(1, 1, 5),
&Device::Cpu,
)?;
let cn = Tensor::from_vec::<_, f32>(
vec![
0.9630558490753174,
1.0033069849014282,
-1.754899024963379,
-1.5967122316360474,
0.8252924680709839,
0.963_055_85,
1.003_307,
-1.754_899,
-1.596_712_2,
0.825_292_47,
],
(1, 1, 5),
&Device::Cpu,
@ -3929,8 +3922,8 @@ fn test_lstm() -> Result<()> {
let idx_iofc = {
let stride = hidden_size as i64;
let dev = weight_ih_l0.device();
let idx_i = Tensor::arange(0 * stride, 1 * stride, dev)?;
let idx_f = Tensor::arange(1 * stride, 2 * stride, dev)?;
let idx_i = Tensor::arange(0, stride, dev)?;
let idx_f = Tensor::arange(stride, 2 * stride, dev)?;
let idx_g = Tensor::arange(2 * stride, 3 * stride, dev)?;
let idx_o = Tensor::arange(3 * stride, 4 * stride, dev)?;
@ -3966,15 +3959,15 @@ fn test_lstm() -> Result<()> {
Ok(diffs.iter().all(|f| f.abs() < 0.0001))
};
assert!(
diff_close_enough(&output, &actual_output)?,
diff_close_enough(&output, actual_output)?,
"output did not match expected\n{actual_output}\n{output}",
);
assert!(
diff_close_enough(&hn, &actual_hn)?,
diff_close_enough(&hn, actual_hn)?,
"hn did not match expected\n{actual_hn}\n{hn}",
);
assert!(
diff_close_enough(&cn, &actual_cn)?,
diff_close_enough(&cn, actual_cn)?,
"cn did not match expected\n{actual_cn}\n{cn}",
);
@ -4064,14 +4057,14 @@ fn make_graph_helper(
doc_string: "".to_string(),
}],
input: inputs
.into_iter()
.iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
..ValueInfoProto::default()
})
.collect(),
output: outputs
.into_iter()
.iter()
.map(|name| ValueInfoProto {
name: name.to_string(),
..ValueInfoProto::default()
@ -4282,7 +4275,7 @@ fn test_reduce_sum_do_not_keep_dims() -> Result<()> {
// Test with random data
{
let shape = (3, 2, 2);
let _shape = (3, 2, 2);
let data = Tensor::from_vec(
vec![
-5.2f32, 7.8, -3.1, 9.4, 2.6, -8.7, 4.3, -1.9, 6.5, -0.8, -7.2, 3.6,