Merge pull request #78 from LaurentMazare/whisper_update

Adapting whisper for Hub use.
This commit is contained in:
Laurent Mazare
2023-07-06 07:21:58 +01:00
committed by GitHub
2 changed files with 117 additions and 61 deletions

View File

@ -7,6 +7,7 @@
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_hub::{api::Api, Repo, RepoType};
use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
@ -68,7 +69,7 @@ impl Decode {
let model = &self.model;
let audio_features = model.encoder.forward(mel)?;
println!("audio features: {:?}", audio_features.dims());
let sample_len = model.config.n_text_ctx / 2;
let sample_len = model.config.max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![SOT_TOKEN];
@ -112,7 +113,7 @@ impl Decode {
.softmax(candle::D::Minus1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx {
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
break;
}
sum_logprob += prob.ln();
@ -165,14 +166,15 @@ struct Args {
cpu: bool,
#[arg(long)]
weights: String,
model_id: Option<String>,
/// The input to be processed, in wav formats.
/// The model to use, check out available models: https://huggingface.co/models?search=whisper
#[arg(long)]
input: String,
revision: Option<String>,
/// The input to be processed, in wav formats, will default to `jfk.wav` https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav
#[arg(long)]
tokenizer_config: String,
input: Option<String>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
@ -186,7 +188,8 @@ struct Args {
filters: String,
}
fn main() -> Result<()> {
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse();
let device = if args.cpu {
Device::Cpu
@ -195,7 +198,49 @@ fn main() -> Result<()> {
};
let rng = rand::rngs::StdRng::seed_from_u64(args.seed);
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
let default_model = "openai/whisper-tiny.en".to_string();
let path = std::path::PathBuf::from(default_model.clone());
let default_revision = "refs/pr/15".to_string();
let (model_id, revision) = match (args.model_id, args.revision) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let (config_filename, tokenizer_filename, weights_filename, input) = if path.exists() {
let mut config_filename = path.clone();
config_filename.push("config.json");
let mut tokenizer_filename = path.clone();
tokenizer_filename.push("tokenizer.json");
let mut model_filename = path.clone();
model_filename.push("model.safetensors");
(
config_filename,
tokenizer_filename,
model_filename,
std::path::PathBuf::from(args.input.expect("You didn't specify a file to read from yet, are using a local model, please add `--input example.wav` to read some audio file")),
)
} else {
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?;
(
api.get(&repo, "config.json").await?,
api.get(&repo, "tokenizer.json").await?,
api.get(&repo, "model.safetensors").await?,
if let Some(input) = args.input {
std::path::PathBuf::from(input)
} else {
println!("No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav");
api.get(
&Repo::new("Narsil/candle-examples".to_string(), RepoType::Dataset),
"samples_jfk.wav",
)
.await?
},
)
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
let mel_filters = mel_filters.deserialize()?;
@ -203,7 +248,7 @@ fn main() -> Result<()> {
println!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let mut input = std::fs::File::open(args.input)?;
let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;
println!("loaded wav data: {header:?}");
if header.sampling_rate != SAMPLE_RATE as u32 {
@ -220,10 +265,11 @@ fn main() -> Result<()> {
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
let model = Whisper::load(&vb, Config::tiny_en())?;
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let model = Whisper::load(&vb, config)?;
let mut dc = Decode {
model,
rng,
@ -233,6 +279,7 @@ fn main() -> Result<()> {
let (_, _, content_frames) = mel.shape().r3()?;
let mut seek = 0;
let mut segments = vec![];
let start = std::time::Instant::now();
while seek < content_frames {
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, N_FRAMES);
@ -249,7 +296,7 @@ fn main() -> Result<()> {
duration: segment_duration,
dr,
};
println!("{seek}: {segment:?}");
println!("{seek}: {segment:?} : Took {:?}", start.elapsed());
segments.push(segment)
}
Ok(())

View File

@ -2,6 +2,7 @@
// back when using RUST_LIB_BACKTRACE=1.
use anyhow::Result;
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use serde::Deserialize;
use std::collections::HashMap;
pub struct VarBuilder<'a> {
@ -76,33 +77,33 @@ impl HiddenAct {
}
}
#[derive(Debug, Clone, PartialEq)]
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Config {
pub n_mels: usize,
pub n_audio_ctx: usize,
pub n_audio_state: usize,
pub n_audio_head: usize,
pub n_audio_layer: usize,
pub n_vocab: usize,
pub n_text_ctx: usize,
pub n_text_state: usize,
pub n_text_head: usize,
pub n_text_layer: usize,
pub num_mel_bins: usize,
pub max_source_positions: usize,
pub d_model: usize,
pub encoder_attention_heads: usize,
pub encoder_layers: usize,
pub vocab_size: usize,
pub max_target_positions: usize,
// pub n_text_state: usize,
pub decoder_attention_heads: usize,
pub decoder_layers: usize,
}
impl Config {
pub fn tiny_en() -> Self {
Self {
n_mels: 80,
n_vocab: 51864,
n_audio_ctx: 1500,
n_audio_state: 384,
n_audio_head: 6,
n_audio_layer: 4,
n_text_ctx: 448,
n_text_state: 384,
n_text_head: 6,
n_text_layer: 4,
num_mel_bins: 80,
vocab_size: 51864,
max_source_positions: 1500,
d_model: 384,
encoder_attention_heads: 6,
encoder_layers: 4,
max_target_positions: 448,
// n_text_state: 384,
decoder_attention_heads: 6,
decoder_layers: 4,
}
}
}
@ -297,10 +298,10 @@ struct MultiHeadAttention {
impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?;
let value = Linear::load(n_state, n_state, &format!("{p}.value"), vb)?;
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.key"), vb)?;
let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?;
let query = Linear::load(n_state, n_state, &format!("{p}.q_proj"), vb)?;
let value = Linear::load(n_state, n_state, &format!("{p}.v_proj"), vb)?;
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
let out = Linear::load(n_state, n_state, &format!("{p}.out_proj"), vb)?;
Ok(Self {
query,
key,
@ -360,20 +361,21 @@ struct ResidualAttentionBlock {
impl ResidualAttentionBlock {
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?;
let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?;
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?;
let attn_ln = LayerNorm::load(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
let cross_attn = if ca {
let cross_attn =
MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?;
let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?;
MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?;
let cross_attn_ln =
LayerNorm::load(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
Some((cross_attn, cross_attn_ln))
} else {
None
};
let n_mlp = n_state * 4;
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?;
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.mlp.2"), vb)?;
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.mlp_ln"), vb)?;
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.final_layer_norm"), vb)?;
Ok(Self {
attn,
attn_ln,
@ -426,9 +428,9 @@ pub struct AudioEncoder {
impl AudioEncoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.n_audio_state;
let n_head = cfg.n_audio_head;
let n_ctx = cfg.n_audio_ctx;
let n_state = cfg.d_model;
let n_head = cfg.encoder_attention_heads;
let n_ctx = cfg.max_source_positions;
let cfg1 = ConvConfig {
padding: 1,
stride: 1,
@ -437,15 +439,22 @@ impl AudioEncoder {
padding: 1,
stride: 2,
};
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
let conv1 = Conv1D::load(
cfg.num_mel_bins,
n_state,
3,
cfg1,
&format!("{p}.conv1"),
vb,
)?;
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?;
let blocks = (0..cfg.n_audio_layer)
let blocks = (0..cfg.encoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb)
})
.collect::<Result<Vec<_>>>()?;
let ln_post = LayerNorm::load(n_state, &format!("{p}.ln_post"), vb)?;
let ln_post = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?;
Ok(Self {
conv1,
conv2,
@ -480,19 +489,19 @@ pub struct TextDecoder {
impl TextDecoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.n_text_state;
let n_head = cfg.n_text_head;
let n_ctx = cfg.n_text_ctx;
let n_state = cfg.d_model;
let n_head = cfg.decoder_attention_heads;
let n_ctx = cfg.max_target_positions;
let token_embedding =
Embedding::load(cfg.n_vocab, n_state, &format!("{p}.token_embedding"), vb)?;
Embedding::load(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?;
let positional_embedding =
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?;
let blocks = (0..cfg.n_text_layer)
vb.get((n_ctx, n_state), &format!("{p}.embed_positions.weight"))?;
let blocks = (0..cfg.decoder_layers)
.map(|i| {
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.blocks.{i}"), vb)
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb)
})
.collect::<Result<Vec<_>>>()?;
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
let ln = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?;
let mask: Vec<_> = (0..n_ctx)
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
.collect();
@ -532,8 +541,8 @@ pub struct Whisper {
impl Whisper {
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
let encoder = AudioEncoder::load("encoder", vb, &config)?;
let decoder = TextDecoder::load("decoder", vb, &config)?;
let encoder = AudioEncoder::load("model.encoder", vb, &config)?;
let decoder = TextDecoder::load("model.decoder", vb, &config)?;
Ok(Self {
encoder,
decoder,