Tracing for StableLM and quantized StableLM. (#1068)

This commit is contained in:
Laurent Mazare
2023-10-10 08:09:25 +02:00
committed by GitHub
parent b34d7f0248
commit bc3351bce4
2 changed files with 24 additions and 0 deletions

View File

@ -14,6 +14,7 @@ struct MLP {
up_proj: Linear,
down_proj: Linear,
act_fn: Activation,
span: tracing::Span,
}
impl MLP {
@ -28,12 +29,14 @@ impl MLP {
up_proj,
down_proj,
act_fn: cfg.hidden_act,
span: tracing::span!(tracing::Level::TRACE, "mlp"),
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
@ -55,6 +58,7 @@ struct Attention {
kv_cache: Option<(Tensor, Tensor)>,
use_cache: bool,
rotary_ndims: usize,
span: tracing::Span,
}
impl Attention {
@ -81,6 +85,7 @@ impl Attention {
kv_cache: None,
use_cache: cfg.use_cache,
rotary_ndims: cfg.rotary_ndims(),
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
@ -102,6 +107,7 @@ impl Attention {
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
@ -168,6 +174,7 @@ struct DecoderLayer {
mlp: MLP,
input_layernorm: LayerNorm,
post_attention_layernorm: LayerNorm,
span: tracing::Span,
}
impl DecoderLayer {
@ -185,6 +192,7 @@ impl DecoderLayer {
mlp,
input_layernorm,
post_attention_layernorm,
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
@ -194,6 +202,7 @@ impl DecoderLayer {
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
@ -211,6 +220,7 @@ pub struct Model {
norm: LayerNorm,
lm_head: Linear,
device: Device,
span: tracing::Span,
}
impl Model {
@ -233,6 +243,7 @@ impl Model {
norm,
lm_head,
device: vb.device().clone(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}
@ -258,6 +269,7 @@ impl Model {
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None

View File

@ -109,6 +109,7 @@ struct MLP {
up_proj: Linear,
down_proj: Linear,
act_fn: Activation,
span: tracing::Span,
}
impl MLP {
@ -123,12 +124,14 @@ impl MLP {
up_proj,
down_proj,
act_fn: cfg.hidden_act,
span: tracing::span!(tracing::Level::TRACE, "mlp"),
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
let rhs = xs.apply(&self.up_proj)?;
(lhs * rhs)?.apply(&self.down_proj)
@ -167,6 +170,7 @@ struct Attention {
use_cache: bool,
rotary_ndims: usize,
use_flash_attn: bool,
span: tracing::Span,
}
impl Attention {
@ -194,6 +198,7 @@ impl Attention {
use_cache: cfg.use_cache,
rotary_ndims: cfg.rotary_ndims(),
use_flash_attn: cfg.use_flash_attn,
span: tracing::span!(tracing::Level::TRACE, "attn"),
})
}
@ -215,6 +220,7 @@ impl Attention {
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, q_len, _) = xs.dims3()?;
let query_states = self.q_proj.forward(xs)?;
@ -288,6 +294,7 @@ struct DecoderLayer {
mlp: MLP,
input_layernorm: LayerNorm,
post_attention_layernorm: LayerNorm,
span: tracing::Span,
}
impl DecoderLayer {
@ -306,6 +313,7 @@ impl DecoderLayer {
mlp,
input_layernorm,
post_attention_layernorm,
span: tracing::span!(tracing::Level::TRACE, "layer"),
})
}
@ -315,6 +323,7 @@ impl DecoderLayer {
attention_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = xs;
let xs = self.input_layernorm.forward(xs)?;
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
@ -333,6 +342,7 @@ pub struct Model {
lm_head: Linear,
device: Device,
dtype: DType,
span: tracing::Span,
}
impl Model {
@ -356,6 +366,7 @@ impl Model {
lm_head,
device: vb.device().clone(),
dtype: vb.dtype(),
span: tracing::span!(tracing::Level::TRACE, "model"),
})
}
@ -381,6 +392,7 @@ impl Model {
}
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None