mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Adding Gemm and ArgMax operators to candle-onnx (#2231)
* feat(gemm): implement Gemm operator in candle-onnx * feat(onnx): Add support for ArgMax operator in candle-onnx * Apply rustfmt. * Remove argmax as it was already present. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -1274,6 +1274,30 @@ fn simple_eval_(
|
||||
let output = candle_nn::ops::leaky_relu(input, alpha.into())?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm
|
||||
"Gemm" => {
|
||||
let a = get(&node.input[0])?;
|
||||
let b = get(&node.input[1])?;
|
||||
let c = get(&node.input[2])?;
|
||||
|
||||
let alpha = get_attr_opt::<f32>(node, "alpha")?.copied().unwrap_or(1.0);
|
||||
let beta = get_attr_opt::<f32>(node, "beta")?.copied().unwrap_or(1.0);
|
||||
|
||||
let alpha = Tensor::full(alpha, a.shape(), &Device::Cpu)?;
|
||||
let beta = Tensor::full(beta, c.shape(), &Device::Cpu)?;
|
||||
|
||||
let trans_a = get_attr_opt::<i64>(node, "transA")?.copied().unwrap_or(0);
|
||||
let trans_b = get_attr_opt::<i64>(node, "transB")?.copied().unwrap_or(0);
|
||||
|
||||
let a = if trans_a == 0 { a.clone() } else { a.t()? };
|
||||
let b = if trans_b == 0 { b.clone() } else { b.t()? };
|
||||
|
||||
let output = a
|
||||
.broadcast_mul(&alpha)?
|
||||
.broadcast_matmul(&b)?
|
||||
.broadcast_add(&c.broadcast_mul(&beta)?)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user