Compare commits

...

18 Commits

Author SHA1 Message Date
56a1b7d97e Apply rustfmt. 2024-06-04 22:47:20 +02:00
47c7ecc948 Merge branch 'refs/heads/leaky_relu' into operators-argmin-argmax-leakyrelu 2024-06-04 21:13:38 +02:00
c441716bd2 Fix a weird automatic RustRover change 2024-06-04 21:13:30 +02:00
a5b81e2c02 Merge branch 'refs/heads/argmin-argmax' into operators-argmin-argmax-leakyrelu
# Conflicts:
#	candle-onnx/src/eval.rs
#	candle-onnx/tests/ops.rs
2024-06-04 21:09:59 +02:00
08fd7f7119 Typo fix 2024-05-10 00:51:01 +02:00
2ced31b530 Added a test for LeakyRelu 2024-05-10 00:50:05 +02:00
91b0d526ee Added LeakyRelu implementation 2024-05-10 00:49:54 +02:00
4de76b89a2 Added tests for ArgMax 2024-05-09 20:45:53 +02:00
8f1119b3e0 Added ArgMax operator implementation 2024-05-09 20:45:41 +02:00
c4743aa570 Added tests from pytorch examples 2024-05-09 20:22:55 +02:00
9a273196b7 ArgMin now returns a tensor with i64 2024-05-09 20:22:22 +02:00
13b88547f7 Added tests for ArgMin 2024-05-09 03:00:22 +02:00
1caf62e4a6 Added ArgMin operator implementation 2024-05-09 03:00:15 +02:00
a06b2ded28 Merge branch 'refs/heads/random' into operators-random-exp
# Conflicts:
#	candle-onnx/tests/ops.rs
2024-04-23 17:36:33 +02:00
a867d652d3 Merge branch 'refs/heads/exp' into operators-random-exp 2024-04-23 17:33:05 +02:00
70388c27b6 Added Exp operator implementation 2024-04-19 22:48:05 +02:00
0fa41a791f Use is_some to check if seed is present 2024-04-19 16:09:45 +02:00
46073c5f73 Add basic RandomUniform implementation 2024-04-19 16:06:43 +02:00
2 changed files with 514 additions and 0 deletions

View File

@ -1027,6 +1027,78 @@ pub fn simple_eval(
};
values.insert(node.output[0].clone(), output);
}
"ArgMin" => {
let input = get(&node.input[0])?;
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
let rank_i64: i64 = input.rank().try_into().unwrap();
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
bail!(
"axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]",
axis_i64,
-rank_i64,
rank_i64 - 1
)
}
let axis = input.normalize_axis(axis_i64)?;
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?
.copied()
.unwrap_or(0);
if select_last_index == 1 {
bail!("select_last_index for ArgMin is currently not supported")
}
let output = if keepdims == 1 {
input.argmin_keepdim(axis)?
} else {
input.argmin(axis)?
}
.to_dtype(DType::I64)?;
values.insert(node.output[0].clone(), output);
}
"ArgMax" => {
let input = get(&node.input[0])?;
let axis_i64: i64 = get_attr_opt(node, "axis")?.copied().unwrap_or(0);
let rank_i64: i64 = input.rank().try_into().unwrap();
if axis_i64 < -rank_i64 || axis_i64 >= rank_i64 {
bail!(
"axis ({}) out of accepted range [-rank, rank-1] which was [{}, {}]",
axis_i64,
-rank_i64,
rank_i64 - 1
)
}
let axis = input.normalize_axis(axis_i64)?;
let keepdims: i64 = get_attr_opt(node, "keepdims")?.copied().unwrap_or(1);
let select_last_index: i64 = get_attr_opt(node, "select_last_index")?
.copied()
.unwrap_or(0);
if select_last_index == 1 {
bail!("select_last_index for ArgMin is currently not supported")
}
let output = if keepdims == 1 {
input.argmax_keepdim(axis)?
} else {
input.argmax(axis)?
}
.to_dtype(DType::I64)?;
values.insert(node.output[0].clone(), output);
}
"LeakyRelu" => {
let input = get(&node.input[0])?;
let dt = input.dtype();
match dt {
DType::U8 | DType::U32 | DType::I64 => {
bail!(
"unsupported dtype {}, only float types are allowed for LeakyRelu",
dt.as_str()
)
}
DType::BF16 | DType::F16 | DType::F32 | DType::F64 => {}
}
let alpha = get_attr_opt::<f32>(node, "alpha")?.copied().unwrap_or(0.01);
let output = candle_nn::ops::leaky_relu(input, alpha.into())?;
values.insert(node.output[0].clone(), output);
}
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
}
}

