mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add more tracing details to bert. (#188)
This commit is contained in:
@ -1,5 +1,5 @@
|
||||
use candle::{DType, Device, Result, Tensor};
|
||||
use candle_nn::{Embedding, LayerNorm, VarBuilder};
|
||||
use candle_nn::{Embedding, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
pub const DTYPE: DType = DType::F32;
|
||||
@ -61,6 +61,45 @@ impl Linear {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct LayerNorm {
|
||||
weight: Tensor,
|
||||
bias: Tensor,
|
||||
eps: f64,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl LayerNorm {
|
||||
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
||||
Self {
|
||||
weight,
|
||||
bias,
|
||||
eps,
|
||||
span,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x_dtype = x.dtype();
|
||||
let internal_dtype = match x_dtype {
|
||||
DType::F16 | DType::BF16 => DType::F32,
|
||||
d => d,
|
||||
};
|
||||
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||
let x = x.to_dtype(internal_dtype)?;
|
||||
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x = x.broadcast_sub(&mean_x)?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||
let x = x_normed
|
||||
.to_dtype(x_dtype)?
|
||||
.broadcast_mul(&self.weight)?
|
||||
.broadcast_add(&self.bias)?;
|
||||
Ok(x)
|
||||
}
|
||||
}
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum PositionEmbeddingType {
|
||||
@ -247,6 +286,7 @@ struct BertSelfAttention {
|
||||
num_attention_heads: usize,
|
||||
attention_head_size: usize,
|
||||
span: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
impl BertSelfAttention {
|
||||
@ -266,6 +306,7 @@ impl BertSelfAttention {
|
||||
num_attention_heads: config.num_attention_heads,
|
||||
attention_head_size,
|
||||
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
|
||||
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -291,7 +332,10 @@ impl BertSelfAttention {
|
||||
|
||||
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||
let attention_probs = attention_scores.softmax(candle::D::Minus1)?;
|
||||
let attention_probs = {
|
||||
let _enter_sm = self.span_softmax.enter();
|
||||
attention_scores.softmax(candle::D::Minus1)?
|
||||
};
|
||||
let attention_probs = self.dropout.forward(&attention_probs)?;
|
||||
|
||||
let context_layer = attention_probs.matmul(&value_layer)?;
|
||||
|
Reference in New Issue
Block a user