mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Fix the layer norm to properly handle bias.
This commit is contained in:
@ -209,27 +209,31 @@ impl Dropout {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This layer norm version handles both weight and bias so removes the mean.
|
||||||
struct LayerNorm {
|
struct LayerNorm {
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
bias: Tensor,
|
bias: Tensor,
|
||||||
|
eps: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LayerNorm {
|
impl LayerNorm {
|
||||||
fn new(weight: Tensor, bias: Tensor) -> Self {
|
fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||||
Self { weight, bias }
|
Self { weight, bias, eps }
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load(size: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
fn load(size: usize, eps: f64, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
let weight = vb.get(size, &format!("{p}.weight"))?;
|
let weight = vb.get(size, &format!("{p}.weight"))?;
|
||||||
let bias = vb.get(size, &format!("{p}.bias"))?;
|
let bias = vb.get(size, &format!("{p}.bias"))?;
|
||||||
Ok(Self { weight, bias })
|
Ok(Self { weight, bias, eps })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let (seq_len, hidden_size) = x.shape().r2()?;
|
let (seq_len, hidden_size) = x.shape().r2()?;
|
||||||
let norm_x = ((x * x)?.sum(&[1])? / hidden_size as f64)?;
|
let mean_x = (x.sum(&[1])? / hidden_size as f64)?;
|
||||||
|
let x = x.broadcast_sub(&mean_x)?;
|
||||||
|
let norm_x = ((&x * &x)?.sum(&[1])? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((seq_len, hidden_size))?;
|
||||||
let x_normed = (x / (norm_x + 1e-5)?.sqrt()?)?;
|
let x_normed = (x / (norm_x + self.eps)?.sqrt()?)?;
|
||||||
let x = x_normed
|
let x = x_normed
|
||||||
.broadcast_mul(&self.weight)?
|
.broadcast_mul(&self.weight)?
|
||||||
.broadcast_add(&self.bias)?;
|
.broadcast_add(&self.bias)?;
|
||||||
@ -268,7 +272,12 @@ impl BertEmbeddings {
|
|||||||
&format!("{p}.token_type_embeddings"),
|
&format!("{p}.token_type_embeddings"),
|
||||||
vb,
|
vb,
|
||||||
)?;
|
)?;
|
||||||
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
|
let layer_norm = LayerNorm::load(
|
||||||
|
config.hidden_size,
|
||||||
|
config.layer_norm_eps,
|
||||||
|
&format!("{p}.LayerNorm"),
|
||||||
|
vb,
|
||||||
|
)?;
|
||||||
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
|
let position_ids: Vec<_> = (0..config.max_position_embeddings as u32).collect();
|
||||||
let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
|
let position_ids = Tensor::new(&position_ids[..], &vb.device)?.unsqueeze(0)?;
|
||||||
let token_type_ids = position_ids.zeros_like()?;
|
let token_type_ids = position_ids.zeros_like()?;
|
||||||
@ -287,7 +296,7 @@ impl BertEmbeddings {
|
|||||||
let seq_len = input_ids.shape().r1()?;
|
let seq_len = input_ids.shape().r1()?;
|
||||||
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
let input_embeddings = self.word_embeddings.forward(input_ids)?;
|
||||||
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
let token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)?;
|
||||||
let mut embeddings = (input_embeddings + token_type_embeddings)?;
|
let mut embeddings = (&input_embeddings + token_type_embeddings)?;
|
||||||
if let Some(position_embeddings) = &self.position_embeddings {
|
if let Some(position_embeddings) = &self.position_embeddings {
|
||||||
// TODO: Proper absolute positions?
|
// TODO: Proper absolute positions?
|
||||||
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
let position_ids = (0..seq_len as u32).collect::<Vec<_>>();
|
||||||
@ -372,7 +381,12 @@ impl BertSelfOutput {
|
|||||||
&format!("{p}.dense"),
|
&format!("{p}.dense"),
|
||||||
vb,
|
vb,
|
||||||
)?;
|
)?;
|
||||||
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
|
let layer_norm = LayerNorm::load(
|
||||||
|
config.hidden_size,
|
||||||
|
config.layer_norm_eps,
|
||||||
|
&format!("{p}.LayerNorm"),
|
||||||
|
vb,
|
||||||
|
)?;
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
dense,
|
dense,
|
||||||
@ -453,7 +467,12 @@ impl BertOutput {
|
|||||||
&format!("{p}.dense"),
|
&format!("{p}.dense"),
|
||||||
vb,
|
vb,
|
||||||
)?;
|
)?;
|
||||||
let layer_norm = LayerNorm::load(config.hidden_size, &format!("{p}.LayerNorm"), vb)?;
|
let layer_norm = LayerNorm::load(
|
||||||
|
config.hidden_size,
|
||||||
|
config.layer_norm_eps,
|
||||||
|
&format!("{p}.LayerNorm"),
|
||||||
|
vb,
|
||||||
|
)?;
|
||||||
let dropout = Dropout::new(config.hidden_dropout_prob);
|
let dropout = Dropout::new(config.hidden_dropout_prob);
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
dense,
|
dense,
|
||||||
@ -587,7 +606,7 @@ fn main() -> Result<()> {
|
|||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
let token_ids = Tensor::new(&tokens[..], &device)?;
|
let token_ids = Tensor::new(&tokens[..7], &device)?;
|
||||||
println!("{token_ids}");
|
println!("{token_ids}");
|
||||||
let token_type_ids = token_ids.zeros_like()?;
|
let token_type_ids = token_ids.zeros_like()?;
|
||||||
let ys = model.forward(&token_ids, &token_type_ids)?;
|
let ys = model.forward(&token_ids, &token_type_ids)?;
|
||||||
|
Reference in New Issue
Block a user