mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Move some shared functions to the nn module. (#221)
This commit is contained in:
@ -6,6 +6,8 @@ pub mod embedding;
|
||||
pub mod init;
|
||||
pub mod layer_norm;
|
||||
pub mod linear;
|
||||
pub mod loss;
|
||||
pub mod ops;
|
||||
pub mod optim;
|
||||
pub mod var_builder;
|
||||
pub mod vision;
|
||||
|
8
candle-nn/src/loss.rs
Normal file
8
candle-nn/src/loss.rs
Normal file
@ -0,0 +1,8 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
pub fn nll(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
|
||||
let b_sz = target.dim(0)?;
|
||||
inp.gather(target, 1)?
|
||||
.sum_all()?
|
||||
.affine(-1f64 / b_sz as f64, 0.)
|
||||
}
|
10
candle-nn/src/ops.rs
Normal file
10
candle-nn/src/ops.rs
Normal file
@ -0,0 +1,10 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
pub fn log_softmax<D: candle::shape::Dim>(xs: &Tensor, d: D) -> Result<Tensor> {
|
||||
let d = d.to_index(xs.shape(), "log-softmax")?;
|
||||
let max = xs.max_keepdim(d)?;
|
||||
let diff = xs.broadcast_sub(&max)?;
|
||||
let sum_exp = diff.exp()?.sum_keepdim(d)?;
|
||||
let log_sm = diff.broadcast_sub(&sum_exp.log()?)?;
|
||||
Ok(log_sm)
|
||||
}
|
Reference in New Issue
Block a user