mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add more tracing details to bert. (#188)
This commit is contained in:
@ -1,5 +1,5 @@
|
|||||||
use candle::{DType, Device, Result, Tensor};
|
use candle::{DType, Device, Result, Tensor};
|
||||||
use candle_nn::{Embedding, LayerNorm, VarBuilder};
|
use candle_nn::{Embedding, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
pub const DTYPE: DType = DType::F32;
|
pub const DTYPE: DType = DType::F32;
|
||||||
@ -61,6 +61,45 @@ impl Linear {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct LayerNorm {
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Tensor,
|
||||||
|
eps: f64,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LayerNorm {
|
||||||
|
pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "layer-norm");
|
||||||
|
Self {
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
eps,
|
||||||
|
span,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let x_dtype = x.dtype();
|
||||||
|
let internal_dtype = match x_dtype {
|
||||||
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
|
d => d,
|
||||||
|
};
|
||||||
|
let (_bsize, _seq_len, hidden_size) = x.shape().r3()?;
|
||||||
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
|
let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
|
let x = x.broadcast_sub(&mean_x)?;
|
||||||
|
let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?;
|
||||||
|
let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?;
|
||||||
|
let x = x_normed
|
||||||
|
.to_dtype(x_dtype)?
|
||||||
|
.broadcast_mul(&self.weight)?
|
||||||
|
.broadcast_add(&self.bias)?;
|
||||||
|
Ok(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
enum PositionEmbeddingType {
|
enum PositionEmbeddingType {
|
||||||
@ -247,6 +286,7 @@ struct BertSelfAttention {
|
|||||||
num_attention_heads: usize,
|
num_attention_heads: usize,
|
||||||
attention_head_size: usize,
|
attention_head_size: usize,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
|
span_softmax: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BertSelfAttention {
|
impl BertSelfAttention {
|
||||||
@ -266,6 +306,7 @@ impl BertSelfAttention {
|
|||||||
num_attention_heads: config.num_attention_heads,
|
num_attention_heads: config.num_attention_heads,
|
||||||
attention_head_size,
|
attention_head_size,
|
||||||
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
|
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
|
||||||
|
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -291,7 +332,10 @@ impl BertSelfAttention {
|
|||||||
|
|
||||||
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
let attention_scores = query_layer.matmul(&key_layer.t()?)?;
|
||||||
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?;
|
||||||
let attention_probs = attention_scores.softmax(candle::D::Minus1)?;
|
let attention_probs = {
|
||||||
|
let _enter_sm = self.span_softmax.enter();
|
||||||
|
attention_scores.softmax(candle::D::Minus1)?
|
||||||
|
};
|
||||||
let attention_probs = self.dropout.forward(&attention_probs)?;
|
let attention_probs = self.dropout.forward(&attention_probs)?;
|
||||||
|
|
||||||
let context_layer = attention_probs.matmul(&value_layer)?;
|
let context_layer = attention_probs.matmul(&value_layer)?;
|
||||||
|
Reference in New Issue
Block a user