Add some tracing to llama. (#318)

This commit is contained in:
Laurent Mazare
2023-08-03 13:52:22 +01:00
committed by GitHub
parent a79286885c
commit df6667ba88
2 changed files with 53 additions and 4 deletions

View File

@ -111,6 +111,10 @@ struct Args {
#[arg(long)] #[arg(long)]
use_f32: bool, use_f32: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)] #[arg(long)]
model_id: Option<String>, model_id: Option<String>,
@ -123,8 +127,18 @@ struct Args {
fn main() -> Result<()> { fn main() -> Result<()> {
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse(); let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let config = if args.v1 { let config = if args.v1 {

View File

@ -1,5 +1,5 @@
use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Linear, VarBuilder}; use candle_nn::{Embedding, VarBuilder};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
@ -47,6 +47,21 @@ impl Config {
} }
} }
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Clone)] #[derive(Clone)]
pub struct Cache { pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>, masks: Arc<Mutex<HashMap<usize, Tensor>>>,
@ -106,8 +121,9 @@ fn silu(xs: &Tensor) -> Result<Tensor> {
} }
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> { fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?; let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear::new(weight, None)) let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
} }
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> { fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
@ -118,15 +134,18 @@ fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
struct RmsNorm { struct RmsNorm {
scale: Tensor, scale: Tensor,
eps: f64, eps: f64,
span: tracing::Span,
} }
impl RmsNorm { impl RmsNorm {
fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> { fn load(size: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
let scale = vb.get(size, "weight")?; let scale = vb.get(size, "weight")?;
Ok(Self { scale, eps }) Ok(Self { scale, eps, span })
} }
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let in_dtype = x.dtype(); let in_dtype = x.dtype();
// This is a no-op if x's dtype is already f32. // This is a no-op if x's dtype is already f32.
let x = x.to_dtype(DType::F32)?; let x = x.to_dtype(DType::F32)?;
@ -155,6 +174,8 @@ struct CausalSelfAttention {
head_dim: usize, head_dim: usize,
cache: Cache, cache: Cache,
use_flash_attn: bool, use_flash_attn: bool,
span: tracing::Span,
span_rot: tracing::Span,
} }
#[cfg(feature = "flash-attn")] #[cfg(feature = "flash-attn")]
@ -175,6 +196,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
impl CausalSelfAttention { impl CausalSelfAttention {
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> { fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
let _enter = self.span_rot.enter();
let (b_sz, _, seq_len, n_embd) = x.dims4()?; let (b_sz, _, seq_len, n_embd) = x.dims4()?;
let cos = self.cache.cos.narrow(0, index_pos, seq_len)?; let cos = self.cache.cos.narrow(0, index_pos, seq_len)?;
let sin = self.cache.sin.narrow(0, index_pos, seq_len)?; let sin = self.cache.sin.narrow(0, index_pos, seq_len)?;
@ -188,6 +210,7 @@ impl CausalSelfAttention {
} }
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let (b_sz, seq_len, n_embd) = x.dims3()?; let (b_sz, seq_len, n_embd) = x.dims3()?;
let q = self.q_proj.forward(x)?; let q = self.q_proj.forward(x)?;
let k = self.k_proj.forward(x)?; let k = self.k_proj.forward(x)?;
@ -269,6 +292,8 @@ impl CausalSelfAttention {
} }
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
let size_in = cfg.hidden_size; let size_in = cfg.hidden_size;
let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head; let size_q = (cfg.hidden_size / cfg.n_head) * cfg.n_head;
let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head; let size_kv = (cfg.hidden_size / cfg.n_head) * cfg.n_key_value_head;
@ -286,6 +311,8 @@ impl CausalSelfAttention {
head_dim: cfg.hidden_size / cfg.n_head, head_dim: cfg.hidden_size / cfg.n_head,
cache: cache.clone(), cache: cache.clone(),
use_flash_attn: cfg.use_flash_attn, use_flash_attn: cfg.use_flash_attn,
span,
span_rot,
}) })
} }
} }
@ -301,15 +328,18 @@ struct Mlp {
c_fc1: Linear, c_fc1: Linear,
c_fc2: Linear, c_fc2: Linear,
c_proj: Linear, c_proj: Linear,
span: tracing::Span,
} }
impl Mlp { impl Mlp {
fn forward(&self, x: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?; let x = (silu(&self.c_fc1.forward(x)?)? * self.c_fc2.forward(x)?)?;
self.c_proj.forward(&x) self.c_proj.forward(&x)
} }
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "mlp");
let h_size = cfg.hidden_size; let h_size = cfg.hidden_size;
let i_size = cfg.intermediate_size; let i_size = cfg.intermediate_size;
let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?; let c_fc1 = linear(h_size, i_size, vb.pp("gate_proj"))?;
@ -319,6 +349,7 @@ impl Mlp {
c_fc1, c_fc1,
c_fc2, c_fc2,
c_proj, c_proj,
span,
}) })
} }
} }
@ -328,10 +359,12 @@ struct Block {
attn: CausalSelfAttention, attn: CausalSelfAttention,
rms_2: RmsNorm, rms_2: RmsNorm,
mlp: Mlp, mlp: Mlp,
span: tracing::Span,
} }
impl Block { impl Block {
fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> { fn forward(&self, x: &Tensor, index_pos: usize, block_idx: usize) -> Result<Tensor> {
let _enter = self.span.enter();
let residual = x; let residual = x;
let x = self.rms_1.forward(x)?; let x = self.rms_1.forward(x)?;
let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?; let x = (self.attn.forward(&x, index_pos, block_idx)? + residual)?;
@ -341,6 +374,7 @@ impl Block {
} }
fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> { fn load(vb: VarBuilder, cache: &Cache, cfg: &Config) -> Result<Self> {
let span = tracing::span!(tracing::Level::TRACE, "block");
let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?; let attn = CausalSelfAttention::load(vb.pp("self_attn"), cache, cfg)?;
let mlp = Mlp::load(vb.pp("mlp"), cfg)?; let mlp = Mlp::load(vb.pp("mlp"), cfg)?;
let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; let rms_1 = RmsNorm::load(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
@ -354,6 +388,7 @@ impl Block {
attn, attn,
rms_2, rms_2,
mlp, mlp,
span,
}) })
} }
} }