mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Tracing for the phi model (#936)
* Add some tracing bits to mixformers. * Add the missing file. * Add the conv2d layer to with-tracing. * Improve the tracing usage.
This commit is contained in:
78
candle-transformers/src/models/with_tracing.rs
Normal file
78
candle-transformers/src/models/with_tracing.rs
Normal file
@ -0,0 +1,78 @@
|
||||
use candle::{Module, Result, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedding {
|
||||
inner: candle_nn::Embedding,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Embedding {
|
||||
pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let inner = candle_nn::embedding(d1, d2, vb)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
pub fn embeddings(&self) -> &Tensor {
|
||||
self.inner.embeddings()
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Embedding {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let inner = candle_nn::linear(d1, d2, vb)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
pub fn linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let inner = candle_nn::linear_no_bias(d1, d2, vb)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
impl Module for Linear {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap the conv2d op to provide some tracing.
|
||||
#[derive(Debug)]
|
||||
pub struct Conv2d {
|
||||
inner: candle_nn::Conv2d,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Conv2d {
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn conv2d(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
cfg: candle_nn::Conv2dConfig,
|
||||
vs: candle_nn::VarBuilder,
|
||||
) -> Result<Conv2d> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
|
||||
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
|
||||
Ok(Conv2d { inner, span })
|
||||
}
|
Reference in New Issue
Block a user