mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Fix the layer norm to properly handle bias.
This commit is contained in:
@ -209,27 +209,31 @@ impl Dropout {
|
||||
}
|
||||
}
|
||||
|
||||
// This layer norm version handles both weight and bias so removes the mean.
|
||||
struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
fn new(weight: Tensor, bias: Tensor) -> Self {
|
||||
Self { weight, bias }
|
||||
fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
Self { weight, bias, eps }
|
||||
}
|
||||
|
||||
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let weight = vb.get(size, &format!("{p}.weight"))?;
|
||||
let bias = vb.get(size, &format!("{p}.bias"))?;
|
||||
Ok(Self { weight, bias })
|
||||
Ok(Self { weight, bias, eps })
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (seq_len, hidden_size) = x.shape().r2()?;
|
||||
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
||||
let mean_x = (x.sum(&[1])? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
||||
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
@ -268,7 +272,12 @@ impl BertEmbeddings {
|
||||
&format!("{p}.token_type_embeddings"),
|
||||
vb,
|
||||
)?;
|
||||
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
|
||||
let layer_norm = LayerNorm::load(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
&format!("{p}.LayerNorm"),
|
||||
vb,
|
||||
)?;
|
||||
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
|
||||
let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
|
||||
let token_type_ids = position_ids.zeros_like()?;
|
||||
@ -287,7 +296,7 @@ impl BertEmbeddings {
|
||||
let seq_len = input_ids.shape().r1()?;
|
||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||
let mut embeddings = (input_embeddings + token_type_embeddings)?;
|
||||
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||
if let Some(position_embeddings) = &self.position_embeddings {
|
||||
// TODO: Proper absolute positions?
|
||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||
@ -372,7 +381,12 @@ impl BertSelfOutput {
|
||||
&format!("{p}.dense"),
|
||||
vb,
|
||||
)?;
|
||||
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
|
||||
let layer_norm = LayerNorm::load(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
&format!("{p}.LayerNorm"),
|
||||
vb,
|
||||
)?;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
Ok(Self {
|
||||
dense,
|
||||
@ -453,7 +467,12 @@ impl BertOutput {
|
||||
&format!("{p}.dense"),
|
||||
vb,
|
||||
)?;
|
||||
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
|
||||
let layer_norm = LayerNorm::load(
|
||||
config.hidden_size,
|
||||
config.layer_norm_eps,
|
||||
&format!("{p}.LayerNorm"),
|
||||
vb,
|
||||
)?;
|
||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||
Ok(Self {
|
||||
dense,
|
||||
@ -587,7 +606,7 @@ fn main() -> Result<()> {
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
let token_ids = Tensor::new(&tokens[..], &device)?;
|
||||
let token_ids = Tensor::new(&tokens[..7], &device)?;
|
||||
println!("{token_ids}");
|
||||
let token_type_ids = token_ids.zeros_like()?;
|
||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
||||
|
Reference in New Issue
Block a user