mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Add some tracing to the quantized example. (#473)
This commit is contained in:
@ -5,7 +5,7 @@ use std::io::Write;
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
use candle::quantized::ggml_file::Content;
|
use candle::quantized::ggml_file::Content;
|
||||||
use candle::quantized::{QMatMul, QTensor};
|
use candle::quantized::QTensor;
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::Embedding;
|
use candle_nn::Embedding;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
@ -16,15 +16,22 @@ const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
|||||||
struct RmsNorm {
|
struct RmsNorm {
|
||||||
scale: Tensor,
|
scale: Tensor,
|
||||||
eps: f64,
|
eps: f64,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RmsNorm {
|
impl RmsNorm {
|
||||||
fn new(scale: QTensor) -> Result<Self> {
|
fn new(scale: QTensor) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "rms-norm");
|
||||||
let scale = scale.dequantize(&Device::Cpu)?;
|
let scale = scale.dequantize(&Device::Cpu)?;
|
||||||
Ok(Self { scale, eps: 1e-5 })
|
Ok(Self {
|
||||||
|
scale,
|
||||||
|
eps: 1e-5,
|
||||||
|
span,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
let (b_sz, seq_len, hidden_size) = x.dims3()?;
|
||||||
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||||
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
let norm_x = norm_x.broadcast_as((b_sz, seq_len, hidden_size))?;
|
||||||
@ -39,6 +46,25 @@ impl RmsNorm {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QMatMul wrapper adding some tracing.
|
||||||
|
struct QMatMul {
|
||||||
|
inner: candle::quantized::QMatMul,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QMatMul {
|
||||||
|
fn from_qtensor(qtensor: QTensor) -> Self {
|
||||||
|
let inner = candle::quantized::QMatMul::from_qtensor(qtensor);
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||||
|
Self { inner, span }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct LayerWeights {
|
struct LayerWeights {
|
||||||
attention_wq: QMatMul,
|
attention_wq: QMatMul,
|
||||||
attention_wk: QMatMul,
|
attention_wk: QMatMul,
|
||||||
@ -54,6 +80,9 @@ struct LayerWeights {
|
|||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
span_attn: tracing::Span,
|
||||||
|
span_rot: tracing::Span,
|
||||||
|
span_mlp: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||||
@ -65,6 +94,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
|
|
||||||
impl LayerWeights {
|
impl LayerWeights {
|
||||||
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
|
let _enter = self.span_rot.enter();
|
||||||
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
let (b_sz, _, seq_len, n_embd) = x.dims4()?;
|
||||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||||
@ -78,6 +108,7 @@ impl LayerWeights {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn forward_attn(&mut self, x: &Tensor, mask: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
|
let _enter = self.span_attn.enter();
|
||||||
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
let (b_sz, seq_len, n_embd) = x.dims3()?;
|
||||||
let q = self.attention_wq.forward(x)?;
|
let q = self.attention_wq.forward(x)?;
|
||||||
let k = self.attention_wk.forward(x)?;
|
let k = self.attention_wk.forward(x)?;
|
||||||
@ -127,6 +158,8 @@ struct ModelWeights {
|
|||||||
// TODO: Switch to using QMatMul instead of linear once we have support for Q6K/Q8K.
|
// TODO: Switch to using QMatMul instead of linear once we have support for Q6K/Q8K.
|
||||||
output: candle_nn::Linear,
|
output: candle_nn::Linear,
|
||||||
masks: HashMap<usize, Tensor>,
|
masks: HashMap<usize, Tensor>,
|
||||||
|
span: tracing::Span,
|
||||||
|
span_output: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
struct WeightMap(HashMap<String, QTensor>);
|
struct WeightMap(HashMap<String, QTensor>);
|
||||||
@ -177,6 +210,9 @@ impl ModelWeights {
|
|||||||
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
|
let feed_forward_w3 = ct.remove(&format!("{prefix}.feed_forward.w3.weight"))?;
|
||||||
let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
|
let attention_norm = ct.remove(&format!("{prefix}.attention_norm.weight"))?;
|
||||||
let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
|
let ffn_norm = ct.remove(&format!("{prefix}.ffn_norm.weight"))?;
|
||||||
|
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||||
|
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||||
|
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||||
layers.push(LayerWeights {
|
layers.push(LayerWeights {
|
||||||
attention_wq: QMatMul::from_qtensor(attention_wq),
|
attention_wq: QMatMul::from_qtensor(attention_wq),
|
||||||
attention_wk: QMatMul::from_qtensor(attention_wk),
|
attention_wk: QMatMul::from_qtensor(attention_wk),
|
||||||
@ -192,14 +228,21 @@ impl ModelWeights {
|
|||||||
cos: cos.clone(),
|
cos: cos.clone(),
|
||||||
sin: sin.clone(),
|
sin: sin.clone(),
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
|
span_attn,
|
||||||
|
span_rot,
|
||||||
|
span_mlp,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||||
|
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
|
tok_embeddings: Embedding::new(tok_embeddings, ct.hparams.n_embd as usize),
|
||||||
layers,
|
layers,
|
||||||
norm,
|
norm,
|
||||||
output,
|
output,
|
||||||
masks: HashMap::new(),
|
masks: HashMap::new(),
|
||||||
|
span,
|
||||||
|
span_output,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -219,6 +262,7 @@ impl ModelWeights {
|
|||||||
fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
let (_b_sz, seq_len) = x.dims2()?;
|
let (_b_sz, seq_len) = x.dims2()?;
|
||||||
let mask = self.mask(seq_len)?;
|
let mask = self.mask(seq_len)?;
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
let x = layer_in;
|
let x = layer_in;
|
||||||
@ -228,6 +272,7 @@ impl ModelWeights {
|
|||||||
let x = (attn + residual)?;
|
let x = (attn + residual)?;
|
||||||
|
|
||||||
// MLP
|
// MLP
|
||||||
|
let _enter = layer.span_mlp.enter();
|
||||||
let residual = &x;
|
let residual = &x;
|
||||||
let x = layer.ffn_norm.forward(&x)?;
|
let x = layer.ffn_norm.forward(&x)?;
|
||||||
let w1 = layer.feed_forward_w1.forward(&x)?;
|
let w1 = layer.feed_forward_w1.forward(&x)?;
|
||||||
@ -239,6 +284,7 @@ impl ModelWeights {
|
|||||||
}
|
}
|
||||||
let x = self.norm.forward(&layer_in)?;
|
let x = self.norm.forward(&layer_in)?;
|
||||||
let x = x.i((.., seq_len - 1, ..))?;
|
let x = x.i((.., seq_len - 1, ..))?;
|
||||||
|
let _enter = self.span_output.enter();
|
||||||
self.output.forward(&x)
|
self.output.forward(&x)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -255,7 +301,7 @@ struct Args {
|
|||||||
prompt: Option<String>,
|
prompt: Option<String>,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, default_value_t = 100)]
|
#[arg(short = 'n', long, default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
/// The tokenizer config in json format.
|
/// The tokenizer config in json format.
|
||||||
@ -269,6 +315,10 @@ struct Args {
|
|||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
@ -298,7 +348,17 @@ impl Args {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let mut file = std::fs::File::open(&args.model()?)?;
|
let mut file = std::fs::File::open(&args.model()?)?;
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
|
@ -89,7 +89,6 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let _guard = if args.tracing {
|
let _guard = if args.tracing {
|
||||||
println!("tracing...");
|
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
Some(guard)
|
Some(guard)
|
||||||
|
Reference in New Issue
Block a user