Allow for batch dimensions in the embedding layer.

This commit is contained in:
laurent
2023-07-03 18:37:40 +01:00
parent 9784d1ed9f
commit b6d179cc1c

View File

@ -153,20 +153,28 @@ impl Config {
struct Embedding {
embeddings: Tensor,
hidden_size: usize,
}
impl Embedding {
fn new(embeddings: Tensor) -> Self {
Self { embeddings }
fn new(embeddings: Tensor, hidden_size: usize) -> Self {
Self {
embeddings,
hidden_size,
}
}
fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let embeddings = vb.get((size1, size2), &format!("{p}.weight"))?;
Ok(Self::new(embeddings))
fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?;
Ok(Self::new(embeddings, hidden_size))
}
fn forward(&self, indexes: &Tensor) -> Result<Tensor> {
let values = Tensor::embedding(indexes, &self.embeddings)?;
let mut final_dims = indexes.dims().to_vec();
final_dims.push(self.hidden_size);
let indexes = indexes.flatten_all()?;
let values = Tensor::embedding(&indexes, &self.embeddings)?;
let values = values.reshape(final_dims)?;
Ok(values)
}
}