mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add the pow operator. (#1583)
* Add the pow operator. * Support the pow operation in onnx.
This commit is contained in:
@ -1245,11 +1245,23 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn logsumexp() -> Result<()> {
|
||||
fn log_sum_exp() -> Result<()> {
|
||||
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let output = input.logsumexp(D::Minus1)?;
|
||||
let output = input.log_sum_exp(D::Minus1)?;
|
||||
// The expectations obtained from pytorch.
|
||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
||||
assert_close(&output, &expected, 0.00001)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pow() -> Result<()> {
|
||||
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||
let rhs = (&lhs - 2.)?;
|
||||
let res = lhs.pow(&rhs)?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res, 4)?,
|
||||
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user