Move some shared functions to the nn module. (#221)

This commit is contained in:
Laurent Mazare
2023-07-22 14:25:11 +02:00
committed by GitHub
parent 43c7223292
commit 1f26042693
4 changed files with 24 additions and 19 deletions

10
candle-nn/src/ops.rs Normal file
View File

@ -0,0 +1,10 @@
use candle::{Result, Tensor};
pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
let d = d.to_index(xs.shape(), "log-softmax")?;
let max = xs.max_keepdim(d)?;
let diff = xs.broadcast_sub(&max)?;
let sum_exp = diff.exp()?.sum_keepdim(d)?;
let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
Ok(log_sm)
}