Add more tracing details to bert. (#188)

This commit is contained in:
Laurent Mazare
2023-07-18 08:11:05 +01:00
committed by GitHub
parent d73df74cb2
commit d6313d2447

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, Result, Tensor}; use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, LayerNorm, VarBuilder}; use candle_nn::{Embedding, VarBuilder};
use serde::Deserialize; use serde::Deserialize;
pub const DTYPE: DType = DType::F32; pub const DTYPE: DType = DType::F32;
@ -61,6 +61,45 @@ impl Linear {
} }
} }
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
bias: Tensor,
eps: f64,
span: tracing::Span,
}
impl LayerNorm {
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
Self {
weight,
bias,
eps,
span,
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {
DType::F16 | DType::BF16 => DType::F32,
d => d,
};
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
let x = x.to_dtype(internal_dtype)?;
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
let x = x.broadcast_sub(&mean_x)?;
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
let x = x_normed
.to_dtype(x_dtype)?
.broadcast_mul(&self.weight)?
.broadcast_add(&self.bias)?;
Ok(x)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
enum PositionEmbeddingType { enum PositionEmbeddingType {
@ -247,6 +286,7 @@ struct BertSelfAttention {
num_attention_heads: usize, num_attention_heads: usize,
attention_head_size: usize, attention_head_size: usize,
span: tracing::Span, span: tracing::Span,
span_softmax: tracing::Span,
} }
impl BertSelfAttention { impl BertSelfAttention {
@ -266,6 +306,7 @@ impl BertSelfAttention {
num_attention_heads: config.num_attention_heads, num_attention_heads: config.num_attention_heads,
attention_head_size, attention_head_size,
span: tracing::span!(tracing::Level::TRACE, "self-attn"), span: tracing::span!(tracing::Level::TRACE, "self-attn"),
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
}) })
} }
@ -291,7 +332,10 @@ impl BertSelfAttention {
let attention_scores = query_layer.matmul(&key_layer.t()?)?; let attention_scores = query_layer.matmul(&key_layer.t()?)?;
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
let attention_probs = attention_scores.softmax(candle::D::Minus1)?; let attention_probs = {
let _enter_sm = self.span_softmax.enter();
attention_scores.softmax(candle::D::Minus1)?
};
let attention_probs = self.dropout.forward(&attention_probs)?; let attention_probs = self.dropout.forward(&attention_probs)?;
let context_layer = attention_probs.matmul(&value_layer)?; let context_layer = attention_probs.matmul(&value_layer)?;