Fix the layer norm to properly handle bias.

This commit is contained in:
laurent
2023-07-03 16:45:03 +01:00
parent f379b8feae
commit a7f03a7bb6

View File

@ -209,27 +209,31 @@ impl Dropout {
} }
} }
// This layer norm version handles both weight and bias so removes the mean.
struct LayerNorm { struct LayerNorm {
weight: Tensor, weight: Tensor,
bias: Tensor, bias: Tensor,
eps: f64,
} }
impl LayerNorm { impl LayerNorm {
fn new(weight: Tensor, bias: Tensor) -> Self { fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
Self { weight, bias } Self { weight, bias, eps }
} }
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> { fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
let weight = vb.get(size, &format!("{p}.weight"))?; let weight = vb.get(size, &format!("{p}.weight"))?;
let bias = vb.get(size, &format!("{p}.bias"))?; let bias = vb.get(size, &format!("{p}.bias"))?;
Ok(Self { weight, bias }) Ok(Self { weight, bias, eps })
} }
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let (seq_len, hidden_size) = x.shape().r2()?; let (seq_len, hidden_size) = x.shape().r2()?;
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?; let mean_x = (x.sum(&[1])? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
let x = x_normed let x = x_normed
.broadcast_mul(&self.weight)? .broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?; .broadcast_add(&self.bias)?;
@ -268,7 +272,12 @@ impl BertEmbeddings {
&format!("{p}.token_type_embeddings"), &format!("{p}.token_type_embeddings"),
vb, vb,
)?; )?;
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?; let layer_norm = LayerNorm::load(
config.hidden_size,
config.layer_norm_eps,
&format!("{p}.LayerNorm"),
vb,
)?;
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect(); let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?; let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
let token_type_ids = position_ids.zeros_like()?; let token_type_ids = position_ids.zeros_like()?;
@ -287,7 +296,7 @@ impl BertEmbeddings {
let seq_len = input_ids.shape().r1()?; let seq_len = input_ids.shape().r1()?;
let input_embeddings = self.word_embeddings.forward(input_ids)?; let input_embeddings = self.word_embeddings.forward(input_ids)?;
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
let mut embeddings = (input_embeddings + token_type_embeddings)?; let mut embeddings = (&input_embeddings + token_type_embeddings)?;
if let Some(position_embeddings) = &self.position_embeddings { if let Some(position_embeddings) = &self.position_embeddings {
// TODO: Proper absolute positions? // TODO: Proper absolute positions?
let position_ids = (0..seq_len as u32).collect::<Vec<_>>(); let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
@ -372,7 +381,12 @@ impl BertSelfOutput {
&format!("{p}.dense"), &format!("{p}.dense"),
vb, vb,
)?; )?;
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?; let layer_norm = LayerNorm::load(
config.hidden_size,
config.layer_norm_eps,
&format!("{p}.LayerNorm"),
vb,
)?;
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self { Ok(Self {
dense, dense,
@ -453,7 +467,12 @@ impl BertOutput {
&format!("{p}.dense"), &format!("{p}.dense"),
vb, vb,
)?; )?;
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?; let layer_norm = LayerNorm::load(
config.hidden_size,
config.layer_norm_eps,
&format!("{p}.LayerNorm"),
vb,
)?;
let dropout = Dropout::new(config.hidden_dropout_prob); let dropout = Dropout::new(config.hidden_dropout_prob);
Ok(Self { Ok(Self {
dense, dense,
@ -587,7 +606,7 @@ fn main() -> Result<()> {
.map_err(E::msg)? .map_err(E::msg)?
.get_ids() .get_ids()
.to_vec(); .to_vec();
let token_ids = Tensor::new(&tokens[..], &device)?; let token_ids = Tensor::new(&tokens[..7], &device)?;
println!("{token_ids}"); println!("{token_ids}");
let token_type_ids = token_ids.zeros_like()?; let token_type_ids = token_ids.zeros_like()?;
let ys = model.forward(&token_ids, &token_type_ids)?; let ys = model.forward(&token_ids, &token_type_ids)?;