Use candle_nn::embedding instead of local copies in a few models. (#1562)

This commit is contained in:
Jani Monoses
2024-01-10 22:36:27 +02:00
committed by GitHub
parent d3bdd788cf
commit 63944714f2
5 changed files with 6 additions and 31 deletions

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
use candle_nn::{embedding, Embedding, LayerNorm, Linear, Module, VarBuilder};
fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
@ -11,11 +11,6 @@ fn linear(size1: usize, size2: usize, bias: bool, vb: VarBuilder) -> Result<Line
Ok(Linear::new(weight, bias))
}
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
let weight = vb.get(size, "weight")?;
let bias = vb.get(size, "bias")?;