Consolidate the with-tracing usage. (#1234)

This commit is contained in:
Laurent Mazare
2023-11-01 19:21:36 +01:00
committed by GitHub
parent 693fad511c
commit 1704f1b3ae
4 changed files with 8 additions and 102 deletions

View File

@ -1,3 +1,4 @@
use super::with_tracing::{linear, Linear};
use candle::{DType, Device, Result, Tensor};
use candle_nn::{Embedding, Module, VarBuilder};
use serde::Deserialize;
@ -32,35 +33,6 @@ impl HiddenActLayer {
}
}
#[derive(Debug)]
pub struct Linear {
weight: Tensor,
bias: Option<Tensor>,
span: tracing::Span,
}
impl Linear {
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
let span = tracing::span!(tracing::Level::TRACE, "linear");
Self { weight, bias, span }
}
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let _enter = self.span.enter();
let w = match x.dims() {
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
_ => self.weight.t()?,
};
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
#[derive(Debug)]
pub struct LayerNorm {
weight: Tensor,
@ -184,12 +156,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
Ok(Embedding::new(embeddings, hidden_size))
}
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
let weight = vb.get((size2, size1), "weight")?;
let bias = vb.get(size2, "bias")?;
Ok(Linear::new(weight, Some(bias)))
}
struct Dropout {
#[allow(dead_code)]
pr: f64,