From d6313d244782bcf2358fe408ecaeb89ca8b88203 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 18 Jul 2023 08:11:05 +0100 Subject: [PATCH] Add more tracing details to bert. (#188) --- candle-examples/examples/bert/model.rs | 48 ++++++++++++++++++++++++-- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/candle-examples/examples/bert/model.rs b/candle-examples/examples/bert/model.rs index 059f4280..7323db4b 100644 --- a/candle-examples/examples/bert/model.rs +++ b/candle-examples/examples/bert/model.rs @@ -1,5 +1,5 @@ use candle::{DType, Device, Result, Tensor}; -use candle_nn::{Embedding, LayerNorm, VarBuilder}; +use candle_nn::{Embedding, VarBuilder}; use serde::Deserialize; pub const DTYPE: DType = DType::F32; @@ -61,6 +61,45 @@ impl Linear { } } +#[derive(Debug)] +pub struct LayerNorm { + weight: Tensor, + bias: Tensor, + eps: f64, + span: tracing::Span, +} + +impl LayerNorm { + pub fn new(weight: Tensor, bias: Tensor, eps: f64) -> Self { + let span = tracing::span!(tracing::Level::TRACE, "layer-norm"); + Self { + weight, + bias, + eps, + span, + } + } + + pub fn forward(&self, x: &Tensor) -> Result { + let _enter = self.span.enter(); + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; + let x = x.to_dtype(internal_dtype)?; + let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; + let x = x.broadcast_sub(&mean_x)?; + let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let x = x_normed + .to_dtype(x_dtype)? + .broadcast_mul(&self.weight)? + .broadcast_add(&self.bias)?; + Ok(x) + } +} #[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)] #[serde(rename_all = "lowercase")] enum PositionEmbeddingType { @@ -247,6 +286,7 @@ struct BertSelfAttention { num_attention_heads: usize, attention_head_size: usize, span: tracing::Span, + span_softmax: tracing::Span, } impl BertSelfAttention { @@ -266,6 +306,7 @@ impl BertSelfAttention { num_attention_heads: config.num_attention_heads, attention_head_size, span: tracing::span!(tracing::Level::TRACE, "self-attn"), + span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"), }) } @@ -291,7 +332,10 @@ impl BertSelfAttention { let attention_scores = query_layer.matmul(&key_layer.t()?)?; let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; - let attention_probs = attention_scores.softmax(candle::D::Minus1)?; + let attention_probs = { + let _enter_sm = self.span_softmax.enter(); + attention_scores.softmax(candle::D::Minus1)? + }; let attention_probs = self.dropout.forward(&attention_probs)?; let context_layer = attention_probs.matmul(&value_layer)?;