Tracing for the phi model (#936)

* Add some tracing bits to mixformers.

* Add the missing file.

* Add the conv2d layer to with-tracing.

* Improve the tracing usage.
This commit is contained in:
Laurent Mazare
2023-09-23 09:19:34 +01:00
committed by GitHub
parent cda1786eed
commit b54acfa3d0
9 changed files with 140 additions and 100 deletions

View File

@ -1,5 +1,4 @@
use candle::{Device, Result, Tensor};
use candle_nn::Module;
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
if steps < 1 {
@ -11,29 +10,3 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
.collect::<Vec<_>>();
Tensor::from_vec(vs, steps, &Device::Cpu)
}
// Wrap the conv2d op to provide some tracing.
#[derive(Debug)]
pub struct Conv2d {
inner: candle_nn::Conv2d,
span: tracing::Span,
}
impl Conv2d {
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
pub fn conv2d(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
cfg: candle_nn::Conv2dConfig,
vs: candle_nn::VarBuilder,
) -> Result<Conv2d> {
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
Ok(Conv2d { inner, span })
}