Consolidate the with-tracing usage. (#1234)

This commit is contained in:
Laurent Mazare
2023-11-01 19:21:36 +01:00
committed by GitHub
parent 693fad511c
commit 1704f1b3ae
4 changed files with 8 additions and 102 deletions

View File

@ -1,3 +1,4 @@
pub use crate::models::with_tracing::Linear;
use candle::{Result, Tensor};
use candle_nn::{Module, VarBuilder};
@ -9,13 +10,11 @@ pub mod tiny_vit;
pub mod transformer;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
let inner = if bias {
candle_nn::linear(in_dim, out_dim, vb)?
if bias {
crate::models::with_tracing::linear(in_dim, out_dim, vb)
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
};
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
}
}
#[derive(Debug)]
@ -85,16 +84,3 @@ impl Module for MlpBlock {
.apply(&self.lin2)
}
}
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}