diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index c9e9ccc6..dfe7a27f 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -1,4 +1,3 @@ -#![allow(dead_code)] // https://github.com/openai/whisper/blob/main/whisper/model.py // TODO: // - kv-cache support? @@ -31,9 +30,6 @@ const HOP_LENGTH: usize = 160; const CHUNK_LENGTH: usize = 30; const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input -const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2 -const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame -const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token const NO_SPEECH_THRESHOLD: f64 = 0.6; const LOGPROB_THRESHOLD: f64 = -1.0; @@ -44,7 +40,6 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; const SOT_TOKEN: u32 = 50257; const EOT_TOKEN: u32 = 50256; const NO_SPEECH_TOKEN: u32 = 50361; -const NO_TIMESTAMP_TOKEN: u32 = 50362; // From the _get_suppress_tokens function + 50362 (no timestamp) // https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605 const SUPPRESS_TOKENS: [u32; 91] = [ @@ -56,6 +51,7 @@ const SUPPRESS_TOKENS: [u32; 91] = [ 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362, ]; +#[allow(dead_code)] #[derive(Debug, Clone)] struct DecodingResult { tokens: Vec, @@ -66,6 +62,7 @@ struct DecodingResult { compression_ratio: f64, } +#[allow(dead_code)] #[derive(Debug, Clone)] struct Segment { start: f64, @@ -243,10 +240,25 @@ struct Args { /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, } fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + let args = Args::parse(); + let _guard = if args.tracing { + println!("tracing..."); + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; let device = candle_examples::device(args.cpu)?; let default_model = "openai/whisper-tiny.en".to_string(); let path = std::path::PathBuf::from(default_model.clone()); diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 4d80c0c8..7015199d 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -1,8 +1,5 @@ -// We use anyhow rather than candle errors as it provides better support for getting the backtrace -// back when using RUST_LIB_BACKTRACE=1. -use anyhow::Result; -use candle::{Device, Tensor}; -use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder}; +use candle::{Device, Result, Tensor}; +use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder}; use serde::Deserialize; // The names in comments correspond to the original implementation: @@ -22,6 +19,7 @@ pub struct Config { } impl Config { + #[allow(dead_code)] pub fn tiny_en() -> Self { Self { num_mel_bins: 80, @@ -42,16 +40,32 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result Result { + let _enter = self.span.enter(); + self.inner.forward(x) + } +} fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { - let weight = vb.get((size2, size1), "weight")?; - let bias = vb.get(size2, "bias")?; - Ok(Linear::new(weight, Some(bias))) + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear(size1, size2, vb)?; + Ok(Linear { inner, span }) } fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result { - let weight = vb.get((size2, size1), "weight")?; - Ok(Linear::new(weight, None)) + let span = tracing::span!(tracing::Level::TRACE, "linear"); + let inner = candle_nn::linear_no_bias(size1, size2, vb)?; + Ok(Linear { inner, span }) } fn conv1d( @@ -66,32 +80,6 @@ fn conv1d( Ok(Conv1d::new(weight, Some(bias), config)) } -fn conv1d_no_bias( - in_channels: usize, - out_channels: usize, - kernel_size: usize, - config: Conv1dConfig, - vb: VarBuilder, -) -> Result { - let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; - Ok(Conv1d::new(weight, None, config)) -} - -struct Dropout { - pr: f64, -} - -impl Dropout { - fn new(pr: f64) -> Self { - Self { pr } - } - - fn forward(&self, x: &Tensor) -> Result { - // TODO - Ok(x.clone()) - } -} - fn layer_norm(size: usize, vb: VarBuilder) -> Result { let weight = vb.get(size, "weight")?; let bias = vb.get(size, "bias")?; @@ -105,10 +93,12 @@ struct MultiHeadAttention { value: Linear, out: Linear, n_head: usize, + span: tracing::Span, } impl MultiHeadAttention { fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn"); let query = linear(n_state, n_state, vb.pp("q_proj"))?; let value = linear(n_state, n_state, vb.pp("v_proj"))?; let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; @@ -119,10 +109,12 @@ impl MultiHeadAttention { value, out, n_head, + span, }) } fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); let q = self.query.forward(x)?; let k = self.key.forward(xa.unwrap_or(x))?; let v = self.value.forward(xa.unwrap_or(x))?; @@ -134,7 +126,7 @@ impl MultiHeadAttention { fn reshape_head(&self, x: &Tensor) -> Result { let (n_batch, n_ctx, n_state) = x.dims3()?; let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; - Ok(x.reshape(target_dims)?.transpose(1, 2)?) + x.reshape(target_dims)?.transpose(1, 2) } fn qkv_attention( @@ -168,10 +160,12 @@ struct ResidualAttentionBlock { mlp_linear1: Linear, mlp_linear2: Linear, mlp_ln: LayerNorm, + span: tracing::Span, } impl ResidualAttentionBlock { fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "residual-attn"); let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; let cross_attn = if ca { @@ -192,10 +186,12 @@ impl ResidualAttentionBlock { mlp_linear1, mlp_linear2, mlp_ln, + span, }) } fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { + let _enter = self.span.enter(); let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?; let mut x = (x + attn)?; if let Some((attn, ln)) = &self.cross_attn { @@ -207,7 +203,7 @@ impl ResidualAttentionBlock { .forward(&self.mlp_ln.forward(&x)?)? .gelu()?, )?; - Ok((x + mlp)?) + x + mlp } } @@ -234,10 +230,16 @@ pub struct AudioEncoder { positional_embedding: Tensor, blocks: Vec, ln_post: LayerNorm, + span: tracing::Span, + conv1_span: tracing::Span, + conv2_span: tracing::Span, } impl AudioEncoder { fn load(vb: VarBuilder, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "audio-encoder"); + let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1"); + let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2"); let n_state = cfg.d_model; let n_head = cfg.encoder_attention_heads; let n_ctx = cfg.max_source_positions; @@ -264,11 +266,22 @@ impl AudioEncoder { positional_embedding, blocks, ln_post, + conv1_span, + conv2_span, + span, }) } + pub fn forward(&self, x: &Tensor) -> Result { - let x = self.conv1.forward(x)?.gelu()?; - let x = self.conv2.forward(&x)?.gelu()?; + let _enter = self.span.enter(); + let x = { + let _enter = self.conv1_span.enter(); + self.conv1.forward(x)?.gelu()? + }; + let x = { + let _enter = self.conv2_span.enter(); + self.conv2.forward(&x)?.gelu()? + }; let x = x.transpose(1, 2)?; let (_bsize, seq_len, _hidden) = x.dims3()?; let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; @@ -288,10 +301,12 @@ pub struct TextDecoder { blocks: Vec, ln: LayerNorm, mask: Tensor, + span: tracing::Span, } impl TextDecoder { fn load(vb: VarBuilder, cfg: &Config) -> Result { + let span = tracing::span!(tracing::Level::TRACE, "text-decoder"); let n_state = cfg.d_model; let n_head = cfg.decoder_attention_heads; let n_ctx = cfg.max_target_positions; @@ -314,10 +329,12 @@ impl TextDecoder { blocks, ln, mask, + span, }) } pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result { + let _enter = self.span.enter(); let x_dims = x.dims(); let last = x_dims[x_dims.len() - 1]; let token_embedding = self.token_embedding.forward(x)?; @@ -354,6 +371,7 @@ impl Whisper { }) } + #[allow(dead_code)] pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result { let enc = self.encoder.forward(mel)?; let dec = self.decoder.forward(tokens, &enc)?;