mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add some tracing to the quantized example. (#473)
This commit is contained in:
@ -5,7 +5,7 @@ use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::ggml_file::Content;
|
||||
use candle::quantized::{QMatMul, QTensor};
|
||||
use candle::quantized::QTensor;
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||
use candle_nn::Embedding;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
@ -16,15 +16,22 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
struct RmsNorm {
|
||||
scale: Tensor,
|
||||
eps: f64,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl RmsNorm {
|
||||
fn new(scale: QTensor) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||
let scale = scale.dequantize(&Device::Cpu)?;
|
||||
Ok(Self { scale, eps: 1e-5 })
|
||||
Ok(Self {
|
||||
scale,
|
||||
eps: 1e-5,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||
@ -39,6 +46,25 @@ impl RmsNorm {
|
||||
}
|
||||
}
|
||||
|
||||
// QMatMul wrapper adding some tracing.
|
||||
struct QMatMul {
|
||||
inner: candle::quantized::QMatMul,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
fn from_qtensor(qtensor: QTensor) -> Self {
|
||||
let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Self { inner, span }
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
struct LayerWeights {
|
||||
attention_wq: QMatMul,
|
||||
attention_wk: QMatMul,
|
||||
@ -54,6 +80,9 @@ struct LayerWeights {
|
||||
cos: Tensor,
|
||||
sin: Tensor,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
span_attn: tracing::Span,
|
||||
span_rot: tracing::Span,
|
||||
span_mlp: tracing::Span,
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
@ -65,6 +94,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
|
||||
impl LayerWeights {
|
||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_rot.enter();
|
||||
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||
@ -78,6 +108,7 @@ impl LayerWeights {
|
||||
}
|
||||
|
||||
fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||
let q = self.attention_wq.forward(x)?;
|
||||
let k = self.attention_wk.forward(x)?;
|
||||
@ -127,6 +158,8 @@ struct ModelWeights {
|
||||
// TODO: Switch to using QMatMul instead of linear once we have support for Q6K/Q8K.
|
||||
output: candle_nn::Linear,
|
||||
masks: HashMap<usize, Tensor>,
|
||||
span: tracing::Span,
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
struct WeightMap(HashMap<String, QTensor>);
|
||||
@ -177,6 +210,9 @@ impl ModelWeights {
|
||||
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
|
||||
let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
|
||||
let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
layers.push(LayerWeights {
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq),
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk),
|
||||
@ -192,14 +228,21 @@ impl ModelWeights {
|
||||
cos: cos.clone(),
|
||||
sin: sin.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_rot,
|
||||
span_mlp,
|
||||
})
|
||||
}
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
Ok(Self {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
|
||||
layers,
|
||||
norm,
|
||||
output,
|
||||
masks: HashMap::new(),
|
||||
span,
|
||||
span_output,
|
||||
})
|
||||
}
|
||||
|
||||
@ -219,6 +262,7 @@ impl ModelWeights {
|
||||
fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (_b_sz, seq_len) = x.dims2()?;
|
||||
let mask = self.mask(seq_len)?;
|
||||
let _enter = self.span.enter();
|
||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
let x = layer_in;
|
||||
@ -228,6 +272,7 @@ impl ModelWeights {
|
||||
let x = (attn + residual)?;
|
||||
|
||||
// MLP
|
||||
let _enter = layer.span_mlp.enter();
|
||||
let residual = &x;
|
||||
let x = layer.ffn_norm.forward(&x)?;
|
||||
let w1 = layer.feed_forward_w1.forward(&x)?;
|
||||
@ -239,6 +284,7 @@ impl ModelWeights {
|
||||
}
|
||||
let x = self.norm.forward(&layer_in)?;
|
||||
let x = x.i((.., seq_len - 1, ..))?;
|
||||
let _enter = self.span_output.enter();
|
||||
self.output.forward(&x)
|
||||
}
|
||||
}
|
||||
@ -255,7 +301,7 @@ struct Args {
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, default_value_t = 100)]
|
||||
#[arg(short = 'n', long, default_value_t = 100)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
@ -269,6 +315,10 @@ struct Args {
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
@ -298,7 +348,17 @@ impl Args {
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut file = std::fs::File::open(&args.model()?)?;
|
||||
let start = std::time::Instant::now();
|
||||
|
@ -89,7 +89,6 @@ fn main() -> Result<()> {
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
|
Reference in New Issue
Block a user