mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Add some tracing to the whisper example. (#375)
This commit is contained in:
@ -1,4 +1,3 @@
|
|||||||
#![allow(dead_code)]
|
|
||||||
// https://github.com/openai/whisper/blob/main/whisper/model.py
|
// https://github.com/openai/whisper/blob/main/whisper/model.py
|
||||||
// TODO:
|
// TODO:
|
||||||
// - kv-cache support?
|
// - kv-cache support?
|
||||||
@ -31,9 +30,6 @@ const HOP_LENGTH: usize = 160;
|
|||||||
const CHUNK_LENGTH: usize = 30;
|
const CHUNK_LENGTH: usize = 30;
|
||||||
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
||||||
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
|
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
|
||||||
const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2
|
|
||||||
const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame
|
|
||||||
const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token
|
|
||||||
|
|
||||||
const NO_SPEECH_THRESHOLD: f64 = 0.6;
|
const NO_SPEECH_THRESHOLD: f64 = 0.6;
|
||||||
const LOGPROB_THRESHOLD: f64 = -1.0;
|
const LOGPROB_THRESHOLD: f64 = -1.0;
|
||||||
@ -44,7 +40,6 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
|||||||
const SOT_TOKEN: u32 = 50257;
|
const SOT_TOKEN: u32 = 50257;
|
||||||
const EOT_TOKEN: u32 = 50256;
|
const EOT_TOKEN: u32 = 50256;
|
||||||
const NO_SPEECH_TOKEN: u32 = 50361;
|
const NO_SPEECH_TOKEN: u32 = 50361;
|
||||||
const NO_TIMESTAMP_TOKEN: u32 = 50362;
|
|
||||||
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
// From the _get_suppress_tokens function + 50362 (no timestamp)
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605
|
||||||
const SUPPRESS_TOKENS: [u32; 91] = [
|
const SUPPRESS_TOKENS: [u32; 91] = [
|
||||||
@ -56,6 +51,7 @@ const SUPPRESS_TOKENS: [u32; 91] = [
|
|||||||
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
|
47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362,
|
||||||
];
|
];
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct DecodingResult {
|
struct DecodingResult {
|
||||||
tokens: Vec<u32>,
|
tokens: Vec<u32>,
|
||||||
@ -66,6 +62,7 @@ struct DecodingResult {
|
|||||||
compression_ratio: f64,
|
compression_ratio: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
struct Segment {
|
struct Segment {
|
||||||
start: f64,
|
start: f64,
|
||||||
@ -243,10 +240,25 @@ struct Args {
|
|||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
println!("tracing...");
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let default_model = "openai/whisper-tiny.en".to_string();
|
let default_model = "openai/whisper-tiny.en".to_string();
|
||||||
let path = std::path::PathBuf::from(default_model.clone());
|
let path = std::path::PathBuf::from(default_model.clone());
|
||||||
|
@ -1,8 +1,5 @@
|
|||||||
// We use anyhow rather than candle errors as it provides better support for getting the backtrace
|
use candle::{Device, Result, Tensor};
|
||||||
// back when using RUST_LIB_BACKTRACE=1.
|
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
|
||||||
use anyhow::Result;
|
|
||||||
use candle::{Device, Tensor};
|
|
||||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder};
|
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
// The names in comments correspond to the original implementation:
|
// The names in comments correspond to the original implementation:
|
||||||
@ -22,6 +19,7 @@ pub struct Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
|
#[allow(dead_code)]
|
||||||
pub fn tiny_en() -> Self {
|
pub fn tiny_en() -> Self {
|
||||||
Self {
|
Self {
|
||||||
num_mel_bins: 80,
|
num_mel_bins: 80,
|
||||||
@ -42,16 +40,32 @@ fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result<Em
|
|||||||
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
let embeddings = vb.get((vocab_size, hidden_size), "weight")?;
|
||||||
Ok(Embedding::new(embeddings, hidden_size))
|
Ok(Embedding::new(embeddings, hidden_size))
|
||||||
}
|
}
|
||||||
|
//
|
||||||
|
// We wrap the `Linear` layer here to add some tracing so that it's easier to profile the resulting
|
||||||
|
// model.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Linear {
|
||||||
|
inner: candle_nn::Linear,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Linear {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
let weight = vb.get((size2, size1), "weight")?;
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
let bias = vb.get(size2, "bias")?;
|
let inner = candle_nn::linear(size1, size2, vb)?;
|
||||||
Ok(Linear::new(weight, Some(bias)))
|
Ok(Linear { inner, span })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result<Linear> {
|
||||||
let weight = vb.get((size2, size1), "weight")?;
|
let span = tracing::span!(tracing::Level::TRACE, "linear");
|
||||||
Ok(Linear::new(weight, None))
|
let inner = candle_nn::linear_no_bias(size1, size2, vb)?;
|
||||||
|
Ok(Linear { inner, span })
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv1d(
|
fn conv1d(
|
||||||
@ -66,32 +80,6 @@ fn conv1d(
|
|||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv1d_no_bias(
|
|
||||||
in_channels: usize,
|
|
||||||
out_channels: usize,
|
|
||||||
kernel_size: usize,
|
|
||||||
config: Conv1dConfig,
|
|
||||||
vb: VarBuilder,
|
|
||||||
) -> Result<Conv1d> {
|
|
||||||
let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?;
|
|
||||||
Ok(Conv1d::new(weight, None, config))
|
|
||||||
}
|
|
||||||
|
|
||||||
struct Dropout {
|
|
||||||
pr: f64,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Dropout {
|
|
||||||
fn new(pr: f64) -> Self {
|
|
||||||
Self { pr }
|
|
||||||
}
|
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
|
||||||
// TODO
|
|
||||||
Ok(x.clone())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
fn layer_norm(size: usize, vb: VarBuilder) -> Result<LayerNorm> {
|
||||||
let weight = vb.get(size, "weight")?;
|
let weight = vb.get(size, "weight")?;
|
||||||
let bias = vb.get(size, "bias")?;
|
let bias = vb.get(size, "bias")?;
|
||||||
@ -105,10 +93,12 @@ struct MultiHeadAttention {
|
|||||||
value: Linear,
|
value: Linear,
|
||||||
out: Linear,
|
out: Linear,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MultiHeadAttention {
|
impl MultiHeadAttention {
|
||||||
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
|
||||||
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
||||||
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
||||||
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
||||||
@ -119,10 +109,12 @@ impl MultiHeadAttention {
|
|||||||
value,
|
value,
|
||||||
out,
|
out,
|
||||||
n_head,
|
n_head,
|
||||||
|
span,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let q = self.query.forward(x)?;
|
let q = self.query.forward(x)?;
|
||||||
let k = self.key.forward(xa.unwrap_or(x))?;
|
let k = self.key.forward(xa.unwrap_or(x))?;
|
||||||
let v = self.value.forward(xa.unwrap_or(x))?;
|
let v = self.value.forward(xa.unwrap_or(x))?;
|
||||||
@ -134,7 +126,7 @@ impl MultiHeadAttention {
|
|||||||
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
fn reshape_head(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
let (n_batch, n_ctx, n_state) = x.dims3()?;
|
||||||
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head];
|
||||||
Ok(x.reshape(target_dims)?.transpose(1, 2)?)
|
x.reshape(target_dims)?.transpose(1, 2)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn qkv_attention(
|
fn qkv_attention(
|
||||||
@ -168,10 +160,12 @@ struct ResidualAttentionBlock {
|
|||||||
mlp_linear1: Linear,
|
mlp_linear1: Linear,
|
||||||
mlp_linear2: Linear,
|
mlp_linear2: Linear,
|
||||||
mlp_ln: LayerNorm,
|
mlp_ln: LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResidualAttentionBlock {
|
impl ResidualAttentionBlock {
|
||||||
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
|
fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "residual-attn");
|
||||||
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
|
let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?;
|
||||||
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
|
let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?;
|
||||||
let cross_attn = if ca {
|
let cross_attn = if ca {
|
||||||
@ -192,10 +186,12 @@ impl ResidualAttentionBlock {
|
|||||||
mlp_linear1,
|
mlp_linear1,
|
||||||
mlp_linear2,
|
mlp_linear2,
|
||||||
mlp_ln,
|
mlp_ln,
|
||||||
|
span,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
|
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?;
|
||||||
let mut x = (x + attn)?;
|
let mut x = (x + attn)?;
|
||||||
if let Some((attn, ln)) = &self.cross_attn {
|
if let Some((attn, ln)) = &self.cross_attn {
|
||||||
@ -207,7 +203,7 @@ impl ResidualAttentionBlock {
|
|||||||
.forward(&self.mlp_ln.forward(&x)?)?
|
.forward(&self.mlp_ln.forward(&x)?)?
|
||||||
.gelu()?,
|
.gelu()?,
|
||||||
)?;
|
)?;
|
||||||
Ok((x + mlp)?)
|
x + mlp
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -234,10 +230,16 @@ pub struct AudioEncoder {
|
|||||||
positional_embedding: Tensor,
|
positional_embedding: Tensor,
|
||||||
blocks: Vec<ResidualAttentionBlock>,
|
blocks: Vec<ResidualAttentionBlock>,
|
||||||
ln_post: LayerNorm,
|
ln_post: LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
|
conv1_span: tracing::Span,
|
||||||
|
conv2_span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AudioEncoder {
|
impl AudioEncoder {
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "audio-encoder");
|
||||||
|
let conv1_span = tracing::span!(tracing::Level::TRACE, "conv1");
|
||||||
|
let conv2_span = tracing::span!(tracing::Level::TRACE, "conv2");
|
||||||
let n_state = cfg.d_model;
|
let n_state = cfg.d_model;
|
||||||
let n_head = cfg.encoder_attention_heads;
|
let n_head = cfg.encoder_attention_heads;
|
||||||
let n_ctx = cfg.max_source_positions;
|
let n_ctx = cfg.max_source_positions;
|
||||||
@ -264,11 +266,22 @@ impl AudioEncoder {
|
|||||||
positional_embedding,
|
positional_embedding,
|
||||||
blocks,
|
blocks,
|
||||||
ln_post,
|
ln_post,
|
||||||
|
conv1_span,
|
||||||
|
conv2_span,
|
||||||
|
span,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let x = self.conv1.forward(x)?.gelu()?;
|
let _enter = self.span.enter();
|
||||||
let x = self.conv2.forward(&x)?.gelu()?;
|
let x = {
|
||||||
|
let _enter = self.conv1_span.enter();
|
||||||
|
self.conv1.forward(x)?.gelu()?
|
||||||
|
};
|
||||||
|
let x = {
|
||||||
|
let _enter = self.conv2_span.enter();
|
||||||
|
self.conv2.forward(&x)?.gelu()?
|
||||||
|
};
|
||||||
let x = x.transpose(1, 2)?;
|
let x = x.transpose(1, 2)?;
|
||||||
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
let (_bsize, seq_len, _hidden) = x.dims3()?;
|
||||||
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?;
|
||||||
@ -288,10 +301,12 @@ pub struct TextDecoder {
|
|||||||
blocks: Vec<ResidualAttentionBlock>,
|
blocks: Vec<ResidualAttentionBlock>,
|
||||||
ln: LayerNorm,
|
ln: LayerNorm,
|
||||||
mask: Tensor,
|
mask: Tensor,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TextDecoder {
|
impl TextDecoder {
|
||||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "text-decoder");
|
||||||
let n_state = cfg.d_model;
|
let n_state = cfg.d_model;
|
||||||
let n_head = cfg.decoder_attention_heads;
|
let n_head = cfg.decoder_attention_heads;
|
||||||
let n_ctx = cfg.max_target_positions;
|
let n_ctx = cfg.max_target_positions;
|
||||||
@ -314,10 +329,12 @@ impl TextDecoder {
|
|||||||
blocks,
|
blocks,
|
||||||
ln,
|
ln,
|
||||||
mask,
|
mask,
|
||||||
|
span,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let x_dims = x.dims();
|
let x_dims = x.dims();
|
||||||
let last = x_dims[x_dims.len() - 1];
|
let last = x_dims[x_dims.len() - 1];
|
||||||
let token_embedding = self.token_embedding.forward(x)?;
|
let token_embedding = self.token_embedding.forward(x)?;
|
||||||
@ -354,6 +371,7 @@ impl Whisper {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
|
pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
|
||||||
let enc = self.encoder.forward(mel)?;
|
let enc = self.encoder.forward(mel)?;
|
||||||
let dec = self.decoder.forward(tokens, &enc)?;
|
let dec = self.decoder.forward(tokens, &enc)?;
|
||||||
|
Reference in New Issue
Block a user