mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Use candle_nn::embedding instead of local copies in a few models. (#1562)
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
use super::with_tracing::{layer_norm, linear, LayerNorm, Linear};
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module, VarBuilder};
|
||||
use candle_nn::{embedding, Embedding, Module, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
@ -112,11 +112,6 @@ impl Config {
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Embedding> {
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
|
||||
struct Dropout {
|
||||
#[allow(dead_code)]
|
||||
pr: f64,
|
||||
|
Reference in New Issue
Block a user