mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Consolidate the with-tracing usage. (#1234)
This commit is contained in:
@ -1,3 +1,4 @@
|
|||||||
|
use super::with_tracing::{linear, Linear};
|
||||||
use candle::{DType, Device, Result, Tensor};
|
use candle::{DType, Device, Result, Tensor};
|
||||||
use candle_nn::{Embedding, Module, VarBuilder};
|
use candle_nn::{Embedding, Module, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@ -32,35 +33,6 @@ impl HiddenActLayer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Linear {
|
|
||||||
weight: Tensor,
|
|
||||||
bias: Option<Tensor>,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Linear {
|
|
||||||
pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
|
||||||
Self { weight, bias, span }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Linear {
|
|
||||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
let w = match x.dims() {
|
|
||||||
&[bsize, _, _] => self.weight.broadcast_left(bsize)?.t()?,
|
|
||||||
_ => self.weight.t()?,
|
|
||||||
};
|
|
||||||
let x = x.matmul(&w)?;
|
|
||||||
match &self.bias {
|
|
||||||
None => Ok(x),
|
|
||||||
Some(bias) => x.broadcast_add(bias),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct LayerNorm {
|
pub struct LayerNorm {
|
||||||
weight: Tensor,
|
weight: Tensor,
|
||||||
@ -184,12 +156,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
|
|||||||
Ok(Embedding::new(embeddings, hidden_size))
|
Ok(Embedding::new(embeddings, hidden_size))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let weight = vb.get((size2, size1), "weight")?;
|
|
||||||
let bias = vb.get(size2, "bias")?;
|
|
||||||
Ok(Linear::new(weight, Some(bias)))
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Dropout {
|
struct Dropout {
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
pr: f64,
|
pr: f64,
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
use super::with_tracing::{linear_no_bias as linear, Linear};
|
||||||
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Embedding, Module, VarBuilder};
|
use candle_nn::{Embedding, Module, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
@ -81,21 +82,6 @@ impl Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
|
||||||
// model.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Linear {
|
|
||||||
inner: candle_nn::Linear,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Linear {
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct Cache {
|
pub struct Cache {
|
||||||
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
|
||||||
@ -150,12 +136,6 @@ impl Cache {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
|
||||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
|
||||||
Ok(Linear { inner, span })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
fn embedding(cfg: &Config, vb: VarBuilder) -> Result<Embedding> {
|
||||||
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
let embeddings = vb.get((cfg.vocab_size, cfg.hidden_size), "weight")?;
|
||||||
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
Ok(Embedding::new(embeddings, cfg.hidden_size))
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
pub use crate::models::with_tracing::Linear;
|
||||||
use candle::{Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
use candle_nn::{Module, VarBuilder};
|
use candle_nn::{Module, VarBuilder};
|
||||||
|
|
||||||
@ -9,13 +10,11 @@ pub mod tiny_vit;
|
|||||||
pub mod transformer;
|
pub mod transformer;
|
||||||
|
|
||||||
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
|
||||||
let inner = if bias {
|
if bias {
|
||||||
candle_nn::linear(in_dim, out_dim, vb)?
|
crate::models::with_tracing::linear(in_dim, out_dim, vb)
|
||||||
} else {
|
} else {
|
||||||
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
|
crate::models::with_tracing::linear_no_bias(in_dim, out_dim, vb)
|
||||||
};
|
}
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
|
||||||
Ok(Linear { inner, span })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@ -85,16 +84,3 @@ impl Module for MlpBlock {
|
|||||||
.apply(&self.lin2)
|
.apply(&self.lin2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Linear {
|
|
||||||
inner: candle_nn::Linear,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Linear {
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
use super::Config;
|
use super::Config;
|
||||||
|
use crate::models::with_tracing::{linear, linear_no_bias, Linear};
|
||||||
use candle::{Device, IndexOp, Result, Tensor, D};
|
use candle::{Device, IndexOp, Result, Tensor, D};
|
||||||
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Module, VarBuilder};
|
||||||
|
|
||||||
@ -6,33 +7,6 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
|
|||||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
Ok(Embedding::new(embeddings, hidden_size))
|
||||||
}
|
}
|
||||||
//
|
|
||||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
|
||||||
// model.
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub struct Linear {
|
|
||||||
inner: candle_nn::Linear,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Linear {
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
|
||||||
let inner = candle_nn::linear(size1, size2, vb)?;
|
|
||||||
Ok(Linear { inner, span })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
|
||||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
|
||||||
Ok(Linear { inner, span })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn conv1d(
|
fn conv1d(
|
||||||
in_channels: usize,
|
in_channels: usize,
|
||||||
|
Reference in New Issue
Block a user