mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Merge branch 'refs/heads/argmin-argmax' into operators-argmin-argmax-leakyrelu
# Conflicts: # candle-onnx/src/eval.rs # candle-onnx/tests/ops.rs
This commit is contained in:
@ -1027,6 +1027,46 @@ pub fn simple_eval(
|
|||||||
};
|
};
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"ArgMin" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
|
||||||
|
let rank_i64: i64 = input.rank().try_into().unwrap();
|
||||||
|
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
|
||||||
|
bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1)
|
||||||
|
}
|
||||||
|
let axis = input.normalize_axis(axis_i64)?;
|
||||||
|
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
|
||||||
|
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0);
|
||||||
|
if select_last_index == 1 {
|
||||||
|
bail!("select_last_index for ArgMin is currently not supported")
|
||||||
|
}
|
||||||
|
let output = if keepdims == 1 {
|
||||||
|
input.argmin_keepdim(axis)?
|
||||||
|
} else {
|
||||||
|
input.argmin(axis)?
|
||||||
|
}.to_dtype(DType::I64)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"ArgMax" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
|
||||||
|
let rank_i64: i64 = input.rank().try_into().unwrap();
|
||||||
|
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
|
||||||
|
bail!("axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]", axis_i64, -rank_i64, rank_i64-1)
|
||||||
|
}
|
||||||
|
let axis = input.normalize_axis(axis_i64)?;
|
||||||
|
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
|
||||||
|
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?.copied().unwrap_or(0);
|
||||||
|
if select_last_index == 1 {
|
||||||
|
bail!("select_last_index for ArgMin is currently not supported")
|
||||||
|
}
|
||||||
|
let output = if keepdims == 1 {
|
||||||
|
input.argmax_keepdim(axis)?
|
||||||
|
} else {
|
||||||
|
input.argmax(axis)?
|
||||||
|
}.to_dtype(DType::I64)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2708,3 +2708,395 @@ fn test_ceil() -> Result<()> {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// "ArgMin"
|
||||||
|
#[test]
|
||||||
|
fn test_argmin() -> std::result::Result<()> {
|
||||||
|
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-7
|
||||||
|
// default_axes_keepdims
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[2u32, 1u32],
|
||||||
|
[3u32, 10u32]
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
Some(1),
|
||||||
|
None,
|
||||||
|
&[
|
||||||
|
[0i64, 0i64],
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
// keepdims
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[2u32, 1u32],
|
||||||
|
[3u32, 10u32]
|
||||||
|
],
|
||||||
|
Some(1),
|
||||||
|
Some(1),
|
||||||
|
None,
|
||||||
|
&[
|
||||||
|
[1i64],
|
||||||
|
[0i64]
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
// // negative_axis_keepdims
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[2u32, 1u32],
|
||||||
|
[3u32, 10u32]
|
||||||
|
],
|
||||||
|
Some(-1),
|
||||||
|
Some(1),
|
||||||
|
None,
|
||||||
|
&[
|
||||||
|
[1i64],
|
||||||
|
[0i64]
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
// no_keepdims
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[2u32, 1u32],
|
||||||
|
[3u32, 10u32]
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
Some(0),
|
||||||
|
None,
|
||||||
|
&[0i64, 0i64],
|
||||||
|
)?;
|
||||||
|
// tests from https://pytorch.org/docs/stable/generated/torch.argmin.html#torch.argmin
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[0.1139, 0.2254, -0.1381, 0.3687],
|
||||||
|
[1.0100, -1.1975, -0.0102, -0.4732],
|
||||||
|
[-0.9240, 0.1207, -0.7506, -1.0213],
|
||||||
|
[1.7809, -1.2960, 0.9384, 0.1438]
|
||||||
|
],
|
||||||
|
Some(1),
|
||||||
|
Some(0),
|
||||||
|
None,
|
||||||
|
&[2i64, 1i64, 3i64, 1i64],
|
||||||
|
)?;
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[0.1139, 0.2254, -0.1381, 0.3687],
|
||||||
|
[1.0100, -1.1975, -0.0102, -0.4732],
|
||||||
|
[-0.9240, 0.1207, -0.7506, -1.0213],
|
||||||
|
[1.7809, -1.2960, 0.9384, 0.1438]
|
||||||
|
],
|
||||||
|
Some(1),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
&[[2i64], [1i64], [3i64], [1i64]],
|
||||||
|
)?;
|
||||||
|
fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> std::result::Result<()> {
|
||||||
|
let att_axis = AttributeProto {
|
||||||
|
name: "axis".to_string(),
|
||||||
|
ref_attr_name: "axis".to_string(),
|
||||||
|
i: axis.unwrap_or(0),
|
||||||
|
doc_string: "axis".to_string(),
|
||||||
|
r#type: 2, // INT
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let att_keepdims = AttributeProto {
|
||||||
|
name: "keepdims".to_string(),
|
||||||
|
ref_attr_name: "keepdims".to_string(),
|
||||||
|
i: keepdims.unwrap_or(1),
|
||||||
|
doc_string: "keepdims".to_string(),
|
||||||
|
r#type: 2, // INT
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let att_select_last_index = AttributeProto {
|
||||||
|
name: "select_last_index".to_string(),
|
||||||
|
ref_attr_name: "select_last_index".to_string(),
|
||||||
|
i: select_last_index.unwrap_or(0),
|
||||||
|
doc_string: "select_last_index".to_string(),
|
||||||
|
r#type: 2, // INT
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let attrs = {
|
||||||
|
let mut mut_attrs = vec![];
|
||||||
|
if axis.is_some() {
|
||||||
|
mut_attrs.push(att_axis);
|
||||||
|
}
|
||||||
|
if keepdims.is_some() {
|
||||||
|
mut_attrs.push(att_keepdims);
|
||||||
|
}
|
||||||
|
if select_last_index.is_some() {
|
||||||
|
mut_attrs.push(att_select_last_index);
|
||||||
|
}
|
||||||
|
mut_attrs
|
||||||
|
};
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "ArgMin".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: attrs,
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||||
|
match expected.dims().len() {
|
||||||
|
1 => assert_eq!(z.to_vec1::<i64>()?, expected.to_vec1::<i64>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<i64>()?, expected.to_vec2::<i64>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// "ArgMax"
|
||||||
|
#[test]
|
||||||
|
fn test_argmax() -> std::result::Result<()> {
|
||||||
|
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-6
|
||||||
|
// default_axes_keepdims
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[2u32, 1u32],
|
||||||
|
[3u32, 10u32]
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
Some(1),
|
||||||
|
None,
|
||||||
|
&[
|
||||||
|
[1i64, 1i64],
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
// keepdims
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[2u32, 1u32],
|
||||||
|
[3u32, 10u32]
|
||||||
|
],
|
||||||
|
Some(1),
|
||||||
|
Some(1),
|
||||||
|
None,
|
||||||
|
&[
|
||||||
|
[0i64],
|
||||||
|
[1i64]
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
// // negative_axis_keepdims
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[2u32, 1u32],
|
||||||
|
[3u32, 10u32]
|
||||||
|
],
|
||||||
|
Some(-1),
|
||||||
|
Some(1),
|
||||||
|
None,
|
||||||
|
&[
|
||||||
|
[0i64],
|
||||||
|
[1i64]
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
// no_keepdims
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[2u32, 1u32],
|
||||||
|
[3u32, 10u32]
|
||||||
|
],
|
||||||
|
None,
|
||||||
|
Some(0),
|
||||||
|
None,
|
||||||
|
&[1i64, 1i64],
|
||||||
|
)?;
|
||||||
|
// tests from https://pytorch.org/docs/stable/generated/torch.argmax.html
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[1.3398, 0.2663, -0.2686, 0.2450],
|
||||||
|
[-0.7401, -0.8805, -0.3402, -1.1936],
|
||||||
|
[0.4907, -1.3948, -1.0691, -0.3132],
|
||||||
|
[-1.6092, 0.5419, -0.2993, 0.3195]
|
||||||
|
],
|
||||||
|
Some(1),
|
||||||
|
Some(0),
|
||||||
|
None,
|
||||||
|
&[0i64, 2i64, 0i64, 1i64],
|
||||||
|
)?;
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[1.3398, 0.2663, -0.2686, 0.2450],
|
||||||
|
[-0.7401, -0.8805, -0.3402, -1.1936],
|
||||||
|
[0.4907, -1.3948, -1.0691, -0.3132],
|
||||||
|
[-1.6092, 0.5419, -0.2993, 0.3195]
|
||||||
|
],
|
||||||
|
Some(1),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
&[[0i64], [2i64], [0i64], [1i64]],
|
||||||
|
)?;
|
||||||
|
fn test(data: impl NdArray, axis: Option<i64>, keepdims: Option<i64>, select_last_index: Option<i64>, expected: impl NdArray) -> std::result::Result<()> {
|
||||||
|
let att_axis = AttributeProto {
|
||||||
|
name: "axis".to_string(),
|
||||||
|
ref_attr_name: "axis".to_string(),
|
||||||
|
i: axis.unwrap_or(0),
|
||||||
|
doc_string: "axis".to_string(),
|
||||||
|
r#type: 2, // INT
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let att_keepdims = AttributeProto {
|
||||||
|
name: "keepdims".to_string(),
|
||||||
|
ref_attr_name: "keepdims".to_string(),
|
||||||
|
i: keepdims.unwrap_or(1),
|
||||||
|
doc_string: "keepdims".to_string(),
|
||||||
|
r#type: 2, // INT
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let att_select_last_index = AttributeProto {
|
||||||
|
name: "select_last_index".to_string(),
|
||||||
|
ref_attr_name: "select_last_index".to_string(),
|
||||||
|
i: select_last_index.unwrap_or(0),
|
||||||
|
doc_string: "select_last_index".to_string(),
|
||||||
|
r#type: 2, // INT
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
let attrs = {
|
||||||
|
let mut mut_attrs = vec![];
|
||||||
|
if axis.is_some() {
|
||||||
|
mut_attrs.push(att_axis);
|
||||||
|
}
|
||||||
|
if keepdims.is_some() {
|
||||||
|
mut_attrs.push(att_keepdims);
|
||||||
|
}
|
||||||
|
if select_last_index.is_some() {
|
||||||
|
mut_attrs.push(att_select_last_index);
|
||||||
|
}
|
||||||
|
mut_attrs
|
||||||
|
};
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "ArgMax".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: attrs,
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||||
|
match expected.dims().len() {
|
||||||
|
1 => assert_eq!(z.to_vec1::<i64>()?, expected.to_vec1::<i64>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<i64>()?, expected.to_vec2::<i64>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user