diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index bfa4e69a..234e79d8 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -209,27 +209,31 @@ impl Dropout { } } +// This layer norm version handles both weight and bias so removes the mean. struct LayerNorm { weight: Tensor, bias: Tensor, + eps: f64, } impl LayerNorm { - fn new(weight: Tensor, bias: Tensor) -> Self { - Self { weight, bias } + fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { + Self { weight, bias, eps } } - fn load(size: usize, p: &str, vb: &VarBuilder) -> Result { + fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result { let weight = vb.get(size, &format!("{p}.weight"))?; let bias = vb.get(size, &format!("{p}.bias"))?; - Ok(Self { weight, bias }) + Ok(Self { weight, bias, eps }) } fn forward(&self, x: &Tensor) -> Result { let (seq_len, hidden_size) = x.shape().r2()?; - let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?; + let mean_x = (x.sum(&[1])? / hidden_size as f64)?; + let x = x.broadcast_sub(&mean_x)?; + let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?; let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?; - let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?; + let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?; let x = x_normed .broadcast_mul(&self.weight)? .broadcast_add(&self.bias)?; @@ -268,7 +272,12 @@ impl BertEmbeddings { &format!("{p}.token_type_embeddings"), vb, )?; - let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?; + let layer_norm = LayerNorm::load( + config.hidden_size, + config.layer_norm_eps, + &format!("{p}.LayerNorm"), + vb, + )?; let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect(); let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?; let token_type_ids = position_ids.zeros_like()?; @@ -287,7 +296,7 @@ impl BertEmbeddings { let seq_len = input_ids.shape().r1()?; let input_embeddings = self.word_embeddings.forward(input_ids)?; let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?; - let mut embeddings = (input_embeddings + token_type_embeddings)?; + let mut embeddings = (&input_embeddings + token_type_embeddings)?; if let Some(position_embeddings) = &self.position_embeddings { // TODO: Proper absolute positions? let position_ids = (0..seq_len as u32).collect::>(); @@ -372,7 +381,12 @@ impl BertSelfOutput { &format!("{p}.dense"), vb, )?; - let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?; + let layer_norm = LayerNorm::load( + config.hidden_size, + config.layer_norm_eps, + &format!("{p}.LayerNorm"), + vb, + )?; let dropout = Dropout::new(config.hidden_dropout_prob); Ok(Self { dense, @@ -453,7 +467,12 @@ impl BertOutput { &format!("{p}.dense"), vb, )?; - let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?; + let layer_norm = LayerNorm::load( + config.hidden_size, + config.layer_norm_eps, + &format!("{p}.LayerNorm"), + vb, + )?; let dropout = Dropout::new(config.hidden_dropout_prob); Ok(Self { dense, @@ -587,7 +606,7 @@ fn main() -> Result<()> { .map_err(E::msg)? .get_ids() .to_vec(); - let token_ids = Tensor::new(&tokens[..], &device)?; + let token_ids = Tensor::new(&tokens[..7], &device)?; println!("{token_ids}"); let token_type_ids = token_ids.zeros_like()?; let ys = model.forward(&token_ids, &token_type_ids)?;