mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Tracing for the phi model (#936)
* Add some tracing bits to mixformers. * Add the missing file. * Add the conv2d layer to with-tracing. * Improve the tracing usage.
This commit is contained in:
@ -70,7 +70,7 @@ impl TextGeneration {
|
|||||||
}
|
}
|
||||||
let dt = start_gen.elapsed();
|
let dt = start_gen.elapsed();
|
||||||
println!(
|
println!(
|
||||||
"{sample_len} tokens generated ({:.3} token/s)",
|
"\n{sample_len} tokens generated ({:.2} token/s)",
|
||||||
sample_len as f64 / dt.as_secs_f64(),
|
sample_len as f64 / dt.as_secs_f64(),
|
||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -84,6 +84,10 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
@ -114,8 +118,19 @@ struct Args {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
use crate::models::with_tracing::{linear, Embedding as E, Linear};
|
||||||
/// MixFormer model.
|
/// MixFormer model.
|
||||||
/// https://huggingface.co/microsoft/phi-1_5
|
/// https://huggingface.co/microsoft/phi-1_5
|
||||||
/// https://arxiv.org/abs/2309.05463
|
/// https://arxiv.org/abs/2309.05463
|
||||||
@ -58,12 +59,12 @@ impl Config {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Embedding {
|
struct Embedding {
|
||||||
wte: candle_nn::Embedding,
|
wte: E,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Embedding {
|
impl Embedding {
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let wte = candle_nn::embedding(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
|
let wte = E::new(cfg.vocab_size, cfg.n_embd, vb.pp("wte"))?;
|
||||||
Ok(Self { wte })
|
Ok(Self { wte })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -143,16 +144,16 @@ impl RotaryEmbedding {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
struct MLP {
|
struct MLP {
|
||||||
fc1: candle_nn::Linear,
|
fc1: Linear,
|
||||||
fc2: candle_nn::Linear,
|
fc2: Linear,
|
||||||
act: Activation,
|
act: Activation,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MLP {
|
impl MLP {
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd);
|
let n_inner = cfg.n_inner.unwrap_or(4 * cfg.n_embd);
|
||||||
let fc1 = candle_nn::linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
|
let fc1 = linear(cfg.n_embd, n_inner, vb.pp("fc1"))?;
|
||||||
let fc2 = candle_nn::linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
|
let fc2 = linear(n_inner, cfg.n_embd, vb.pp("fc2"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
fc1,
|
fc1,
|
||||||
fc2,
|
fc2,
|
||||||
@ -170,13 +171,13 @@ impl Module for MLP {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct CausalLMHead {
|
struct CausalLMHead {
|
||||||
ln: candle_nn::LayerNorm,
|
ln: candle_nn::LayerNorm,
|
||||||
linear: candle_nn::Linear,
|
linear: Linear,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl CausalLMHead {
|
impl CausalLMHead {
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
|
let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
|
||||||
let linear = candle_nn::linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
|
let linear = linear(cfg.n_embd, cfg.vocab_size, vb.pp("linear"))?;
|
||||||
Ok(Self { ln, linear })
|
Ok(Self { ln, linear })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -192,20 +193,21 @@ impl Module for CausalLMHead {
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
#[allow(clippy::upper_case_acronyms)]
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
struct MHA {
|
struct MHA {
|
||||||
wqkv: candle_nn::Linear,
|
wqkv: Linear,
|
||||||
out_proj: candle_nn::Linear,
|
out_proj: Linear,
|
||||||
rotary_emb: RotaryEmbedding,
|
rotary_emb: RotaryEmbedding,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
softmax_scale: f64,
|
softmax_scale: f64,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MHA {
|
impl MHA {
|
||||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
let head_dim = cfg.n_embd / cfg.n_head;
|
let head_dim = cfg.n_embd / cfg.n_head;
|
||||||
let op_size = cfg.n_embd;
|
let op_size = cfg.n_embd;
|
||||||
let wqkv = candle_nn::linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
|
let wqkv = linear(cfg.n_embd, 3 * op_size, vb.pp("Wqkv"))?;
|
||||||
let out_proj = candle_nn::linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
|
let out_proj = linear(op_size, cfg.n_embd, vb.pp("out_proj"))?;
|
||||||
let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
|
let rotary_emb = RotaryEmbedding::new(cfg.rotary_dim, MAX_SEQ_LEN, vb.device())?;
|
||||||
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -215,10 +217,12 @@ impl MHA {
|
|||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
rotary_emb,
|
rotary_emb,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||||
let qkv = self
|
let qkv = self
|
||||||
.wqkv
|
.wqkv
|
||||||
@ -267,6 +271,7 @@ struct ParallelBlock {
|
|||||||
ln: candle_nn::LayerNorm,
|
ln: candle_nn::LayerNorm,
|
||||||
mixer: MHA,
|
mixer: MHA,
|
||||||
mlp: MLP,
|
mlp: MLP,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ParallelBlock {
|
impl ParallelBlock {
|
||||||
@ -274,10 +279,16 @@ impl ParallelBlock {
|
|||||||
let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
|
let ln = candle_nn::layer_norm(cfg.n_embd, cfg.layer_norm_epsilon, vb.pp("ln"))?;
|
||||||
let mixer = MHA::new(cfg, vb.pp("mixer"))?;
|
let mixer = MHA::new(cfg, vb.pp("mixer"))?;
|
||||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
Ok(Self { ln, mixer, mlp })
|
Ok(Self {
|
||||||
|
ln,
|
||||||
|
mixer,
|
||||||
|
mlp,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "block"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let residual = xs;
|
let residual = xs;
|
||||||
let xs = xs.apply(&self.ln)?;
|
let xs = xs.apply(&self.ln)?;
|
||||||
let attn_outputs = self.mixer.forward(&xs)?;
|
let attn_outputs = self.mixer.forward(&xs)?;
|
||||||
@ -291,6 +302,7 @@ pub struct MixFormerSequentialForCausalLM {
|
|||||||
embedding: Embedding,
|
embedding: Embedding,
|
||||||
blocks: Vec<ParallelBlock>,
|
blocks: Vec<ParallelBlock>,
|
||||||
head: CausalLMHead,
|
head: CausalLMHead,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MixFormerSequentialForCausalLM {
|
impl MixFormerSequentialForCausalLM {
|
||||||
@ -307,10 +319,12 @@ impl MixFormerSequentialForCausalLM {
|
|||||||
embedding,
|
embedding,
|
||||||
blocks,
|
blocks,
|
||||||
head,
|
head,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "mixformer"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
pub fn forward(&mut self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (_b_size, seq_len) = xs.dims2()?;
|
let (_b_size, seq_len) = xs.dims2()?;
|
||||||
let mut xs = xs.apply(&self.embedding)?;
|
let mut xs = xs.apply(&self.embedding)?;
|
||||||
for block in self.blocks.iter_mut() {
|
for block in self.blocks.iter_mut() {
|
||||||
|
@ -11,4 +11,5 @@ pub mod segment_anything;
|
|||||||
pub mod stable_diffusion;
|
pub mod stable_diffusion;
|
||||||
pub mod t5;
|
pub mod t5;
|
||||||
pub mod whisper;
|
pub mod whisper;
|
||||||
|
pub mod with_tracing;
|
||||||
pub mod wuerstchen;
|
pub mod wuerstchen;
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
//!
|
//!
|
||||||
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
|
//! Denoising Diffusion Implicit Models, K. He and al, 2015.
|
||||||
//! https://arxiv.org/abs/1512.03385
|
//! https://arxiv.org/abs/1512.03385
|
||||||
use super::utils::{conv2d, Conv2d};
|
use crate::models::with_tracing::{conv2d, Conv2d};
|
||||||
use candle::{Result, Tensor, D};
|
use candle::{Result, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
use candle_nn::Module;
|
use candle_nn::Module;
|
||||||
|
@ -4,7 +4,7 @@
|
|||||||
//! timestep and return a denoised version of the input.
|
//! timestep and return a denoised version of the input.
|
||||||
use super::embeddings::{TimestepEmbedding, Timesteps};
|
use super::embeddings::{TimestepEmbedding, Timesteps};
|
||||||
use super::unet_2d_blocks::*;
|
use super::unet_2d_blocks::*;
|
||||||
use super::utils::{conv2d, Conv2d};
|
use crate::models::with_tracing::{conv2d, Conv2d};
|
||||||
use candle::{Result, Tensor};
|
use candle::{Result, Tensor};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
use candle_nn::Module;
|
use candle_nn::Module;
|
||||||
|
@ -4,7 +4,7 @@ use super::attention::{
|
|||||||
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
|
AttentionBlock, AttentionBlockConfig, SpatialTransformer, SpatialTransformerConfig,
|
||||||
};
|
};
|
||||||
use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
use super::resnet::{ResnetBlock2D, ResnetBlock2DConfig};
|
||||||
use super::utils::{conv2d, Conv2d};
|
use crate::models::with_tracing::{conv2d, Conv2d};
|
||||||
use candle::{Module, Result, Tensor, D};
|
use candle::{Module, Result, Tensor, D};
|
||||||
use candle_nn as nn;
|
use candle_nn as nn;
|
||||||
|
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, Result, Tensor};
|
||||||
use candle_nn::Module;
|
|
||||||
|
|
||||||
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
||||||
if steps < 1 {
|
if steps < 1 {
|
||||||
@ -11,29 +10,3 @@ pub fn linspace(start: f64, stop: f64, steps: usize) -> Result<Tensor> {
|
|||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
Tensor::from_vec(vs, steps, &Device::Cpu)
|
Tensor::from_vec(vs, steps, &Device::Cpu)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap the conv2d op to provide some tracing.
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub struct Conv2d {
|
|
||||||
inner: candle_nn::Conv2d,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Conv2d {
|
|
||||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(x)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn conv2d(
|
|
||||||
in_channels: usize,
|
|
||||||
out_channels: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
cfg: candle_nn::Conv2dConfig,
|
|
||||||
vs: candle_nn::VarBuilder,
|
|
||||||
) -> Result<Conv2d> {
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
|
|
||||||
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
|
|
||||||
Ok(Conv2d { inner, span })
|
|
||||||
}
|
|
||||||
|
@ -1,57 +1,12 @@
|
|||||||
// T5 Text Model
|
// T5 Text Model
|
||||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py
|
||||||
|
|
||||||
|
use crate::models::with_tracing::{linear_no_bias, Embedding, Linear};
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
use candle_nn::{Activation, VarBuilder};
|
use candle_nn::{Activation, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Embedding {
|
|
||||||
inner: candle_nn::Embedding,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Embedding {
|
|
||||||
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let inner = candle_nn::embedding(d1, d2, vb)?;
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
|
||||||
Ok(Self { inner, span })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn embeddings(&self) -> &Tensor {
|
|
||||||
self.inner.embeddings()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Embedding {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct Linear {
|
|
||||||
inner: candle_nn::Linear,
|
|
||||||
span: tracing::Span,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Linear {
|
|
||||||
fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
|
||||||
let inner = candle_nn::linear_no_bias(d1, d2, vb)?;
|
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
|
||||||
Ok(Self { inner, span })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Module for Linear {
|
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
|
||||||
let _enter = self.span.enter();
|
|
||||||
self.inner.forward(xs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_relative_attention_max_distance() -> usize {
|
fn default_relative_attention_max_distance() -> usize {
|
||||||
128
|
128
|
||||||
}
|
}
|
||||||
@ -205,8 +160,8 @@ struct T5DenseActDense {
|
|||||||
|
|
||||||
impl T5DenseActDense {
|
impl T5DenseActDense {
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let wi = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
|
let wi = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi"))?;
|
||||||
let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wi,
|
wi,
|
||||||
wo,
|
wo,
|
||||||
@ -237,9 +192,9 @@ struct T5DenseGatedActDense {
|
|||||||
|
|
||||||
impl T5DenseGatedActDense {
|
impl T5DenseGatedActDense {
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let wi_0 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
|
let wi_0 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_0"))?;
|
||||||
let wi_1 = Linear::new(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
|
let wi_1 = linear_no_bias(cfg.d_model, cfg.d_ff, vb.pp("wi_1"))?;
|
||||||
let wo = Linear::new(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
let wo = linear_no_bias(cfg.d_ff, cfg.d_model, vb.pp("wo"))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
wi_0,
|
wi_0,
|
||||||
wi_1,
|
wi_1,
|
||||||
@ -334,10 +289,10 @@ impl T5Attention {
|
|||||||
cfg: &Config,
|
cfg: &Config,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let inner_dim = cfg.num_heads * cfg.d_kv;
|
let inner_dim = cfg.num_heads * cfg.d_kv;
|
||||||
let q = Linear::new(cfg.d_model, inner_dim, vb.pp("q"))?;
|
let q = linear_no_bias(cfg.d_model, inner_dim, vb.pp("q"))?;
|
||||||
let k = Linear::new(cfg.d_model, inner_dim, vb.pp("k"))?;
|
let k = linear_no_bias(cfg.d_model, inner_dim, vb.pp("k"))?;
|
||||||
let v = Linear::new(cfg.d_model, inner_dim, vb.pp("v"))?;
|
let v = linear_no_bias(cfg.d_model, inner_dim, vb.pp("v"))?;
|
||||||
let o = Linear::new(inner_dim, cfg.d_model, vb.pp("o"))?;
|
let o = linear_no_bias(inner_dim, cfg.d_model, vb.pp("o"))?;
|
||||||
let relative_attention_bias = if has_relative_attention_bias {
|
let relative_attention_bias = if has_relative_attention_bias {
|
||||||
let emb = Embedding::new(
|
let emb = Embedding::new(
|
||||||
cfg.relative_attention_num_buckets,
|
cfg.relative_attention_num_buckets,
|
||||||
@ -772,7 +727,11 @@ impl T5ForConditionalGeneration {
|
|||||||
let lm_head = if tie_word_embeddings {
|
let lm_head = if tie_word_embeddings {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(Linear::new(cfg.d_model, cfg.vocab_size, vb.pp("lm_head"))?)
|
Some(linear_no_bias(
|
||||||
|
cfg.d_model,
|
||||||
|
cfg.vocab_size,
|
||||||
|
vb.pp("lm_head"),
|
||||||
|
)?)
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
78
candle-transformers/src/models/with_tracing.rs
Normal file
78
candle-transformers/src/models/with_tracing.rs
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
use candle::{Module, Result, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Embedding {
|
||||||
|
inner: candle_nn::Embedding,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embedding {
|
||||||
|
pub fn new(d1: usize, d2: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let inner = candle_nn::embedding(d1, d2, vb)?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "embedding");
|
||||||
|
Ok(Self { inner, span })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embeddings(&self) -> &Tensor {
|
||||||
|
self.inner.embeddings()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Embedding {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Linear {
|
||||||
|
inner: candle_nn::Linear,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn linear(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let inner = candle_nn::linear(d1, d2, vb)?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
Ok(Linear { inner, span })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
|
let inner = candle_nn::linear_no_bias(d1, d2, vb)?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
|
Ok(Linear { inner, span })
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Linear {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap the conv2d op to provide some tracing.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Conv2d {
|
||||||
|
inner: candle_nn::Conv2d,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Conv2d {
|
||||||
|
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn conv2d(
|
||||||
|
in_channels: usize,
|
||||||
|
out_channels: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
cfg: candle_nn::Conv2dConfig,
|
||||||
|
vs: candle_nn::VarBuilder,
|
||||||
|
) -> Result<Conv2d> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "conv2d");
|
||||||
|
let inner = candle_nn::conv2d(in_channels, out_channels, kernel_size, cfg, vs)?;
|
||||||
|
Ok(Conv2d { inner, span })
|
||||||
|
}
|
Reference in New Issue
Block a user