mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
50 lines
1.2 KiB
Rust
50 lines
1.2 KiB
Rust
//! Embedding Layer.
|
|
use candle::{Result, Tensor};
|
|
|
|
#[derive(Clone, Debug)]
|
|
pub struct Embedding {
|
|
embeddings: Tensor,
|
|
hidden_size: usize,
|
|
}
|
|
|
|
impl Embedding {
|
|
pub fn new(embeddings: Tensor, hidden_size: usize) -> Self {
|
|
Self {
|
|
embeddings,
|
|
hidden_size,
|
|
}
|
|
}
|
|
|
|
pub fn embeddings(&self) -> &Tensor {
|
|
&self.embeddings
|
|
}
|
|
|
|
/// Get the hidden size of the embedding matrix
|
|
pub fn hidden_size(&self) -> usize {
|
|
self.hidden_size
|
|
}
|
|
}
|
|
|
|
impl crate::Module for Embedding {
|
|
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
|
|
let mut final_dims = indexes.dims().to_vec();
|
|
final_dims.push(self.hidden_size);
|
|
let indexes = indexes.flatten_all()?;
|
|
let values = self.embeddings.index_select(&indexes, 0)?;
|
|
let values = values.reshape(final_dims)?;
|
|
Ok(values)
|
|
}
|
|
}
|
|
|
|
pub fn embedding(in_size: usize, out_size: usize, vb: crate::VarBuilder) -> Result<Embedding> {
|
|
let embeddings = vb.get_with_hints(
|
|
(in_size, out_size),
|
|
"weight",
|
|
crate::Init::Randn {
|
|
mean: 0.,
|
|
stdev: 1.,
|
|
},
|
|
)?;
|
|
Ok(Embedding::new(embeddings, out_size))
|
|
}
|