From e85573a4bd4641e5cc5cc1bed85a4550d104396d Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 5 Jul 2023 14:35:27 +0000 Subject: [PATCH] Adapting whisper for Hub use. --- candle-examples/examples/whisper/main.rs | 54 ++++++++--- candle-examples/examples/whisper/model.rs | 107 ++++++++++++---------- 2 files changed, 101 insertions(+), 60 deletions(-) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index 6ea3e536..0df0fc42 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -7,6 +7,7 @@ use anyhow::{Error as E, Result}; use candle::{DType, Device, Tensor}; +use candle_hub::{api::Api, Repo, RepoType}; use clap::Parser; use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; @@ -68,7 +69,7 @@ impl Decode { let model = &self.model; let audio_features = model.encoder.forward(mel)?; println!("audio features: {:?}", audio_features.dims()); - let sample_len = model.config.n_text_ctx / 2; + let sample_len = model.config.max_target_positions / 2; let mut sum_logprob = 0f64; let mut no_speech_prob = f64::NAN; let mut tokens = vec![SOT_TOKEN]; @@ -112,7 +113,7 @@ impl Decode { .softmax(logits.rank() - 1)? .get(next_token as usize)? .to_scalar::()? as f64; - if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx { + if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions { break; } sum_logprob += prob.ln(); @@ -165,15 +166,16 @@ struct Args { cpu: bool, #[arg(long)] - weights: String, + model_id: Option, + + /// The model to use, check out available models: https://huggingface.co/models?search=whisper + #[arg(long)] + revision: Option, /// The input to be processed, in wav formats. - #[arg(long)] + #[arg(long, default_value = "jfk.wav")] input: String, - #[arg(long)] - tokenizer_config: String, - /// The seed to use when generating random samples. #[arg(long, default_value_t = 299792458)] seed: u64, @@ -186,7 +188,8 @@ struct Args { filters: String, } -fn main() -> Result<()> { +#[tokio::main] +async fn main() -> Result<()> { let args = Args::parse(); let device = if args.cpu { Device::Cpu @@ -195,7 +198,35 @@ fn main() -> Result<()> { }; let rng = rand::rngs::StdRng::seed_from_u64(args.seed); - let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?; + let default_model = "openai/whisper-tiny.en".to_string(); + let path = std::path::PathBuf::from(default_model.clone()); + let default_revision = "refs/pr/15".to_string(); + let (model_id, revision) = match (args.model_id, args.revision) { + (Some(model_id), Some(revision)) => (model_id, revision), + (Some(model_id), None) => (model_id, "main".to_string()), + (None, Some(revision)) => (default_model, revision), + (None, None) => (default_model, default_revision), + }; + + let (config_filename, tokenizer_filename, weights_filename) = if path.exists() { + let mut config_filename = path.clone(); + config_filename.push("config.json"); + let mut tokenizer_filename = path.clone(); + tokenizer_filename.push("tokenizer.json"); + let mut model_filename = path.clone(); + model_filename.push("model.safetensors"); + (config_filename, tokenizer_filename, model_filename) + } else { + let repo = Repo::with_revision(model_id, RepoType::Model, revision); + let api = Api::new()?; + ( + api.get(&repo, "config.json").await?, + api.get(&repo, "tokenizer.json").await?, + api.get(&repo, "model.safetensors").await?, + ) + }; + println!("Weights {weights_filename:?}"); + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? }; let mel_filters = mel_filters.deserialize()?; @@ -220,10 +251,11 @@ fn main() -> Result<()> { let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?; println!("loaded mel: {:?}", mel.dims()); - let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; + let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device); - let model = Whisper::load(&vb, Config::tiny_en())?; + let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; + let model = Whisper::load(&vb, config)?; let mut dc = Decode { model, rng, diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index bf322c51..50f63b6d 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -2,6 +2,7 @@ // back when using RUST_LIB_BACKTRACE=1. use anyhow::Result; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; +use serde::Deserialize; use std::collections::HashMap; pub struct VarBuilder<'a> { @@ -76,33 +77,33 @@ impl HiddenAct { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Deserialize)] pub struct Config { - pub n_mels: usize, - pub n_audio_ctx: usize, - pub n_audio_state: usize, - pub n_audio_head: usize, - pub n_audio_layer: usize, - pub n_vocab: usize, - pub n_text_ctx: usize, - pub n_text_state: usize, - pub n_text_head: usize, - pub n_text_layer: usize, + pub num_mel_bins: usize, + pub max_source_positions: usize, + pub d_model: usize, + pub encoder_attention_heads: usize, + pub encoder_layers: usize, + pub vocab_size: usize, + pub max_target_positions: usize, + // pub n_text_state: usize, + pub decoder_attention_heads: usize, + pub decoder_layers: usize, } impl Config { pub fn tiny_en() -> Self { Self { - n_mels: 80, - n_vocab: 51864, - n_audio_ctx: 1500, - n_audio_state: 384, - n_audio_head: 6, - n_audio_layer: 4, - n_text_ctx: 448, - n_text_state: 384, - n_text_head: 6, - n_text_layer: 4, + num_mel_bins: 80, + vocab_size: 51864, + max_source_positions: 1500, + d_model: 384, + encoder_attention_heads: 6, + encoder_layers: 4, + max_target_positions: 448, + // n_text_state: 384, + decoder_attention_heads: 6, + decoder_layers: 4, } } } @@ -297,10 +298,10 @@ struct MultiHeadAttention { impl MultiHeadAttention { fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result { - let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?; - let value = Linear::load(n_state, n_state, &format!("{p}.value"), vb)?; - let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.key"), vb)?; - let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?; + let query = Linear::load(n_state, n_state, &format!("{p}.q_proj"), vb)?; + let value = Linear::load(n_state, n_state, &format!("{p}.v_proj"), vb)?; + let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?; + let out = Linear::load(n_state, n_state, &format!("{p}.out_proj"), vb)?; Ok(Self { query, key, @@ -360,20 +361,21 @@ struct ResidualAttentionBlock { impl ResidualAttentionBlock { fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result { - let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?; - let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?; + let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?; + let attn_ln = LayerNorm::load(n_state, &format!("{p}.self_attn_layer_norm"), vb)?; let cross_attn = if ca { let cross_attn = - MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?; - let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?; + MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?; + let cross_attn_ln = + LayerNorm::load(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?; Some((cross_attn, cross_attn_ln)) } else { None }; let n_mlp = n_state * 4; - let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?; - let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.mlp.2"), vb)?; - let mlp_ln = LayerNorm::load(n_state, &format!("{p}.mlp_ln"), vb)?; + let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.fc1"), vb)?; + let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.fc2"), vb)?; + let mlp_ln = LayerNorm::load(n_state, &format!("{p}.final_layer_norm"), vb)?; Ok(Self { attn, attn_ln, @@ -426,9 +428,9 @@ pub struct AudioEncoder { impl AudioEncoder { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let n_state = cfg.n_audio_state; - let n_head = cfg.n_audio_head; - let n_ctx = cfg.n_audio_ctx; + let n_state = cfg.d_model; + let n_head = cfg.encoder_attention_heads; + let n_ctx = cfg.max_source_positions; let cfg1 = ConvConfig { padding: 1, stride: 1, @@ -437,15 +439,22 @@ impl AudioEncoder { padding: 1, stride: 2, }; - let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?; + let conv1 = Conv1D::load( + cfg.num_mel_bins, + n_state, + 3, + cfg1, + &format!("{p}.conv1"), + vb, + )?; let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?; let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?; - let blocks = (0..cfg.n_audio_layer) + let blocks = (0..cfg.encoder_layers) .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb) + ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb) }) .collect::>>()?; - let ln_post = LayerNorm::load(n_state, &format!("{p}.ln_post"), vb)?; + let ln_post = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?; Ok(Self { conv1, conv2, @@ -480,19 +489,19 @@ pub struct TextDecoder { impl TextDecoder { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let n_state = cfg.n_text_state; - let n_head = cfg.n_text_head; - let n_ctx = cfg.n_text_ctx; + let n_state = cfg.d_model; + let n_head = cfg.decoder_attention_heads; + let n_ctx = cfg.max_target_positions; let token_embedding = - Embedding::load(cfg.n_vocab, n_state, &format!("{p}.token_embedding"), vb)?; + Embedding::load(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?; let positional_embedding = - vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?; - let blocks = (0..cfg.n_text_layer) + vb.get((n_ctx, n_state), &format!("{p}.embed_positions.weight"))?; + let blocks = (0..cfg.decoder_layers) .map(|i| { - ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.blocks.{i}"), vb) + ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb) }) .collect::>>()?; - let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?; + let ln = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?; let mask: Vec<_> = (0..n_ctx) .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) .collect(); @@ -532,8 +541,8 @@ pub struct Whisper { impl Whisper { pub fn load(vb: &VarBuilder, config: Config) -> Result { - let encoder = AudioEncoder::load("encoder", vb, &config)?; - let decoder = TextDecoder::load("decoder", vb, &config)?; + let encoder = AudioEncoder::load("model.encoder", vb, &config)?; + let decoder = TextDecoder::load("model.decoder", vb, &config)?; Ok(Self { encoder, decoder,