diff --git a/.gitignore b/.gitignore index c433e74b..400ff9bf 100644 --- a/.gitignore +++ b/.gitignore @@ -13,10 +13,13 @@ Cargo.lock # MSVC Windows builds of rustc generate these, which store debugging information *.pdb -*tokenizer.json +*tokenizer*.json *.npz perf.data flamegraph.svg *.so *.swp + +candle-wasm-example/*.wav +candle-wasm-example/*.safetensors diff --git a/Cargo.toml b/Cargo.toml index 218e717b..efbc80d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ members = [ "candle-nn", "candle-pyo3", "candle-transformers", + "candle-wasm-example", ] [profile.release-with-debug] diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 9e23cf9d..6ef709ce 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -155,6 +155,11 @@ impl MmapedFile { } impl<'a> SafeTensors<'a> { + pub fn from_buffer(buffer: &'a [u8]) -> Result { + let st = safetensors::SafeTensors::deserialize(buffer)?; + Ok(SafeTensors(st)) + } + pub fn tensor(&self, name: &str, device: &Device) -> Result { convert(self.0.tensor(name)?, device) } diff --git a/candle-wasm-example/Cargo.toml b/candle-wasm-example/Cargo.toml new file mode 100644 index 00000000..57e97584 --- /dev/null +++ b/candle-wasm-example/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "candle-wasm-example" +version = "0.1.0" +edition = "2021" + +description = "Wasm example for the candle ML framework." +repository = "https://github.com/LaurentMazare/candle" +keywords = ["blas", "tensor", "machine-learning"] +categories = ["science"] +license = "MIT/Apache-2.0" +readme = "README.md" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +candle = { path = "../candle-core", default-features=false } +candle-nn = { path = "../candle-nn", default-features=false } +wasm-bindgen = "0.2.87" +getrandom = { version = "0.2", features = ["js"] } +tokenizers = { version = "0.13.3", default-features=false, features=["unstable_wasm"] } +serde = { version = "1.0.166", features = ["derive"] } +serde_json = "1.0.99" +wav = "1.0.0" +rand = "0.8.5" +num-traits = "0.2.15" +anyhow = "1.0.71" +js-sys = "0.3.64" +wasm-bindgen-futures = "0.4.37" + +[dependencies.web-sys] +version = "0.3.64" +features = [ + 'Blob', + 'Document', + 'Element', + 'HtmlElement', + 'Node', + 'Window', + 'Request', + 'RequestCache', + 'RequestInit', + 'RequestMode', + 'Response', +] diff --git a/candle-wasm-example/index.html b/candle-wasm-example/index.html new file mode 100644 index 00000000..2ba74929 --- /dev/null +++ b/candle-wasm-example/index.html @@ -0,0 +1,9 @@ + + + + + Hello Candle - Rust + + + + diff --git a/candle-wasm-example/index.js b/candle-wasm-example/index.js new file mode 100644 index 00000000..74bb5e99 --- /dev/null +++ b/candle-wasm-example/index.js @@ -0,0 +1,8 @@ +import init from "./pkg/candle_wasm.js"; + +const runWasm = async () => { + const candleWasm = await init("./pkg/candle_wasm_bg.wasm"); + candleWasm.test_fn(); + await candleWasm.run_fn(); +}; +runWasm(); diff --git a/candle-wasm-example/src/audio.rs b/candle-wasm-example/src/audio.rs new file mode 100644 index 00000000..d095e239 --- /dev/null +++ b/candle-wasm-example/src/audio.rs @@ -0,0 +1,216 @@ +// Audio processing code, adapted from whisper.cpp +// https://github.com/ggerganov/whisper.cpp + +pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} + +impl Float for f32 {} +impl Float for f64 {} + +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357 +fn fft(inp: &[T]) -> Vec { + let n = inp.len(); + let zero = T::zero(); + if n == 1 { + return vec![inp[0], zero]; + } + if n % 2 == 1 { + return dft(inp); + } + let mut out = vec![zero; n * 2]; + + let mut even = vec![]; + even.reserve(n / 2); + let mut odd = vec![]; + odd.reserve(n / 2); + + for (i, &inp) in inp.iter().enumerate() { + if i % 2 == 0 { + even.push(inp) + } else { + odd.push(inp); + } + } + + let even_fft = fft(&even); + let odd_fft = fft(&odd); + + let two_pi = T::PI() + T::PI(); + let n_t = T::from(n).unwrap(); + for k in 0..n / 2 { + let k_t = T::from(k).unwrap(); + let theta = two_pi * k_t / n_t; + let re = theta.cos(); + let im = -theta.sin(); + + let re_odd = odd_fft[2 * k]; + let im_odd = odd_fft[2 * k + 1]; + + out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd; + out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd; + + out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd; + out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd; + } + out +} + +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337 +fn dft(inp: &[T]) -> Vec { + let zero = T::zero(); + let n = inp.len(); + let two_pi = T::PI() + T::PI(); + + let mut out = Vec::new(); + out.reserve(2 * n); + let n_t = T::from(n).unwrap(); + for k in 0..n { + let k_t = T::from(k).unwrap(); + let mut re = zero; + let mut im = zero; + + for (j, &inp) in inp.iter().enumerate() { + let j_t = T::from(j).unwrap(); + let angle = two_pi * k_t * j_t / n_t; + re += inp * angle.cos(); + im -= inp * angle.sin(); + } + + out.push(re); + out.push(im); + } + out +} + +#[allow(clippy::too_many_arguments)] +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414 +fn log_mel_spectrogram_w( + ith: usize, + hann: &[T], + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + speed_up: bool, + n_len: usize, + n_mel: usize, + n_threads: usize, +) -> Vec { + let n_fft = if speed_up { + 1 + fft_size / 4 + } else { + 1 + fft_size / 2 + }; + + let zero = T::zero(); + let half = T::from(0.5).unwrap(); + let mut fft_in = vec![zero; fft_size]; + let mut mel = vec![zero; n_len * n_mel]; + + for i in (ith..n_len).step_by(n_threads) { + let offset = i * fft_step; + + // apply Hanning window + for j in 0..fft_size { + fft_in[j] = if offset + j < samples.len() { + hann[j] * samples[offset + j] + } else { + zero + } + } + + // FFT -> mag^2 + let mut fft_out: Vec = fft(&fft_in); + + for j in 0..fft_size { + fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1]; + } + for j in 1..fft_size / 2 { + let v = fft_out[fft_size - j]; + fft_out[j] += v; + } + + if speed_up { + // scale down in the frequency domain results in a speed up in the time domain + for j in 0..n_fft { + fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]); + } + } + + // mel spectrogram + for j in 0..n_mel { + let mut sum = zero; + for k in 0..n_fft { + sum += fft_out[k] * filters[j * n_fft + k]; + } + mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10(); + } + } + mel +} + +fn log_mel_spectrogram_( + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + n_mel: usize, + speed_up: bool, +) -> Vec { + let zero = T::zero(); + let two_pi = T::PI() + T::PI(); + let half = T::from(0.5).unwrap(); + let one = T::from(1.0).unwrap(); + let four = T::from(4.0).unwrap(); + let fft_size_t = T::from(fft_size).unwrap(); + + let hann: Vec = (0..fft_size) + .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos())) + .collect(); + let n_len = samples.len() / fft_step; + + // pad audio with at least one extra chunk of zeros + let pad = 100 * super::CHUNK_LENGTH / 2; + let n_len = if n_len % pad != 0 { + (n_len / pad + 1) * pad + } else { + n_len + }; + let n_len = n_len + pad; + let samples = { + let mut samples_padded = samples.to_vec(); + let to_add = n_len * fft_step - samples.len(); + samples_padded.extend(std::iter::repeat(zero).take(to_add)); + samples_padded + }; + + // Use a single thread for now. + let mut mel = log_mel_spectrogram_w( + 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1, + ); + let mmax = mel + .iter() + .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater)) + .copied() + .unwrap_or(zero) + - T::from(8).unwrap(); + for m in mel.iter_mut() { + let v = T::max(*m, mmax); + *m = v / four + one + } + mel +} + +pub fn pcm_to_mel( + samples: &[T], + filters: &[T], +) -> anyhow::Result> { + let mel = log_mel_spectrogram_( + samples, + filters, + super::N_FFT, + super::HOP_LENGTH, + super::N_MELS, + false, + ); + Ok(mel) +} diff --git a/candle-wasm-example/src/lib.rs b/candle-wasm-example/src/lib.rs new file mode 100644 index 00000000..8e51f872 --- /dev/null +++ b/candle-wasm-example/src/lib.rs @@ -0,0 +1,335 @@ +#![allow(dead_code)] +use anyhow::Error as E; +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use rand::{distributions::Distribution, SeedableRng}; +use tokenizers::Tokenizer; +use wasm_bindgen::prelude::*; +use wasm_bindgen_futures::JsFuture; + +mod audio; +mod model; +use model::{Config, Whisper}; + +const DTYPE: DType = DType::F32; + +// Audio parameters. +const SAMPLE_RATE: usize = 16000; +const N_FFT: usize = 400; +const N_MELS: usize = 80; +const HOP_LENGTH: usize = 160; +const CHUNK_LENGTH: usize = 30; +const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk +const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input +const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2 +const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame +const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token + +const NO_SPEECH_THRESHOLD: f64 = 0.6; +const LOGPROB_THRESHOLD: f64 = -1.0; +const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; +const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; + +// Tokenizer dependent bits. +const SOT_TOKEN: u32 = 50257; +const EOT_TOKEN: u32 = 50256; +const NO_SPEECH_TOKEN: u32 = 50361; +const NO_TIMESTAMP_TOKEN: u32 = 50362; +// From the _get_suppress_tokens function + 50362 (no timestamp) +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/decoding.py#L605 +const SUPPRESS_TOKENS: [u32; 91] = [ + 1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 357, + 366, 438, 532, 685, 705, 796, 930, 1058, 1220, 1267, 1279, 1303, 1343, 1377, 1391, 1635, 1782, + 1875, 2162, 2361, 2488, 3467, 4008, 4211, 4600, 4808, 5299, 5855, 6329, 7203, 9609, 9959, + 10563, 10786, 11420, 11709, 11907, 13163, 13697, 13700, 14808, 15306, 16410, 16791, 17992, + 19203, 19510, 20724, 22305, 22935, 27007, 30109, 30420, 33409, 34949, 40283, 40493, 40549, + 47282, 49146, 50257, 50357, 50358, 50359, 50360, 50361, 50362, +]; + +#[wasm_bindgen] +extern "C" { + // Use `js_namespace` here to bind `console.log(..)` instead of just + // `log(..)` + #[wasm_bindgen(js_namespace = console)] + fn log(s: &str); +} + +macro_rules! console_log { + // Note that this is using the `log` function imported above during + // `bare_bones` + ($($t:tt)*) => (log(&format_args!($($t)*).to_string())) +} + +#[derive(Debug, Clone)] +struct DecodingResult { + tokens: Vec, + text: String, + avg_logprob: f64, + no_speech_prob: f64, + temperature: f64, + compression_ratio: f64, +} + +#[derive(Debug, Clone)] +struct Segment { + start: f64, + duration: f64, + dr: DecodingResult, +} + +struct Decoder { + model: Whisper, + rng: rand::rngs::StdRng, + tokenizer: Tokenizer, + suppress_tokens: Tensor, +} + +impl Decoder { + fn new( + model: Whisper, + tokenizer: Tokenizer, + seed: u64, + device: &Device, + ) -> anyhow::Result { + let suppress_tokens: Vec = (0..model.config.vocab_size as u32) + .map(|i| { + if SUPPRESS_TOKENS.contains(&i) { + f32::NEG_INFINITY + } else { + 0f32 + } + }) + .collect(); + let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; + Ok(Self { + model, + rng: rand::rngs::StdRng::seed_from_u64(seed), + tokenizer, + suppress_tokens, + }) + } + + fn decode(&mut self, mel: &Tensor, t: f64) -> anyhow::Result { + let model = &self.model; + let audio_features = model.encoder.forward(mel)?; + console_log!("audio features: {:?}", audio_features.dims()); + let sample_len = model.config.max_target_positions / 2; + let mut sum_logprob = 0f64; + let mut no_speech_prob = f64::NAN; + let mut tokens = vec![SOT_TOKEN]; + for i in 0..sample_len { + let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; + + // The model expects a batch dim but this inference loop does not handle + // it so we add it at this point. + let tokens_t = tokens_t.unsqueeze(0)?; + let logits = model.decoder.forward(&tokens_t, &audio_features)?; + let logits = logits.squeeze(0)?; + + // Extract the no speech probability on the first iteration by looking at the first + // token logits and the probability for the according token. + if i == 0 { + no_speech_prob = logits + .get(0)? + .softmax(0)? + .get(NO_SPEECH_TOKEN as usize)? + .to_scalar::()? as f64; + } + + let (seq_len, _) = logits.shape().r2()?; + let logits = logits + .get(seq_len - 1)? + .broadcast_add(&self.suppress_tokens)?; + let next_token = if t > 0f64 { + let prs = (&logits / t)?.softmax(0)?; + let logits_v: Vec = prs.to_vec1()?; + let distr = rand::distributions::WeightedIndex::new(&logits_v)?; + distr.sample(&mut self.rng) as u32 + } else { + let logits_v: Vec = logits.to_vec1()?; + logits_v + .iter() + .enumerate() + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap() + }; + tokens.push(next_token); + let prob = logits + .softmax(candle::D::Minus1)? + .get(next_token as usize)? + .to_scalar::()? as f64; + if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions { + break; + } + sum_logprob += prob.ln(); + } + let text = self + .tokenizer + .decode(tokens.clone(), true) + .map_err(E::msg)?; + let avg_logprob = sum_logprob / tokens.len() as f64; + + Ok(DecodingResult { + tokens, + text, + avg_logprob, + no_speech_prob, + temperature: t, + compression_ratio: f64::NAN, + }) + } + + fn decode_with_fallback(&mut self, segment: &Tensor) -> anyhow::Result { + for (i, &t) in TEMPERATURES.iter().enumerate() { + let dr: Result = self.decode(segment, t); + if i == TEMPERATURES.len() - 1 { + return dr; + } + // On errors, we try again with a different temperature. + match dr { + Ok(dr) => { + let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD + || dr.avg_logprob < LOGPROB_THRESHOLD; + if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { + return Ok(dr); + } + } + Err(err) => { + console_log!("Error running at {t}: {err}") + } + } + } + unreachable!() + } + + fn run(&mut self, mel: &Tensor) -> anyhow::Result> { + let (_, _, content_frames) = mel.shape().r3()?; + let mut seek = 0; + let mut segments = vec![]; + while seek < content_frames { + let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; + let segment_size = usize::min(content_frames - seek, N_FRAMES); + let mel_segment = mel.narrow(2, seek, segment_size)?; + let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; + let dr = self.decode_with_fallback(&mel_segment)?; + seek += segment_size; + if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { + console_log!("no speech detected, skipping {seek} {dr:?}"); + continue; + } + let segment = Segment { + start: time_offset, + duration: segment_duration, + dr, + }; + console_log!("{seek}: {segment:?}"); + segments.push(segment) + } + Ok(segments) + } +} + +async fn fetch_url(url: &str) -> Result, JsValue> { + use web_sys::{Request, RequestCache, RequestInit, RequestMode, Response}; + let window = web_sys::window().ok_or("window")?; + let mut opts = RequestInit::new(); + let opts = opts + .method("GET") + .mode(RequestMode::Cors) + .cache(RequestCache::NoCache); + + let request = Request::new_with_str_and_init(url, opts)?; + + let resp_value = JsFuture::from(window.fetch_with_request(&request)).await?; + + // `resp_value` is a `Response` object. + assert!(resp_value.is_instance_of::()); + let resp: Response = resp_value.dyn_into()?; + let data = JsFuture::from(resp.blob()?).await?; + let blob = web_sys::Blob::from(data); + let array_buffer = JsFuture::from(blob.array_buffer()).await?; + let data = js_sys::Uint8Array::new(&array_buffer).to_vec(); + Ok(data) +} + +fn w(x: T) -> String { + x.to_string() +} + +async fn run_impl() -> Result<(), JsValue> { + let device = Device::Cpu; + let tokenizer_config = fetch_url("tokenizer.en.json").await?; + let tokenizer = Tokenizer::from_bytes(tokenizer_config).map_err(w)?; + + let mel_filters = fetch_url("mel_filters.safetensors").await?; + let mel_filters = candle::safetensors::SafeTensors::from_buffer(&mel_filters).map_err(w)?; + let mel_filters = mel_filters.tensor("mel_80", &device).map_err(w)?; + console_log!("loaded mel filters {:?}", mel_filters.shape()); + let mel_filters = mel_filters + .flatten_all() + .map_err(w)? + .to_vec1::() + .map_err(w)?; + + let wav_input = fetch_url("jfk.wav").await?; + let mut wav_input = std::io::Cursor::new(wav_input); + let (header, data) = wav::read(&mut wav_input).map_err(w)?; + console_log!("loaded wav data: {header:?}"); + if header.sampling_rate != SAMPLE_RATE as u32 { + Err(format!( + "wav file must have a {} sampling rate", + SAMPLE_RATE + ))? + } + let data = data.as_sixteen().expect("expected 16 bit wav file"); + let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] + .iter() + .map(|v| *v as f32 / 32768.) + .collect(); + console_log!("pcm data loaded {}", pcm_data.len()); + let mel = audio::pcm_to_mel(&pcm_data, &mel_filters).map_err(w)?; + let mel_len = mel.len(); + let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device).map_err(w)?; + console_log!("loaded mel: {:?}", mel.dims()); + + let weights = fetch_url("tiny.en.safetensors").await?; + let weights = candle::safetensors::SafeTensors::from_buffer(&weights).map_err(w)?; + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); + let config = Config::tiny_en(); + let model = Whisper::load(&vb, config).map_err(w)?; + let mut dc = Decoder::new(model, tokenizer, 299792458, &device).map_err(w)?; + dc.run(&mel).map_err(w)?; + Ok(()) +} + +fn test_fn_impl() -> anyhow::Result { + let t1 = Tensor::randn((3, 4), DType::F32, &Device::Cpu, 0., 1.)?; + let t2 = Tensor::randn((4, 2), DType::F32, &Device::Cpu, 0., 1.)?; + let t = t1.matmul(&t2)?; + console_log!("matmul result: {t}"); + let res = format!("Hello Candle!\n\nt1:\n{t1}\n\nt2:\n{t2}\n\nt1@t2:\n{t}\n"); + Ok(res) +} + +#[wasm_bindgen] +pub fn test_fn() -> std::result::Result<(), JsValue> { + let result = match test_fn_impl() { + Ok(v) => v, + Err(err) => format!("error: {err:?}"), + }; + let window = web_sys::window().expect("no global `window` exists"); + let document = window.document().expect("should have a document on window"); + let p_element = document.create_element("p")?; + p_element.set_text_content(Some(&result)); + let body = document.body().expect("document should have a body"); + body.append_child(&p_element)?; + Ok(()) +} + +#[wasm_bindgen] +pub async fn run_fn() -> std::result::Result<(), JsValue> { + console_log!("run_fn starting..."); + run_impl().await?; + Ok(()) +} diff --git a/candle-wasm-example/src/model.rs b/candle-wasm-example/src/model.rs new file mode 100644 index 00000000..b19ff90a --- /dev/null +++ b/candle-wasm-example/src/model.rs @@ -0,0 +1,363 @@ +#![allow(dead_code)] +// We use anyhow rather than candle errors as it provides better support for getting the backtrace +// back when using RUST_LIB_BACKTRACE=1. +use anyhow::Result; +use candle::{Device, Tensor}; +use candle_nn::{Conv1d, Conv1dConfig, Embedding, LayerNorm, Linear, VarBuilder}; +use serde::Deserialize; + +// The names in comments correspond to the original implementation: +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L17 +#[derive(Debug, Clone, PartialEq, Deserialize)] +pub struct Config { + pub num_mel_bins: usize, // n_mels + pub max_source_positions: usize, // n_audio_ctx + pub d_model: usize, // n_audio_state + pub encoder_attention_heads: usize, // n_audio_head + pub encoder_layers: usize, // n_audio_layer + pub vocab_size: usize, // n_vocab + pub max_target_positions: usize, // n_text_ctx + // pub n_text_state: usize, + pub decoder_attention_heads: usize, // n_text_head + pub decoder_layers: usize, // n_text_layer +} + +impl Config { + pub fn tiny_en() -> Self { + Self { + num_mel_bins: 80, + vocab_size: 51864, + max_source_positions: 1500, + d_model: 384, + encoder_attention_heads: 6, + encoder_layers: 4, + max_target_positions: 448, + // n_text_state: 384, + decoder_attention_heads: 6, + decoder_layers: 4, + } + } +} + +fn embedding(vocab_size: usize, hidden_size: usize, vb: VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), "weight")?; + Ok(Embedding::new(embeddings, hidden_size)) +} + +fn linear(size1: usize, size2: usize, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; + let bias = vb.get(size2, "bias")?; + Ok(Linear::new(weight, Some(bias))) +} + +fn linear_no_bias(size1: usize, size2: usize, vb: VarBuilder) -> Result { + let weight = vb.get((size2, size1), "weight")?; + Ok(Linear::new(weight, None)) +} + +fn conv1d( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: Conv1dConfig, + vb: VarBuilder, +) -> Result { + let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; + let bias = vb.get(out_channels, "bias")?; + Ok(Conv1d::new(weight, Some(bias), config)) +} + +fn conv1d_no_bias( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: Conv1dConfig, + vb: VarBuilder, +) -> Result { + let weight = vb.get((out_channels, in_channels, kernel_size), "weight")?; + Ok(Conv1d::new(weight, None, config)) +} + +struct Dropout { + pr: f64, +} + +impl Dropout { + fn new(pr: f64) -> Self { + Self { pr } + } + + fn forward(&self, x: &Tensor) -> Result { + // TODO + Ok(x.clone()) + } +} + +fn layer_norm(size: usize, vb: VarBuilder) -> Result { + let weight = vb.get(size, "weight")?; + let bias = vb.get(size, "bias")?; + Ok(LayerNorm::new(weight, bias, 1e-5)) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +struct MultiHeadAttention { + query: Linear, + key: Linear, + value: Linear, + out: Linear, + n_head: usize, +} + +impl MultiHeadAttention { + fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result { + let query = linear(n_state, n_state, vb.pp("q_proj"))?; + let value = linear(n_state, n_state, vb.pp("v_proj"))?; + let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?; + let out = linear(n_state, n_state, vb.pp("out_proj"))?; + Ok(Self { + query, + key, + value, + out, + n_head, + }) + } + + fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { + let q = self.query.forward(x)?; + let k = self.key.forward(xa.unwrap_or(x))?; + let v = self.value.forward(xa.unwrap_or(x))?; + let wv = self.qkv_attention(&q, &k, &v, mask)?; + let out = self.out.forward(&wv)?; + Ok(out) + } + + fn reshape_head(&self, x: &Tensor) -> Result { + let (n_batch, n_ctx, n_state) = x.shape().r3()?; + let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; + Ok(x.reshape(target_dims)?.transpose(1, 2)?) + } + + fn qkv_attention( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + ) -> Result { + let (_, n_ctx, n_state) = q.shape().r3()?; + let scale = ((n_state / self.n_head) as f64).powf(-0.25); + let q = (self.reshape_head(q)? * scale)?; + let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; + let v = self.reshape_head(v)?.contiguous()?; + let mut qk = q.matmul(&k)?; + if let Some(mask) = mask { + let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?; + qk = qk.broadcast_add(&mask)? + } + let w = qk.softmax(candle::D::Minus1)?; + let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?; + Ok(wv) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +struct ResidualAttentionBlock { + attn: MultiHeadAttention, + attn_ln: LayerNorm, + cross_attn: Option<(MultiHeadAttention, LayerNorm)>, + mlp_linear1: Linear, + mlp_linear2: Linear, + mlp_ln: LayerNorm, +} + +impl ResidualAttentionBlock { + fn load(n_state: usize, n_head: usize, ca: bool, vb: VarBuilder) -> Result { + let attn = MultiHeadAttention::load(n_state, n_head, vb.pp("self_attn"))?; + let attn_ln = layer_norm(n_state, vb.pp("self_attn_layer_norm"))?; + let cross_attn = if ca { + let cross_attn = MultiHeadAttention::load(n_state, n_head, vb.pp("encoder_attn"))?; + let cross_attn_ln = layer_norm(n_state, vb.pp("encoder_attn_layer_norm"))?; + Some((cross_attn, cross_attn_ln)) + } else { + None + }; + let n_mlp = n_state * 4; + let mlp_linear1 = linear(n_state, n_mlp, vb.pp("fc1"))?; + let mlp_linear2 = linear(n_mlp, n_state, vb.pp("fc2"))?; + let mlp_ln = layer_norm(n_state, vb.pp("final_layer_norm"))?; + Ok(Self { + attn, + attn_ln, + cross_attn, + mlp_linear1, + mlp_linear2, + mlp_ln, + }) + } + + fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { + let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?; + let mut x = (x + attn)?; + if let Some((attn, ln)) = &self.cross_attn { + x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?; + } + let mlp = self.mlp_linear2.forward( + &self + .mlp_linear1 + .forward(&self.mlp_ln.forward(&x)?)? + .gelu()?, + )?; + Ok((x + mlp)?) + } +} + +fn sinusoids(length: usize, channels: usize) -> Result { + let max_timescale = 10000f32; + let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; + let inv_timescales: Vec<_> = (0..channels / 2) + .map(|i| (i as f32 * (-log_timescale_increment)).exp()) + .collect(); + let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let arange = Tensor::arange(0, length as u32, &Device::Cpu)? + .to_dtype(candle::DType::F32)? + .unsqueeze(1)?; + let sh = (length, channels / 2); + let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; + let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; + Ok(sincos) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +pub struct AudioEncoder { + conv1: Conv1d, + conv2: Conv1d, + positional_embedding: Tensor, + blocks: Vec, + ln_post: LayerNorm, +} + +impl AudioEncoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let n_state = cfg.d_model; + let n_head = cfg.encoder_attention_heads; + let n_ctx = cfg.max_source_positions; + let cfg1 = Conv1dConfig { + padding: 1, + stride: 1, + }; + let cfg2 = Conv1dConfig { + padding: 1, + stride: 2, + }; + let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; + let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; + let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(vb.device())?; + let blocks = (0..cfg.encoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, false, vb.pp(&format!("layers.{i}"))) + }) + .collect::>>()?; + let ln_post = layer_norm(n_state, vb.pp("layer_norm"))?; + Ok(Self { + conv1, + conv2, + positional_embedding, + blocks, + ln_post, + }) + } + pub fn forward(&self, x: &Tensor) -> Result { + let x = self.conv1.forward(x)?.gelu()?; + let x = self.conv2.forward(&x)?.gelu()?; + let x = x.transpose(1, 2)?; + let (_bsize, seq_len, _hidden) = x.shape().r3()?; + let positional_embedding = self.positional_embedding.narrow(0, 0, seq_len)?; + let mut x = x.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter() { + x = block.forward(&x, None, None)? + } + let x = self.ln_post.forward(&x)?; + Ok(x) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +pub struct TextDecoder { + token_embedding: Embedding, + positional_embedding: Tensor, + blocks: Vec, + ln: LayerNorm, + mask: Tensor, +} + +impl TextDecoder { + fn load(vb: VarBuilder, cfg: &Config) -> Result { + let n_state = cfg.d_model; + let n_head = cfg.decoder_attention_heads; + let n_ctx = cfg.max_target_positions; + let token_embedding = embedding(cfg.vocab_size, n_state, vb.pp("embed_tokens"))?; + let positional_embedding = vb.get((n_ctx, n_state), "embed_positions.weight")?; + let blocks = (0..cfg.decoder_layers) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, true, vb.pp(&format!("layers.{i}"))) + }) + .collect::>>()?; + let ln = layer_norm(n_state, vb.pp("layer_norm"))?; + let mask: Vec<_> = (0..n_ctx) + .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), vb.device())?; + + Ok(Self { + token_embedding, + positional_embedding, + blocks, + ln, + mask, + }) + } + + pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result { + let x_dims = x.dims(); + let last = x_dims[x_dims.len() - 1]; + let token_embedding = self.token_embedding.forward(x)?; + let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; + let mut x = token_embedding.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter() { + x = block.forward(&x, Some(xa), Some(&self.mask))?; + } + let x = self.ln.forward(&x)?; + let w = self + .token_embedding + .embeddings() + .broadcast_left(x_dims[0])?; + let logits = x.matmul(&w.t()?)?; + Ok(logits) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +pub struct Whisper { + pub encoder: AudioEncoder, + pub decoder: TextDecoder, + pub config: Config, +} + +impl Whisper { + pub fn load(vb: &VarBuilder, config: Config) -> Result { + let encoder = AudioEncoder::load(vb.pp("model.encoder"), &config)?; + let decoder = TextDecoder::load(vb.pp("model.decoder"), &config)?; + Ok(Self { + encoder, + decoder, + config, + }) + } + + pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result { + let enc = self.encoder.forward(mel)?; + let dec = self.decoder.forward(tokens, &enc)?; + Ok(dec) + } +}