Add the pow operator. (#1583)

* Add the pow operator.

* Support the pow operation in onnx.
This commit is contained in:
Laurent Mazare
2024-01-13 20:24:06 +01:00
committed by GitHub
parent 88618255cb
commit e6d86b0819
3 changed files with 31 additions and 3 deletions

View File

@ -2578,11 +2578,21 @@ impl Tensor {
}
/// Returns log(sum(exp(tensor), dim)).
pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?;
let sum = exp.sum(sum_dims)?;
sum.log()
}
/// Pointwise pow operation.
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.mul(&self.log()?)?.exp()
}
/// Broadcasting version of `pow`.
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.broadcast_mul(&self.log()?)?.exp()
}
}
macro_rules! bin_trait {