diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 29e2904e..aca520da 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -1,3 +1,4 @@ +use super::with_tracing::{linear, Linear}; use candle::{DType, Device, Result, Tensor}; use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; @@ -32,35 +33,6 @@ impl HiddenActLayer { } } -#[derive(Debug)] -pub struct Linear { - weight: Tensor, - bias: Option, - span: tracing::Span, -} - -impl Linear { - pub fn new(weight: Tensor, bias: Option) -> Self { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - Self { weight, bias, span } - } -} - -impl Module for Linear { - fn forward(&self, x: &Tensor) -> candle::Result { - let _enter = self.span.enter(); - let w = match x.dims() { - &[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?, - _ => self.weight.t()?, - }; - let x = x.matmul(&w)?; - match &self.bias { - None => Ok(x), - Some(bias) => x.broadcast_add(bias), - } - } -} - #[derive(Debug)] pub struct LayerNorm { weight: Tensor, @@ -184,12 +156,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result Result { - let weight = vb.get((size2, size1), "weight")?; - let bias = vb.get(size2, "bias")?; - Ok(Linear::new(weight, Some(bias))) -} - struct Dropout { #[allow(dead_code)] pr: f64, diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index eed4df5e..7e8c8920 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -1,3 +1,4 @@ +use super::with_tracing::{linear_no_bias as linear, Linear}; use candle::{DType, Device, IndexOp, Result, Tensor, D}; use candle_nn::{Embedding, Module, VarBuilder}; use serde::Deserialize; @@ -81,21 +82,6 @@ impl Config { } } -// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting -// model. -#[derive(Debug)] -pub struct Linear { - inner: candle_nn::Linear, - span: tracing::Span, -} - -impl Linear { - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - #[derive(Clone)] pub struct Cache { masks: Arc>>, @@ -150,12 +136,6 @@ impl Cache { } } -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - let inner = candle_nn::linear_no_bias(size1, size2, vb)?; - Ok(Linear { inner, span }) -} - fn embedding(cfg: &Config, vb: VarBuilder) -> Result { let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?; Ok(Embedding::new(embeddings, cfg.hidden_size)) diff --git a/candle-transformers/src/models/segment_anything/mod.rs b/candle-transformers/src/models/segment_anything/mod.rs index c29db70a..c54493d2 100644 --- a/candle-transformers/src/models/segment_anything/mod.rs +++ b/candle-transformers/src/models/segment_anything/mod.rs @@ -1,3 +1,4 @@ +pub use crate::models::with_tracing::Linear; use candle::{Result, Tensor}; use candle_nn::{Module, VarBuilder}; @@ -9,13 +10,11 @@ pub mod tiny_vit; pub mod transformer; pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result { - let inner = if bias { - candle_nn::linear(in_dim, out_dim, vb)? + if bias { + crate::models::with_tracing::linear(in_dim, out_dim, vb) } else { - candle_nn::linear_no_bias(in_dim, out_dim, vb)? - }; - let span = tracing::span!(tracing::Level::TRACE, "linear"); - Ok(Linear { inner, span }) + crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb) + } } #[derive(Debug)] @@ -85,16 +84,3 @@ impl Module for MlpBlock { .apply(&self.lin2) } } - -#[derive(Debug)] -pub struct Linear { - inner: candle_nn::Linear, - span: tracing::Span, -} - -impl Module for Linear { - fn forward(&self, x: &Tensor) -> Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} diff --git a/candle-transformers/src/models/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs index 6078944c..25454ba6 100644 --- a/candle-transformers/src/models/whisper/model.rs +++ b/candle-transformers/src/models/whisper/model.rs @@ -1,4 +1,5 @@ use super::Config; +use crate::models::with_tracing::{linear, linear_no_bias, Linear}; use candle::{Device, IndexOp, Result, Tensor, D}; use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder}; @@ -6,33 +7,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result Result { - let _enter = self.span.enter(); - self.inner.forward(x) - } -} - -fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - let inner = candle_nn::linear(size1, size2, vb)?; - Ok(Linear { inner, span }) -} - -fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result { - let span = tracing::span!(tracing::Level::TRACE, "linear"); - let inner = candle_nn::linear_no_bias(size1, size2, vb)?; - Ok(Linear { inner, span }) -} fn conv1d( in_channels: usize,