mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 11:56:45 +00:00
[nn] Move the Embedding and Activation parts. (#116)
* Share the Embedding and Activation parts. * Tweak some activations.
This commit is contained in:
18
candle-nn/src/activation.rs
Normal file
18
candle-nn/src/activation.rs
Normal file
@ -0,0 +1,18 @@
|
||||
use candle::Tensor;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Activation {
|
||||
Gelu,
|
||||
Relu,
|
||||
Elu(f64),
|
||||
}
|
||||
|
||||
impl Activation {
|
||||
pub fn forward(&self, xs: &Tensor) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu => xs.gelu(),
|
||||
Self::Relu => xs.relu(),
|
||||
&Self::Elu(alpha) => xs.elu(alpha),
|
||||
}
|
||||
}
|
||||
}
|
29
candle-nn/src/embedding.rs
Normal file
29
candle-nn/src/embedding.rs
Normal file
@ -0,0 +1,29 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedding {
|
||||
embeddings: Tensor,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
||||
Self {
|
||||
embeddings,
|
||||
hidden_size,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embeddings(&self) -> &Tensor {
|
||||
&self.embeddings
|
||||
}
|
||||
|
||||
pub fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
||||
let mut final_dims = indexes.dims().to_vec();
|
||||
final_dims.push(self.hidden_size);
|
||||
let indexes = indexes.flatten_all()?;
|
||||
let values = Tensor::embedding(&indexes, &self.embeddings)?;
|
||||
let values = values.reshape(final_dims)?;
|
||||
Ok(values)
|
||||
}
|
||||
}
|
@ -1,5 +1,11 @@
|
||||
// For now this crate shares its error type with candle-core. We may introduce some separate
|
||||
// error type if needed or add some specialized cases on the candle-core side.
|
||||
mod activation;
|
||||
mod embedding;
|
||||
mod layer_norm;
|
||||
mod linear;
|
||||
|
||||
pub use activation::Activation;
|
||||
pub use embedding::Embedding;
|
||||
pub use layer_norm::LayerNorm;
|
||||
pub use linear::Linear;
|
||||
|
Reference in New Issue
Block a user