View File

@ -2708,3 +2708,445 @@ fn test_ceil() -> Result<()> {
Ok(())
}
// "ArgMin"
#[test]
fn test_argmin() -> Result<()> {
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-7
// default_axes_keepdims
test(
&[[2u32, 1u32], [3u32, 10u32]],
None,
Some(1),
None,
&[[0i64, 0i64]],
)?;
// keepdims
test(
&[[2u32, 1u32], [3u32, 10u32]],
Some(1),
Some(1),
None,
&[[1i64], [0i64]],
)?;
// // negative_axis_keepdims
test(
&[[2u32, 1u32], [3u32, 10u32]],
Some(-1),
Some(1),
None,
&[[1i64], [0i64]],
)?;
// no_keepdims
test(
&[[2u32, 1u32], [3u32, 10u32]],
None,
Some(0),
None,
&[0i64, 0i64],
)?;
// tests from https://pytorch.org/docs/stable/generated/torch.argmin.html#torch.argmin
test(
&[
[0.1139, 0.2254, -0.1381, 0.3687],
[1.0100, -1.1975, -0.0102, -0.4732],
[-0.9240, 0.1207, -0.7506, -1.0213],
[1.7809, -1.2960, 0.9384, 0.1438],
],
Some(1),
Some(0),
None,
&[2i64, 1i64, 3i64, 1i64],
)?;
test(
&[
[0.1139, 0.2254, -0.1381, 0.3687],
[1.0100, -1.1975, -0.0102, -0.4732],
[-0.9240, 0.1207, -0.7506, -1.0213],
[1.7809, -1.2960, 0.9384, 0.1438],
],
Some(1),
None,
None,
&[[2i64], [1i64], [3i64], [1i64]],
)?;
fn test(
data: impl NdArray,
axis: Option<i64>,
keepdims: Option<i64>,
select_last_index: Option<i64>,
expected: impl NdArray,
) -> Result<()> {
let att_axis = AttributeProto {
name: "axis".to_string(),
ref_attr_name: "axis".to_string(),
i: axis.unwrap_or(0),
doc_string: "axis".to_string(),
r#type: 2, // INT
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_keepdims = AttributeProto {
name: "keepdims".to_string(),
ref_attr_name: "keepdims".to_string(),
i: keepdims.unwrap_or(1),
doc_string: "keepdims".to_string(),
r#type: 2, // INT
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_select_last_index = AttributeProto {
name: "select_last_index".to_string(),
ref_attr_name: "select_last_index".to_string(),
i: select_last_index.unwrap_or(0),
doc_string: "select_last_index".to_string(),
r#type: 2, // INT
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let attrs = {
let mut mut_attrs = vec![];
if axis.is_some() {
mut_attrs.push(att_axis);
}
if keepdims.is_some() {
mut_attrs.push(att_keepdims);
}
if select_last_index.is_some() {
mut_attrs.push(att_select_last_index);
}
mut_attrs
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "ArgMin".to_string(),
domain: "".to_string(),
attribute: attrs,
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
1 => assert_eq!(z.to_vec1::<i64>()?, expected.to_vec1::<i64>()?),
2 => assert_eq!(z.to_vec2::<i64>()?, expected.to_vec2::<i64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
// "ArgMax"
#[test]
fn test_argmax() -> Result<()> {
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-6
// default_axes_keepdims
test(
&[[2u32, 1u32], [3u32, 10u32]],
None,
Some(1),
None,
&[[1i64, 1i64]],
)?;
// keepdims
test(
&[[2u32, 1u32], [3u32, 10u32]],
Some(1),
Some(1),
None,
&[[0i64], [1i64]],
)?;
// // negative_axis_keepdims
test(
&[[2u32, 1u32], [3u32, 10u32]],
Some(-1),
Some(1),
None,
&[[0i64], [1i64]],
)?;
// no_keepdims
test(
&[[2u32, 1u32], [3u32, 10u32]],
None,
Some(0),
None,
&[1i64, 1i64],
)?;
// tests from https://pytorch.org/docs/stable/generated/torch.argmax.html
test(
&[
[1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195],
],
Some(1),
Some(0),
None,
&[0i64, 2i64, 0i64, 1i64],
)?;
test(
&[
[1.3398, 0.2663, -0.2686, 0.2450],
[-0.7401, -0.8805, -0.3402, -1.1936],
[0.4907, -1.3948, -1.0691, -0.3132],
[-1.6092, 0.5419, -0.2993, 0.3195],
],
Some(1),
None,
None,
&[[0i64], [2i64], [0i64], [1i64]],
)?;
fn test(
data: impl NdArray,
axis: Option<i64>,
keepdims: Option<i64>,
select_last_index: Option<i64>,
expected: impl NdArray,
) -> Result<()> {
let att_axis = AttributeProto {
name: "axis".to_string(),
ref_attr_name: "axis".to_string(),
i: axis.unwrap_or(0),
doc_string: "axis".to_string(),
r#type: 2, // INT
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_keepdims = AttributeProto {
name: "keepdims".to_string(),
ref_attr_name: "keepdims".to_string(),
i: keepdims.unwrap_or(1),
doc_string: "keepdims".to_string(),
r#type: 2, // INT
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let att_select_last_index = AttributeProto {
name: "select_last_index".to_string(),
ref_attr_name: "select_last_index".to_string(),
i: select_last_index.unwrap_or(0),
doc_string: "select_last_index".to_string(),
r#type: 2, // INT
f: 0.0,
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let attrs = {
let mut mut_attrs = vec![];
if axis.is_some() {
mut_attrs.push(att_axis);
}
if keepdims.is_some() {
mut_attrs.push(att_keepdims);
}
if select_last_index.is_some() {
mut_attrs.push(att_select_last_index);
}
mut_attrs
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "ArgMax".to_string(),
domain: "".to_string(),
attribute: attrs,
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
match expected.dims().len() {
1 => assert_eq!(z.to_vec1::<i64>()?, expected.to_vec1::<i64>()?),
2 => assert_eq!(z.to_vec2::<i64>()?, expected.to_vec2::<i64>()?),
_ => unreachable!(),
};
Ok(())
}
Ok(())
}
// "LeakyRelu"
#[test]
fn test_leakyrelu() -> Result<()> {
// tests from https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-80
// leakyrelu
test(&[-1.0, 0.0, 1.0], Some(0.1), &[-0.1, 0.0, 1.0])?;
fn test(data: impl NdArray, alpha: Option<f32>, expected: impl NdArray) -> Result<()> {
let att_alpha = AttributeProto {
name: "alpha".to_string(),
ref_attr_name: "alpha".to_string(),
i: 0,
doc_string: "alpha".to_string(),
r#type: 1, // FLOAT
f: alpha.unwrap_or(0.01),
s: vec![],
t: None,
g: None,
sparse_tensor: None,
tp: None,
floats: vec![],
ints: vec![],
strings: vec![],
tensors: vec![],
graphs: vec![],
sparse_tensors: vec![],
type_protos: vec![],
};
let attrs = {
let mut mut_attrs = vec![];
if alpha.is_some() {
mut_attrs.push(att_alpha);
}
mut_attrs
};
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "LeakyRelu".to_string(),
domain: "".to_string(),
attribute: attrs,
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
let expected = Tensor::new(expected, &Device::Cpu)?;
for both in z
.to_vec1::<f64>()?
.iter()
.zip(expected.to_vec1::<f64>()?.iter())
{
let (act, exp) = both;
assert!(f64::abs(act - exp) < f32::EPSILON.into());
}
Ok(())
}
Ok(())
}