mirror of
https://github.com/huggingface/candle.git
synced 2025-06-22 12:28:06 +00:00
Apply rustfmt. (#2247)
This commit is contained in:
@ -2715,51 +2715,31 @@ fn test_argmin() -> Result<()> {
|
||||
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-7
|
||||
// default_axes_keepdims
|
||||
test(
|
||||
&[
|
||||
[2u32, 1u32],
|
||||
[3u32, 10u32]
|
||||
],
|
||||
&[[2u32, 1u32], [3u32, 10u32]],
|
||||
None,
|
||||
Some(1),
|
||||
None,
|
||||
&[
|
||||
[0i64, 0i64],
|
||||
],
|
||||
&[[0i64, 0i64]],
|
||||
)?;
|
||||
// keepdims
|
||||
test(
|
||||
&[
|
||||
[2u32, 1u32],
|
||||
[3u32, 10u32]
|
||||
],
|
||||
&[[2u32, 1u32], [3u32, 10u32]],
|
||||
Some(1),
|
||||
Some(1),
|
||||
None,
|
||||
&[
|
||||
[1i64],
|
||||
[0i64]
|
||||
],
|
||||
&[[1i64], [0i64]],
|
||||
)?;
|
||||
// // negative_axis_keepdims
|
||||
test(
|
||||
&[
|
||||
[2u32, 1u32],
|
||||
[3u32, 10u32]
|
||||
],
|
||||
&[[2u32, 1u32], [3u32, 10u32]],
|
||||
Some(-1),
|
||||
Some(1),
|
||||
None,
|
||||
&[
|
||||
[1i64],
|
||||
[0i64]
|
||||
],
|
||||
&[[1i64], [0i64]],
|
||||
)?;
|
||||
// no_keepdims
|
||||
test(
|
||||
&[
|
||||
[2u32, 1u32],
|
||||
[3u32, 10u32]
|
||||
],
|
||||
&[[2u32, 1u32], [3u32, 10u32]],
|
||||
None,
|
||||
Some(0),
|
||||
None,
|
||||
@ -2771,7 +2751,7 @@ fn test_argmin() -> Result<()> {
|
||||
[0.1139, 0.2254, -0.1381, 0.3687],
|
||||
[1.0100, -1.1975, -0.0102, -0.4732],
|
||||
[-0.9240, 0.1207, -0.7506, -1.0213],
|
||||
[1.7809, -1.2960, 0.9384, 0.1438]
|
||||
[1.7809, -1.2960, 0.9384, 0.1438],
|
||||
],
|
||||
Some(1),
|
||||
Some(0),
|
||||
@ -2783,14 +2763,20 @@ fn test_argmin() -> Result<()> {
|
||||
[0.1139, 0.2254, -0.1381, 0.3687],
|
||||
[1.0100, -1.1975, -0.0102, -0.4732],
|
||||
[-0.9240, 0.1207, -0.7506, -1.0213],
|
||||
[1.7809, -1.2960, 0.9384, 0.1438]
|
||||
[1.7809, -1.2960, 0.9384, 0.1438],
|
||||
],
|
||||
Some(1),
|
||||
None,
|
||||
None,
|
||||
&[[2i64], [1i64], [3i64], [1i64]],
|
||||
)?;
|
||||
fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> {
|
||||
fn test(
|
||||
data: impl NdArray,
|
||||
axis: Option<i64>,
|
||||
keepdims: Option<i64>,
|
||||
select_last_index: Option<i64>,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let att_axis = AttributeProto {
|
||||
name: "axis".to_string(),
|
||||
ref_attr_name: "axis".to_string(),
|
||||
@ -2911,51 +2897,31 @@ fn test_argmax() -> Result<()> {
|
||||
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-6
|
||||
// default_axes_keepdims
|
||||
test(
|
||||
&[
|
||||
[2u32, 1u32],
|
||||
[3u32, 10u32]
|
||||
],
|
||||
&[[2u32, 1u32], [3u32, 10u32]],
|
||||
None,
|
||||
Some(1),
|
||||
None,
|
||||
&[
|
||||
[1i64, 1i64],
|
||||
],
|
||||
&[[1i64, 1i64]],
|
||||
)?;
|
||||
// keepdims
|
||||
test(
|
||||
&[
|
||||
[2u32, 1u32],
|
||||
[3u32, 10u32]
|
||||
],
|
||||
&[[2u32, 1u32], [3u32, 10u32]],
|
||||
Some(1),
|
||||
Some(1),
|
||||
None,
|
||||
&[
|
||||
[0i64],
|
||||
[1i64]
|
||||
],
|
||||
&[[0i64], [1i64]],
|
||||
)?;
|
||||
// // negative_axis_keepdims
|
||||
test(
|
||||
&[
|
||||
[2u32, 1u32],
|
||||
[3u32, 10u32]
|
||||
],
|
||||
&[[2u32, 1u32], [3u32, 10u32]],
|
||||
Some(-1),
|
||||
Some(1),
|
||||
None,
|
||||
&[
|
||||
[0i64],
|
||||
[1i64]
|
||||
],
|
||||
&[[0i64], [1i64]],
|
||||
)?;
|
||||
// no_keepdims
|
||||
test(
|
||||
&[
|
||||
[2u32, 1u32],
|
||||
[3u32, 10u32]
|
||||
],
|
||||
&[[2u32, 1u32], [3u32, 10u32]],
|
||||
None,
|
||||
Some(0),
|
||||
None,
|
||||
@ -2967,7 +2933,7 @@ fn test_argmax() -> Result<()> {
|
||||
[1.3398, 0.2663, -0.2686, 0.2450],
|
||||
[-0.7401, -0.8805, -0.3402, -1.1936],
|
||||
[0.4907, -1.3948, -1.0691, -0.3132],
|
||||
[-1.6092, 0.5419, -0.2993, 0.3195]
|
||||
[-1.6092, 0.5419, -0.2993, 0.3195],
|
||||
],
|
||||
Some(1),
|
||||
Some(0),
|
||||
@ -2979,14 +2945,20 @@ fn test_argmax() -> Result<()> {
|
||||
[1.3398, 0.2663, -0.2686, 0.2450],
|
||||
[-0.7401, -0.8805, -0.3402, -1.1936],
|
||||
[0.4907, -1.3948, -1.0691, -0.3132],
|
||||
[-1.6092, 0.5419, -0.2993, 0.3195]
|
||||
[-1.6092, 0.5419, -0.2993, 0.3195],
|
||||
],
|
||||
Some(1),
|
||||
None,
|
||||
None,
|
||||
&[[0i64], [2i64], [0i64], [1i64]],
|
||||
)?;
|
||||
fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> Result<()> {
|
||||
fn test(
|
||||
data: impl NdArray,
|
||||
axis: Option<i64>,
|
||||
keepdims: Option<i64>,
|
||||
select_last_index: Option<i64>,
|
||||
expected: impl NdArray,
|
||||
) -> Result<()> {
|
||||
let att_axis = AttributeProto {
|
||||
name: "axis".to_string(),
|
||||
ref_attr_name: "axis".to_string(),
|
||||
@ -3106,11 +3078,7 @@ fn test_argmax() -> Result<()> {
|
||||
fn test_leakyrelu() -> Result<()> {
|
||||
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-80
|
||||
// leakyrelu
|
||||
test(
|
||||
&[-1.0, 0.0, 1.0],
|
||||
Some(0.1),
|
||||
&[-0.1, 0.0, 1.0]
|
||||
)?;
|
||||
test(&[-1.0, 0.0, 1.0], Some(0.1), &[-0.1, 0.0, 1.0])?;
|
||||
fn test(data: impl NdArray, alpha: Option<f32>, expected: impl NdArray) -> Result<()> {
|
||||
let att_alpha = AttributeProto {
|
||||
name: "alpha".to_string(),
|
||||
@ -3168,7 +3136,11 @@ fn test_leakyrelu() -> Result<()> {
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
|
||||
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||
for both in z.to_vec1::<f64>()?.iter().zip(expected.to_vec1::<f64>()?.iter()) {
|
||||
for both in z
|
||||
.to_vec1::<f64>()?
|
||||
.iter()
|
||||
.zip(expected.to_vec1::<f64>()?.iter())
|
||||
{
|
||||
let (act, exp) = both;
|
||||
assert!(f64::abs(act - exp) < f32::EPSILON.into());
|
||||
}
|
||||
|
Reference in New Issue
Block a user