Consolidate the with-tracing usage. (#1234)

This commit is contained in:
Laurent Mazare
2023-11-01 19:21:36 +01:00
committed by GitHub
parent 693fad511c
commit 1704f1b3ae
4 changed files with 8 additions and 102 deletions

View File

@ -1,3 +1,4 @@
use super::with_tracing::{linear_no_bias as linear, Linear};
use candle::{DType, Device, IndexOp, Result, Tensor, D};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
@ -81,21 +82,6 @@ impl Config {
}
}
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
// model.
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Clone)]
pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
@ -150,12 +136,6 @@ impl Cache {
}
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let span = tracing::span!(tracing::Level::TRACE, "linear");
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
Ok(Linear { inner, span })
}
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
Ok(Embedding::new(embeddings, cfg.hidden_size))