mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Support more ONNX ops. (#1267)
* Add LogSoftmax. * Support for Transpose.
This commit is contained in:
@ -59,6 +59,26 @@ pub fn simple_eval(
|
|||||||
Ok(dt.i)
|
Ok(dt.i)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let get_attr_is = |name: &str| match node.attribute.iter().find(|attr| attr.name == name) {
|
||||||
|
None => {
|
||||||
|
bail!(
|
||||||
|
"cannot find the '{name}' attribute in '{}' for {}",
|
||||||
|
node.op_type,
|
||||||
|
node.name
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Some(dt) => {
|
||||||
|
match dt.r#type() {
|
||||||
|
AttributeType::Ints => (),
|
||||||
|
rtype => bail!(
|
||||||
|
"unsupported type {rtype:?} for '{name}' attribute in '{}' for {}",
|
||||||
|
node.op_type,
|
||||||
|
node.name
|
||||||
|
),
|
||||||
|
}
|
||||||
|
Ok(dt.ints.as_slice())
|
||||||
|
}
|
||||||
|
};
|
||||||
// TODO: Validate node.input for each operator.
|
// TODO: Validate node.input for each operator.
|
||||||
match node.op_type.as_str() {
|
match node.op_type.as_str() {
|
||||||
"Add" => {
|
"Add" => {
|
||||||
@ -114,6 +134,24 @@ pub fn simple_eval(
|
|||||||
let output = input0.reshape(input1)?;
|
let output = input0.reshape(input1)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"LogSoftmax" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = match get_attr_i("axis") {
|
||||||
|
Err(_) => candle_nn::ops::softmax_last_dim(input)?,
|
||||||
|
Ok(axis) => {
|
||||||
|
let num_axis = input.rank() as i64;
|
||||||
|
let axis = if axis >= 0 {
|
||||||
|
axis as usize
|
||||||
|
} else if axis < -num_axis {
|
||||||
|
bail!("wrong axis in concat {axis} for shape {:?}", input.shape())
|
||||||
|
} else {
|
||||||
|
(num_axis - axis) as usize
|
||||||
|
};
|
||||||
|
candle_nn::ops::log_softmax(input, axis)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
"Softmax" => {
|
"Softmax" => {
|
||||||
let input = get(&node.input[0])?;
|
let input = get(&node.input[0])?;
|
||||||
let output = match get_attr_i("axis") {
|
let output = match get_attr_i("axis") {
|
||||||
@ -132,6 +170,17 @@ pub fn simple_eval(
|
|||||||
};
|
};
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"Transpose" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let output = match get_attr_is("perm") {
|
||||||
|
Err(_) => input.t()?,
|
||||||
|
Ok(perm) => {
|
||||||
|
let perm = perm.iter().map(|&v| v as usize).collect::<Vec<_>>();
|
||||||
|
input.permute(perm)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
"Concat" => {
|
"Concat" => {
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Concat
|
||||||
let inputs = node
|
let inputs = node
|
||||||
|
Reference in New Issue
Block a user