mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Add some tracing to the whisper example. (#375)
This commit is contained in:
@ -1,8 +1,5 @@
|
||||
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
|
||||
use candle::{Device, Result, Tensor};
|
||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
|
||||
use serde::Deserialize;
|
||||
|
||||
// The names in comments correspond to the original implementation:
|
||||
@ -22,6 +19,7 @@ pub struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
#[allow(dead_code)]
|
||||
pub fn tiny_en() -> Self {
|
||||
Self {
|
||||
num_mel_bins: 80,
|
||||
@ -42,16 +40,32 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
|
||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||
Ok(Embedding::new(embeddings, hidden_size))
|
||||
}
|
||||
//
|
||||
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||
// model.
|
||||
#[derive(Debug)]
|
||||
pub struct Linear {
|
||||
inner: candle_nn::Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(x)
|
||||
}
|
||||
}
|
||||
|
||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
let bias = vb.get(size2, "bias")?;
|
||||
Ok(Linear::new(weight, Some(bias)))
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let weight = vb.get((size2, size1), "weight")?;
|
||||
Ok(Linear::new(weight, None))
|
||||
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||
Ok(Linear { inner, span })
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
@ -66,32 +80,6 @@ fn conv1d(
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
fn conv1d_no_bias(
|
||||
in_channels: usize,
|
||||
out_channels: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
|
||||
Ok(Conv1d::new(weight, None, config))
|
||||
}
|
||||
|
||||
struct Dropout {
|
||||
pr: f64,
|
||||
}
|
||||
|
||||
impl Dropout {
|
||||
fn new(pr: f64) -> Self {
|
||||
Self { pr }
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
// TODO
|
||||
Ok(x.clone())
|
||||
}
|
||||
}
|
||||
|
||||
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
let weight = vb.get(size, "weight")?;
|
||||
let bias = vb.get(size, "bias")?;
|
||||
@ -105,10 +93,12 @@ struct MultiHeadAttention {
|
||||
value: Linear,
|
||||
out: Linear,
|
||||
n_head: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MultiHeadAttention {
|
||||
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
|
||||
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
||||
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
||||
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
||||
@ -119,10 +109,12 @@ impl MultiHeadAttention {
|
||||
value,
|
||||
out,
|
||||
n_head,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let q = self.query.forward(x)?;
|
||||
let k = self.key.forward(xa.unwrap_or(x))?;
|
||||
let v = self.value.forward(xa.unwrap_or(x))?;
|
||||
@ -134,7 +126,7 @@ impl MultiHeadAttention {
|
||||
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
||||
x.reshape(target_dims)?.transpose(1, 2)
|
||||
}
|
||||
|
||||
fn qkv_attention(
|
||||
@ -168,10 +160,12 @@ struct ResidualAttentionBlock {
|
||||
mlp_linear1: Linear,
|
||||
mlp_linear2: Linear,
|
||||
mlp_ln: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl ResidualAttentionBlock {
|
||||
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
|
||||
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
|
||||
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
|
||||
let cross_attn = if ca {
|
||||
@ -192,10 +186,12 @@ impl ResidualAttentionBlock {
|
||||
mlp_linear1,
|
||||
mlp_linear2,
|
||||
mlp_ln,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
|
||||
let mut x = (x + attn)?;
|
||||
if let Some((attn, ln)) = &self.cross_attn {
|
||||
@ -207,7 +203,7 @@ impl ResidualAttentionBlock {
|
||||
.forward(&self.mlp_ln.forward(&x)?)?
|
||||
.gelu()?,
|
||||
)?;
|
||||
Ok((x + mlp)?)
|
||||
x + mlp
|
||||
}
|
||||
}
|
||||
|
||||
@ -234,10 +230,16 @@ pub struct AudioEncoder {
|
||||
positional_embedding: Tensor,
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln_post: LayerNorm,
|
||||
span: tracing::Span,
|
||||
conv1_span: tracing::Span,
|
||||
conv2_span: tracing::Span,
|
||||
}
|
||||
|
||||
impl AudioEncoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
|
||||
let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
|
||||
let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.encoder_attention_heads;
|
||||
let n_ctx = cfg.max_source_positions;
|
||||
@ -264,11 +266,22 @@ impl AudioEncoder {
|
||||
positional_embedding,
|
||||
blocks,
|
||||
ln_post,
|
||||
conv1_span,
|
||||
conv2_span,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = self.conv1.forward(x)?.gelu()?;
|
||||
let x = self.conv2.forward(&x)?.gelu()?;
|
||||
let _enter = self.span.enter();
|
||||
let x = {
|
||||
let _enter = self.conv1_span.enter();
|
||||
self.conv1.forward(x)?.gelu()?
|
||||
};
|
||||
let x = {
|
||||
let _enter = self.conv2_span.enter();
|
||||
self.conv2.forward(&x)?.gelu()?
|
||||
};
|
||||
let x = x.transpose(1, 2)?;
|
||||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||
@ -288,10 +301,12 @@ pub struct TextDecoder {
|
||||
blocks: Vec<ResidualAttentionBlock>,
|
||||
ln: LayerNorm,
|
||||
mask: Tensor,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl TextDecoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.decoder_attention_heads;
|
||||
let n_ctx = cfg.max_target_positions;
|
||||
@ -314,10 +329,12 @@ impl TextDecoder {
|
||||
blocks,
|
||||
ln,
|
||||
mask,
|
||||
span,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let x_dims = x.dims();
|
||||
let last = x_dims[x_dims.len() - 1];
|
||||
let token_embedding = self.token_embedding.forward(x)?;
|
||||
@ -354,6 +371,7 @@ impl Whisper {
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
|
||||
let enc = self.encoder.forward(mel)?;
|
||||
let dec = self.decoder.forward(tokens, &enc)?;
|
||||
|
Reference in New Issue
Block a user