Fix the layer norm to properly handle bias.

This commit is contained in:
laurent
2023-07-03 16:45:03 +01:00
parent f379b8feae
commit a7f03a7bb6

View File

@ -209,27 +209,31 @@ impl Dropout {
}
}
// This layer norm version handles both weight and bias so removes the mean.
struct LayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
}
impl LayerNorm {
fn new(weight: Tensor, bias: Tensor) -> Self {
Self { weight, bias }
fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
Self { weight, bias, eps }
}
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get(size, &format!("{p}.weight"))?;
let bias = vb.get(size, &format!("{p}.bias"))?;
Ok(Self { weight, bias })
Ok(Self { weight, bias, eps })
}
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (seq_len, hidden_size) = x.shape().r2()?;
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
let mean_x = (x.sum(&[1])? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
@ -268,7 +272,12 @@ impl BertEmbeddings {
&format!("{p}.token_type_embeddings"),
vb,
)?;
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
let layer_norm = LayerNorm::load(
config.hidden_size,
config.layer_norm_eps,
&format!("{p}.LayerNorm"),
vb,
)?;
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
let token_type_ids = position_ids.zeros_like()?;
@ -287,7 +296,7 @@ impl BertEmbeddings {
let seq_len = input_ids.shape().r1()?;
let input_embeddings = self.word_embeddings.forward(input_ids)?;
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
let mut embeddings = (input_embeddings + token_type_embeddings)?;
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
if let Some(position_embeddings) = &self.position_embeddings {
// TODO: Proper absolute positions?
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
@ -372,7 +381,12 @@ impl BertSelfOutput {
&format!("{p}.dense"),
vb,
)?;
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
let layer_norm = LayerNorm::load(
config.hidden_size,
config.layer_norm_eps,
&format!("{p}.LayerNorm"),
vb,
)?;
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self {
dense,
@ -453,7 +467,12 @@ impl BertOutput {
&format!("{p}.dense"),
vb,
)?;
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
let layer_norm = LayerNorm::load(
config.hidden_size,
config.layer_norm_eps,
&format!("{p}.LayerNorm"),
vb,
)?;
let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self {
dense,
@ -587,7 +606,7 @@ fn main() -> Result<()> {
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?;
let token_ids = Tensor::new(&tokens[..7], &device)?;
println!("{token_ids}");
let token_type_ids = token_ids.zeros_like()?;
let ys = model.forward(&token_ids, &token_type_ids)?;