mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Tracing for StableLM and quantized StableLM. (#1068)
This commit is contained in:
@ -14,6 +14,7 @@ struct MLP {
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: Activation,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
@ -28,12 +29,14 @@ impl MLP {
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mlp"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
@ -55,6 +58,7 @@ struct Attention {
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
use_cache: bool,
|
||||
rotary_ndims: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
@ -81,6 +85,7 @@ impl Attention {
|
||||
kv_cache: None,
|
||||
use_cache: cfg.use_cache,
|
||||
rotary_ndims: cfg.rotary_ndims(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -102,6 +107,7 @@ impl Attention {
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
@ -168,6 +174,7 @@ struct DecoderLayer {
|
||||
mlp: MLP,
|
||||
input_layernorm: LayerNorm,
|
||||
post_attention_layernorm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
@ -185,6 +192,7 @@ impl DecoderLayer {
|
||||
mlp,
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -194,6 +202,7 @@ impl DecoderLayer {
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
@ -211,6 +220,7 @@ pub struct Model {
|
||||
norm: LayerNorm,
|
||||
lm_head: Linear,
|
||||
device: Device,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -233,6 +243,7 @@ impl Model {
|
||||
norm,
|
||||
lm_head,
|
||||
device: vb.device().clone(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -258,6 +269,7 @@ impl Model {
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
|
@ -109,6 +109,7 @@ struct MLP {
|
||||
up_proj: Linear,
|
||||
down_proj: Linear,
|
||||
act_fn: Activation,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
@ -123,12 +124,14 @@ impl MLP {
|
||||
up_proj,
|
||||
down_proj,
|
||||
act_fn: cfg.hidden_act,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mlp"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||
let rhs = xs.apply(&self.up_proj)?;
|
||||
(lhs * rhs)?.apply(&self.down_proj)
|
||||
@ -167,6 +170,7 @@ struct Attention {
|
||||
use_cache: bool,
|
||||
rotary_ndims: usize,
|
||||
use_flash_attn: bool,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
@ -194,6 +198,7 @@ impl Attention {
|
||||
use_cache: cfg.use_cache,
|
||||
rotary_ndims: cfg.rotary_ndims(),
|
||||
use_flash_attn: cfg.use_flash_attn,
|
||||
span: tracing::span!(tracing::Level::TRACE, "attn"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -215,6 +220,7 @@ impl Attention {
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, q_len, _) = xs.dims3()?;
|
||||
|
||||
let query_states = self.q_proj.forward(xs)?;
|
||||
@ -288,6 +294,7 @@ struct DecoderLayer {
|
||||
mlp: MLP,
|
||||
input_layernorm: LayerNorm,
|
||||
post_attention_layernorm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
@ -306,6 +313,7 @@ impl DecoderLayer {
|
||||
mlp,
|
||||
input_layernorm,
|
||||
post_attention_layernorm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "layer"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -315,6 +323,7 @@ impl DecoderLayer {
|
||||
attention_mask: Option<&Tensor>,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let residual = xs;
|
||||
let xs = self.input_layernorm.forward(xs)?;
|
||||
let xs = self.self_attn.forward(&xs, attention_mask, seqlen_offset)?;
|
||||
@ -333,6 +342,7 @@ pub struct Model {
|
||||
lm_head: Linear,
|
||||
device: Device,
|
||||
dtype: DType,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -356,6 +366,7 @@ impl Model {
|
||||
lm_head,
|
||||
device: vb.device().clone(),
|
||||
dtype: vb.dtype(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "model"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -381,6 +392,7 @@ impl Model {
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
|
Reference in New Issue
Block a user