Use candle_nn::embedding instead of local copies in a few models. (#1562)

This commit is contained in:
Jani Monoses
2024-01-10 22:36:27 +02:00
committed by GitHub
parent d3bdd788cf
commit 63944714f2
5 changed files with 6 additions and 31 deletions

View File

@ -1,6 +1,6 @@
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use candle_nn::{embedding, Embedding, Module, VarBuilder};
use serde::Deserialize;
pub const DTYPE: DType = DType::F32;
@ -112,11 +112,6 @@ impl Config {
}
}
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
Ok(Embedding::new(embeddings, hidden_size))
}
struct Dropout {
#[allow(dead_code)]
pr: f64,