diff --git a/candle-core/src/backprop.rs b/candle-core/src/backprop.rs index 45448505..a44f732f 100644 --- a/candle-core/src/backprop.rs +++ b/candle-core/src/backprop.rs @@ -33,7 +33,12 @@ impl Tensor { track_grad |= tg; nodes } - Op::Add(lhs, rhs) + Op::Conv1D { + arg: lhs, + kernel: rhs, + .. + } + | Op::Add(lhs, rhs) | Op::Mul(lhs, rhs) | Op::Sub(lhs, rhs) | Op::Div(lhs, rhs) @@ -147,6 +152,7 @@ impl Tensor { let f_grad = pred.where_cond(&zeros, &grad)?; *f_sum_grad = f_sum_grad.add(&f_grad)?; } + Op::Conv1D { .. } => return Err(Error::BackwardNotSupported { op: "conv1d" }), Op::Embedding(_lhs, _rhs) => { return Err(Error::BackwardNotSupported { op: "embedding" }) } diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs new file mode 100644 index 00000000..041bb6fb --- /dev/null +++ b/candle-core/src/conv.rs @@ -0,0 +1,27 @@ +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct ParamsConv1D { + pub(crate) b_size: Option, + // Maybe we should have a version without l_in as this bit depends on the input and not only on + // the weights. + pub(crate) l_in: usize, + pub(crate) c_out: usize, + pub(crate) c_in: usize, + pub(crate) k_size: usize, + pub(crate) padding: usize, + pub(crate) stride: usize, +} + +impl ParamsConv1D { + pub(crate) fn l_out(&self) -> usize { + let dilation = 1; + (self.l_in + 2 * self.padding - dilation * (self.k_size - 1) - 1) / self.stride + 1 + } + + pub(crate) fn out_dims(&self) -> Vec { + let l_out = self.l_out(); + match self.b_size { + None => vec![self.c_out, l_out], + Some(n) => vec![n, self.c_out, l_out], + } + } +} diff --git a/candle-core/src/cpu_backend.rs b/candle-core/src/cpu_backend.rs index 0871175f..b2345756 100644 --- a/candle-core/src/cpu_backend.rs +++ b/candle-core/src/cpu_backend.rs @@ -202,6 +202,64 @@ fn copy_strided_src_( } } +struct Conv1D<'a>(&'a crate::conv::ParamsConv1D); + +impl<'a> Map2 for Conv1D<'a> { + const OP: &'static str = "conv1d"; + fn f( + &self, + inp: &[T], + inp_l: &Layout, + k: &[T], + k_l: &Layout, + ) -> Result> { + // TODO: Optimize this (proper algorithm, simd, multithread, remove bound checks, etc). + let p = self.0; + let inp = &inp[inp_l.start_offset()..]; + let k = &k[k_l.start_offset()..]; + let inp_stride = inp_l.stride(); + let (inp_stride0, inp_stride) = if inp_stride.len() == 3 { + (inp_stride[0], &inp_stride[1..]) + } else { + (0, inp_stride) // This value never gets used anyway + }; + let k_stride = k_l.stride(); + let k_over_2 = p.k_size / 2; + let l_out = p.l_out(); + let dst_elems = p.c_out * l_out * p.b_size.unwrap_or(1); + let mut dst = vec![T::zero(); dst_elems]; + // The output shape is [b_size, c_out, l_out] + for b_idx in 0..p.b_size.unwrap_or(1) { + let inp_idx = b_idx * inp_stride0; + let dst_idx = b_idx * p.c_out * l_out; + for dst_c_idx in 0..p.c_out { + let dst_idx = dst_idx + dst_c_idx * l_out; + for dst_l in 0..l_out { + let dst_idx = dst_idx + dst_l; + let mut d = T::zero(); + for offset in 0..p.k_size { + let src_l_plus = p.stride * dst_l + offset; + // inp[bidx, src_c_idx, dst_l + offset - k//2] * k[dst_c_idx, src_c_idx, offset] + if k_over_2 <= src_l_plus && src_l_plus < k_over_2 + p.l_in { + let src_l = src_l_plus - k_over_2; + for src_c_idx in 0..p.c_in { + let inp_idx = + inp_idx + src_c_idx * inp_stride[0] + src_l * inp_stride[1]; + let k_idx = dst_c_idx * k_stride[0] + + src_c_idx * k_stride[1] + + offset * k_stride[2]; + d += inp[inp_idx] * k[k_idx] + } + } + } + dst[dst_idx] = d + } + } + } + Ok(dst) + } +} + struct MatMul((usize, usize, usize, usize)); impl Map2 for MatMul { @@ -627,6 +685,16 @@ impl CpuStorage { WCond(pred, layout).map(t, t_l, f, f_l) } + pub(crate) fn conv1d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv1D, + ) -> Result { + Conv1D(params).map(self, l, kernel, kernel_l) + } + pub(crate) fn embedding(&self, ids_l: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let ids = self.as_slice::()?; let (vocab_size, hidden_size) = rhs_l.shape().r2()?; diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 0c87004b..917655fc 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -801,6 +801,16 @@ impl CudaStorage { Ok(Self { slice, device }) } + pub(crate) fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConv1D, + ) -> Result { + todo!() + } + pub(crate) fn embedding(&self, layout: &Layout, rhs: &Self, rhs_l: &Layout) -> Result { let device = self.device().clone(); let slice = Embedding(self, layout).map(&rhs.slice, &device, rhs_l)?; diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index b025eeab..0dbd8d54 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -100,6 +100,16 @@ impl CudaStorage { Err(Error::NotCompiledWithCudaSupport) } + pub(crate) fn conv1d( + &self, + _l: &Layout, + _kernel: &Self, + _kernel_l: &Layout, + _params: &crate::conv::ParamsConv1D, + ) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } + pub(crate) fn embedding(&self, _: &Layout, _: &Self, _: &Layout) -> Result { Err(Error::NotCompiledWithCudaSupport) } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 0d4c2a8d..2365a34d 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -1,4 +1,5 @@ mod backprop; +mod conv; mod cpu_backend; #[cfg(feature = "cuda")] mod cuda_backend; diff --git a/candle-core/src/op.rs b/candle-core/src/op.rs index 860be0b3..ee57b325 100644 --- a/candle-core/src/op.rs +++ b/candle-core/src/op.rs @@ -12,6 +12,14 @@ pub(crate) enum Op { Embedding(Tensor, Tensor), WhereCond(Tensor, Tensor, Tensor), + #[allow(dead_code)] + Conv1D { + arg: Tensor, + kernel: Tensor, + padding: usize, + stride: usize, + }, + Cat(Vec, usize), #[allow(dead_code)] // add is currently unused. diff --git a/candle-core/src/storage.rs b/candle-core/src/storage.rs index 4e630a58..53ea1544 100644 --- a/candle-core/src/storage.rs +++ b/candle-core/src/storage.rs @@ -144,6 +144,32 @@ impl Storage { } } + pub(crate) fn conv1d( + &self, + l: &Layout, + kernel: &Self, + kernel_l: &Layout, + params: &crate::conv::ParamsConv1D, + ) -> Result { + self.same_device(kernel, "conv1d")?; + self.same_dtype(kernel, "conv1d")?; + match (self, &kernel) { + (Storage::Cpu(inp), Storage::Cpu(kernel)) => { + let s = inp.conv1d(l, kernel, kernel_l, params)?; + Ok(Self::Cpu(s)) + } + (Storage::Cuda(inp), Storage::Cuda(kernel)) => { + let s = inp.conv1d(l, kernel, kernel_l, params)?; + Ok(Self::Cuda(s)) + } + (lhs, rhs) => Err(Error::DeviceMismatchBinaryOp { + lhs: lhs.device().location(), + rhs: rhs.device().location(), + op: "conv1d", + }), + } + } + pub(crate) fn where_cond( &self, layout: &Layout, diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a468d879..95f663f0 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -326,7 +326,7 @@ impl Tensor { } let from_cpu_storage = |cpu_storage: &crate::CpuStorage| { let data = S::cpu_storage_as_slice(cpu_storage)?; - Ok::<_, Error>(data[0]) + Ok::<_, Error>(data[self.layout().start_offset()]) }; match self.storage.as_ref() { Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage), @@ -432,6 +432,42 @@ impl Tensor { Ok(from_storage(storage, dims, op, false)) } + pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result { + let (c_out, c_in_k, k_size) = kernel.shape().r3()?; + let (b_size, c_in, l_in) = match *self.dims() { + [b_size, c_in, l_in] => (Some(b_size), c_in, l_in), + [c_in, l_in] => (None, c_in, l_in), + _ => todo!("proper error message"), + }; + if c_in != c_in_k { + todo!("proper error message") + } + let params = crate::conv::ParamsConv1D { + b_size, + l_in, + c_out, + c_in, + k_size, + padding, + stride, + }; + let storage = + self.storage + .conv1d(self.layout(), &kernel.storage, kernel.layout(), ¶ms)?; + let op = if self.track_op() || kernel.track_op() { + Some(Op::Conv1D { + arg: self.clone(), + kernel: kernel.clone(), + padding, + stride, + }) + } else { + None + }; + let out_dims = params.out_dims(); + Ok(from_storage(storage, out_dims, op, false)) + } + pub fn matmul(&self, rhs: &Self) -> Result { let a_dims = self.shape().dims(); let b_dims = rhs.shape().dims(); diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 53a1a150..a39ee3a3 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -14,6 +14,7 @@ readme = "README.md" candle = { path = "../candle-core", default-features=false } serde = { version = "1.0.166", features = ["derive"] } serde_json = "1.0.99" +num-traits = "0.2.15" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } @@ -22,6 +23,7 @@ clap = { version = "4.2.4", features = ["derive"] } rand = "0.8.5" tokenizers = { version = "0.13.3", default-features=false, features=["onig"] } tokio = { version = "1.28.2", features = ["macros", "rt-multi-thread"] } +wav = "1.0.0" [features] default = ["cuda"] diff --git a/candle-examples/examples/whisper/audio.rs b/candle-examples/examples/whisper/audio.rs new file mode 100644 index 00000000..d095e239 --- /dev/null +++ b/candle-examples/examples/whisper/audio.rs @@ -0,0 +1,216 @@ +// Audio processing code, adapted from whisper.cpp +// https://github.com/ggerganov/whisper.cpp + +pub trait Float: num_traits::Float + num_traits::FloatConst + num_traits::NumAssign {} + +impl Float for f32 {} +impl Float for f64 {} + +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2357 +fn fft(inp: &[T]) -> Vec { + let n = inp.len(); + let zero = T::zero(); + if n == 1 { + return vec![inp[0], zero]; + } + if n % 2 == 1 { + return dft(inp); + } + let mut out = vec![zero; n * 2]; + + let mut even = vec![]; + even.reserve(n / 2); + let mut odd = vec![]; + odd.reserve(n / 2); + + for (i, &inp) in inp.iter().enumerate() { + if i % 2 == 0 { + even.push(inp) + } else { + odd.push(inp); + } + } + + let even_fft = fft(&even); + let odd_fft = fft(&odd); + + let two_pi = T::PI() + T::PI(); + let n_t = T::from(n).unwrap(); + for k in 0..n / 2 { + let k_t = T::from(k).unwrap(); + let theta = two_pi * k_t / n_t; + let re = theta.cos(); + let im = -theta.sin(); + + let re_odd = odd_fft[2 * k]; + let im_odd = odd_fft[2 * k + 1]; + + out[2 * k] = even_fft[2 * k] + re * re_odd - im * im_odd; + out[2 * k + 1] = even_fft[2 * k + 1] + re * im_odd + im * re_odd; + + out[2 * (k + n / 2)] = even_fft[2 * k] - re * re_odd + im * im_odd; + out[2 * (k + n / 2) + 1] = even_fft[2 * k + 1] - re * im_odd - im * re_odd; + } + out +} + +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2337 +fn dft(inp: &[T]) -> Vec { + let zero = T::zero(); + let n = inp.len(); + let two_pi = T::PI() + T::PI(); + + let mut out = Vec::new(); + out.reserve(2 * n); + let n_t = T::from(n).unwrap(); + for k in 0..n { + let k_t = T::from(k).unwrap(); + let mut re = zero; + let mut im = zero; + + for (j, &inp) in inp.iter().enumerate() { + let j_t = T::from(j).unwrap(); + let angle = two_pi * k_t * j_t / n_t; + re += inp * angle.cos(); + im -= inp * angle.sin(); + } + + out.push(re); + out.push(im); + } + out +} + +#[allow(clippy::too_many_arguments)] +// https://github.com/ggerganov/whisper.cpp/blob/4774d2feb01a772a15de81ffc34b34a1f294f020/whisper.cpp#L2414 +fn log_mel_spectrogram_w( + ith: usize, + hann: &[T], + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + speed_up: bool, + n_len: usize, + n_mel: usize, + n_threads: usize, +) -> Vec { + let n_fft = if speed_up { + 1 + fft_size / 4 + } else { + 1 + fft_size / 2 + }; + + let zero = T::zero(); + let half = T::from(0.5).unwrap(); + let mut fft_in = vec![zero; fft_size]; + let mut mel = vec![zero; n_len * n_mel]; + + for i in (ith..n_len).step_by(n_threads) { + let offset = i * fft_step; + + // apply Hanning window + for j in 0..fft_size { + fft_in[j] = if offset + j < samples.len() { + hann[j] * samples[offset + j] + } else { + zero + } + } + + // FFT -> mag^2 + let mut fft_out: Vec = fft(&fft_in); + + for j in 0..fft_size { + fft_out[j] = fft_out[2 * j] * fft_out[2 * j] + fft_out[2 * j + 1] * fft_out[2 * j + 1]; + } + for j in 1..fft_size / 2 { + let v = fft_out[fft_size - j]; + fft_out[j] += v; + } + + if speed_up { + // scale down in the frequency domain results in a speed up in the time domain + for j in 0..n_fft { + fft_out[j] = half * (fft_out[2 * j] + fft_out[2 * j + 1]); + } + } + + // mel spectrogram + for j in 0..n_mel { + let mut sum = zero; + for k in 0..n_fft { + sum += fft_out[k] * filters[j * n_fft + k]; + } + mel[j * n_len + i] = T::max(sum, T::from(1e-10).unwrap()).log10(); + } + } + mel +} + +fn log_mel_spectrogram_( + samples: &[T], + filters: &[T], + fft_size: usize, + fft_step: usize, + n_mel: usize, + speed_up: bool, +) -> Vec { + let zero = T::zero(); + let two_pi = T::PI() + T::PI(); + let half = T::from(0.5).unwrap(); + let one = T::from(1.0).unwrap(); + let four = T::from(4.0).unwrap(); + let fft_size_t = T::from(fft_size).unwrap(); + + let hann: Vec = (0..fft_size) + .map(|i| half * (one - ((two_pi * T::from(i).unwrap()) / fft_size_t).cos())) + .collect(); + let n_len = samples.len() / fft_step; + + // pad audio with at least one extra chunk of zeros + let pad = 100 * super::CHUNK_LENGTH / 2; + let n_len = if n_len % pad != 0 { + (n_len / pad + 1) * pad + } else { + n_len + }; + let n_len = n_len + pad; + let samples = { + let mut samples_padded = samples.to_vec(); + let to_add = n_len * fft_step - samples.len(); + samples_padded.extend(std::iter::repeat(zero).take(to_add)); + samples_padded + }; + + // Use a single thread for now. + let mut mel = log_mel_spectrogram_w( + 0, &hann, &samples, filters, fft_size, fft_step, speed_up, n_len, n_mel, 1, + ); + let mmax = mel + .iter() + .max_by(|&u, &v| u.partial_cmp(v).unwrap_or(std::cmp::Ordering::Greater)) + .copied() + .unwrap_or(zero) + - T::from(8).unwrap(); + for m in mel.iter_mut() { + let v = T::max(*m, mmax); + *m = v / four + one + } + mel +} + +pub fn pcm_to_mel( + samples: &[T], + filters: &[T], +) -> anyhow::Result> { + let mel = log_mel_spectrogram_( + samples, + filters, + super::N_FFT, + super::HOP_LENGTH, + super::N_MELS, + false, + ); + Ok(mel) +} diff --git a/candle-examples/examples/whisper/extract_weights.py b/candle-examples/examples/whisper/extract_weights.py new file mode 100644 index 00000000..65602703 --- /dev/null +++ b/candle-examples/examples/whisper/extract_weights.py @@ -0,0 +1,13 @@ +# Get the checkpoint from +# https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt + +import torch +from safetensors.torch import save_file + +data = torch.load("tiny.en.pt") +weights = {} +for k, v in data["model_state_dict"].items(): + weights[k] = v.contiguous() + print(k, v.shape, v.dtype) +save_file(weights, "tiny.en.safetensors") +print(data["dims"]) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs new file mode 100644 index 00000000..6ea3e536 --- /dev/null +++ b/candle-examples/examples/whisper/main.rs @@ -0,0 +1,256 @@ +#![allow(dead_code)] +// https://github.com/openai/whisper/blob/main/whisper/model.py +// TODO: +// - kv-cache support? +// - Language detection? +// - Batch size greater than 1. + +use anyhow::{Error as E, Result}; +use candle::{DType, Device, Tensor}; +use clap::Parser; +use rand::{distributions::Distribution, SeedableRng}; +use tokenizers::Tokenizer; + +mod audio; +mod model; +use model::{Config, VarBuilder, Whisper}; + +const DTYPE: DType = DType::F32; + +// Audio parameters. +const SAMPLE_RATE: usize = 16000; +const N_FFT: usize = 400; +const N_MELS: usize = 80; +const HOP_LENGTH: usize = 160; +const CHUNK_LENGTH: usize = 30; +const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk +const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input +const N_SAMPLES_PER_TOKEN: usize = HOP_LENGTH * 2; // the initial convolutions has stride 2 +const FRAMES_PER_SECOND: usize = SAMPLE_RATE / HOP_LENGTH; // 10ms per audio frame +const TOKENS_PER_SECOND: usize = SAMPLE_RATE / N_SAMPLES_PER_TOKEN; // 20ms per audio token + +const NO_SPEECH_THRESHOLD: f64 = 0.6; +const LOGPROB_THRESHOLD: f64 = -1.0; +const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; +const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; + +// Tokenizer dependent bits. +const SOT_TOKEN: u32 = 50257; +const EOT_TOKEN: u32 = 50256; +const NO_SPEECH_TOKEN: u32 = 50361; +const NO_TIMESTAMP_TOKEN: u32 = 50362; + +#[derive(Debug, Clone)] +struct DecodingResult { + tokens: Vec, + text: String, + avg_logprob: f64, + no_speech_prob: f64, + temperature: f64, + compression_ratio: f64, +} + +#[derive(Debug, Clone)] +struct Segment { + start: f64, + duration: f64, + dr: DecodingResult, +} + +struct Decode { + model: Whisper, + rng: rand::rngs::StdRng, + tokenizer: Tokenizer, +} + +impl Decode { + fn decode(&mut self, mel: &Tensor, t: f64) -> Result { + let model = &self.model; + let audio_features = model.encoder.forward(mel)?; + println!("audio features: {:?}", audio_features.dims()); + let sample_len = model.config.n_text_ctx / 2; + let mut sum_logprob = 0f64; + let mut no_speech_prob = f64::NAN; + let mut tokens = vec![SOT_TOKEN]; + for i in 0..sample_len { + let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?; + + // The model expects a batch dim but this inference loop does not handle + // it so we add it at this point. + let tokens_t = tokens_t.unsqueeze(0)?; + let logits = model.decoder.forward(&tokens_t, &audio_features)?; + let logits = logits.squeeze(0)?; + + // Extract the no speech probability on the first iteration by looking at the first + // token logits and the probability for the according token. + if i == 0 { + no_speech_prob = logits + .get(0)? + .softmax(0)? + .get(NO_SPEECH_TOKEN as usize)? + .to_scalar::()? as f64; + } + + let (seq_len, _) = logits.shape().r2()?; + let logits = logits.get(seq_len - 1)?; + let next_token = if t > 0f64 { + let prs = (&logits / t)?.softmax(0)?; + let logits_v: Vec = prs.to_vec1()?; + let distr = rand::distributions::WeightedIndex::new(&logits_v)?; + distr.sample(&mut self.rng) as u32 + } else { + let logits_v: Vec = logits.to_vec1()?; + logits_v + .iter() + .enumerate() + .max_by(|(_, u), (_, v)| u.total_cmp(v)) + .map(|(i, _)| i as u32) + .unwrap() + }; + tokens.push(next_token); + let prob = logits + .softmax(logits.rank() - 1)? + .get(next_token as usize)? + .to_scalar::()? as f64; + if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx { + break; + } + sum_logprob += prob.ln(); + } + let text = self + .tokenizer + .decode(tokens.clone(), true) + .map_err(E::msg)?; + let avg_logprob = sum_logprob / tokens.len() as f64; + + Ok(DecodingResult { + tokens, + text, + avg_logprob, + no_speech_prob, + temperature: t, + compression_ratio: f64::NAN, + }) + } + + fn decode_with_fallback(&mut self, segment: &Tensor) -> Result { + for (i, &t) in TEMPERATURES.iter().enumerate() { + let dr: Result = self.decode(segment, t); + if i == TEMPERATURES.len() - 1 { + return dr; + } + // On errors, we try again with a different temperature. + match dr { + Ok(dr) => { + let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD + || dr.avg_logprob < LOGPROB_THRESHOLD; + if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { + return Ok(dr); + } + } + Err(err) => { + println!("Error running at {t}: {err}") + } + } + } + unreachable!() + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + #[arg(long)] + weights: String, + + /// The input to be processed, in wav formats. + #[arg(long)] + input: String, + + #[arg(long)] + tokenizer_config: String, + + /// The seed to use when generating random samples. + #[arg(long, default_value_t = 299792458)] + seed: u64, + + /// The mel filters in safetensors format. + #[arg( + long, + default_value = "candle-examples/examples/whisper/mel_filters.safetensors" + )] + filters: String, +} + +fn main() -> Result<()> { + let args = Args::parse(); + let device = if args.cpu { + Device::Cpu + } else { + Device::new_cuda(0)? + }; + let rng = rand::rngs::StdRng::seed_from_u64(args.seed); + + let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?; + + let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? }; + let mel_filters = mel_filters.deserialize()?; + let mel_filters = mel_filters.tensor("mel_80", &device)?; + println!("loaded mel filters {:?}", mel_filters.shape()); + let mel_filters = mel_filters.flatten_all()?.to_vec1::()?; + + let mut input = std::fs::File::open(args.input)?; + let (header, data) = wav::read(&mut input)?; + println!("loaded wav data: {header:?}"); + if header.sampling_rate != SAMPLE_RATE as u32 { + anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE) + } + let data = data.as_sixteen().expect("expected 16 bit wav file"); + let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] + .iter() + .map(|v| *v as f32 / 32768.) + .collect(); + println!("pcm data loaded {}", pcm_data.len()); + let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?; + let mel_len = mel.len(); + let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?; + println!("loaded mel: {:?}", mel.dims()); + + let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; + let weights = weights.deserialize()?; + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device); + let model = Whisper::load(&vb, Config::tiny_en())?; + let mut dc = Decode { + model, + rng, + tokenizer, + }; + + let (_, _, content_frames) = mel.shape().r3()?; + let mut seek = 0; + let mut segments = vec![]; + while seek < content_frames { + let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; + let segment_size = usize::min(content_frames - seek, N_FRAMES); + let mel_segment = mel.narrow(2, seek, segment_size)?; + let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; + let dr = dc.decode_with_fallback(&mel_segment)?; + seek += segment_size; + if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { + println!("no speech detected, skipping {seek} {dr:?}"); + continue; + } + let segment = Segment { + start: time_offset, + duration: segment_duration, + dr, + }; + println!("{seek}: {segment:?}"); + segments.push(segment) + } + Ok(()) +} diff --git a/candle-examples/examples/whisper/mel_filters.safetensors b/candle-examples/examples/whisper/mel_filters.safetensors new file mode 100644 index 00000000..98f3af44 Binary files /dev/null and b/candle-examples/examples/whisper/mel_filters.safetensors differ diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs new file mode 100644 index 00000000..53ee6a90 --- /dev/null +++ b/candle-examples/examples/whisper/model.rs @@ -0,0 +1,547 @@ +// We use anyhow rather than candle errors as it provides better support for getting the backtrace +// back when using RUST_LIB_BACKTRACE=1. +use anyhow::Result; +use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; +use std::collections::HashMap; + +pub struct VarBuilder<'a> { + safetensors: Option<(HashMap, Vec>)>, + dtype: DType, + device: Device, +} + +impl<'a> VarBuilder<'a> { + pub fn from_safetensors( + safetensors: Vec>, + dtype: DType, + device: Device, + ) -> Self { + let mut routing = HashMap::new(); + for (index, sf) in safetensors.iter().enumerate() { + for k in sf.names() { + routing.insert(k.to_string(), index); + } + } + Self { + safetensors: Some((routing, safetensors)), + device, + dtype, + } + } + + pub fn zeros(dtype: DType, device: Device) -> Self { + Self { + safetensors: None, + device, + dtype, + } + } + + pub fn get>(&self, s: S, tensor_name: &str) -> candle::Result { + let s: Shape = s.into(); + match &self.safetensors { + None => Tensor::zeros(s, self.dtype, &self.device), + Some((routing, safetensors)) => { + // Unwrap or 0 just to let the proper error flow. + let index = routing.get(tensor_name).unwrap_or(&0); + let tensor = safetensors[*index] + .tensor(tensor_name, &self.device)? + .to_dtype(self.dtype)?; + if *tensor.shape() != s { + let msg = format!("shape mismatch for {tensor_name}"); + Err(candle::Error::UnexpectedShape { + msg, + expected: s, + got: tensor.shape().clone(), + })? + } + Ok(tensor) + } + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum HiddenAct { + Gelu, + Relu, +} + +impl HiddenAct { + fn forward(&self, xs: &Tensor) -> candle::Result { + match self { + Self::Gelu => xs.gelu(), + Self::Relu => xs.relu(), + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub struct Config { + pub n_mels: usize, + pub n_audio_ctx: usize, + pub n_audio_state: usize, + pub n_audio_head: usize, + pub n_audio_layer: usize, + pub n_vocab: usize, + pub n_text_ctx: usize, + pub n_text_state: usize, + pub n_text_head: usize, + pub n_text_layer: usize, +} + +impl Config { + pub fn tiny_en() -> Self { + Self { + n_mels: 80, + n_vocab: 51864, + n_audio_ctx: 1500, + n_audio_state: 384, + n_audio_head: 6, + n_audio_layer: 4, + n_text_ctx: 448, + n_text_state: 384, + n_text_head: 6, + n_text_layer: 4, + } + } +} + +struct Embedding { + embeddings: Tensor, + hidden_size: usize, +} + +impl Embedding { + fn new(embeddings: Tensor, hidden_size: usize) -> Self { + Self { + embeddings, + hidden_size, + } + } + + fn load(vocab_size: usize, hidden_size: usize, p: &str, vb: &VarBuilder) -> Result { + let embeddings = vb.get((vocab_size, hidden_size), &format!("{p}.weight"))?; + Ok(Self::new(embeddings, hidden_size)) + } + + fn forward(&self, indexes: &Tensor) -> Result { + let mut final_dims = indexes.dims().to_vec(); + final_dims.push(self.hidden_size); + let indexes = indexes.flatten_all()?; + let values = Tensor::embedding(&indexes, &self.embeddings)?; + let values = values.reshape(final_dims)?; + Ok(values) + } +} + +struct Linear { + weight: Tensor, + bias: Option, +} + +impl Linear { + fn load(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { + let weight = vb.get((size2, size1), &format!("{p}.weight"))?; + let bias = vb.get(size2, &format!("{p}.bias"))?; + Ok(Self { + weight, + bias: Some(bias), + }) + } + + fn load_no_bias(size1: usize, size2: usize, p: &str, vb: &VarBuilder) -> Result { + let weight = vb.get((size2, size1), &format!("{p}.weight"))?; + Ok(Self { weight, bias: None }) + } + + fn forward(&self, x: &Tensor) -> candle::Result { + let (bsize, _, _) = x.shape().r3()?; + let w = self.weight.broadcast_left(bsize)?.t()?; + let x = x.matmul(&w)?; + match &self.bias { + None => Ok(x), + Some(bias) => x.broadcast_add(bias), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct ConvConfig { + padding: usize, + stride: usize, +} + +impl Default for ConvConfig { + fn default() -> Self { + Self { + padding: 0, + stride: 1, + } + } +} + +struct Conv1D { + weight: Tensor, + bias: Option, + config: ConvConfig, +} + +impl Conv1D { + fn load( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: ConvConfig, + p: &str, + vb: &VarBuilder, + ) -> Result { + let weight = vb.get( + (out_channels, in_channels, kernel_size), + &format!("{p}.weight"), + )?; + let bias = vb.get(out_channels, &format!("{p}.bias"))?; + Ok(Self { + weight, + bias: Some(bias), + config, + }) + } + + fn load_no_bias( + in_channels: usize, + out_channels: usize, + kernel_size: usize, + config: ConvConfig, + p: &str, + vb: &VarBuilder, + ) -> Result { + let weight = vb.get( + (out_channels, in_channels, kernel_size), + &format!("{p}.weight"), + )?; + Ok(Self { + weight, + bias: None, + config, + }) + } + + fn forward(&self, x: &Tensor) -> Result { + let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?; + match &self.bias { + None => Ok(x), + Some(bias) => { + let b = bias.shape().r1()?; + let bias = bias.reshape((1, b, 1))?; + Ok(x.broadcast_add(&bias)?) + } + } + } +} + +struct Dropout { + pr: f64, +} + +impl Dropout { + fn new(pr: f64) -> Self { + Self { pr } + } + + fn forward(&self, x: &Tensor) -> Result { + // TODO + Ok(x.clone()) + } +} + +// This layer norm version handles both weight and bias so removes the mean. +struct LayerNorm { + weight: Tensor, + bias: Tensor, + eps: f64, +} + +impl LayerNorm { + fn load(size: usize, p: &str, vb: &VarBuilder) -> Result { + let weight = vb.get(size, &format!("{p}.weight"))?; + let bias = vb.get(size, &format!("{p}.bias"))?; + Ok(Self { + weight, + bias, + eps: 1e-5, + }) + } + + fn forward(&self, x: &Tensor) -> Result { + let (_bsize, _seq_len, hidden_size) = x.shape().r3()?; + let mean_x = (x.sum(&[2])? / hidden_size as f64)?; + let x = x.broadcast_sub(&mean_x)?; + let norm_x = ((&x * &x)?.sum(&[2])? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + let x = x_normed + .broadcast_mul(&self.weight)? + .broadcast_add(&self.bias)?; + Ok(x) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L62 +struct MultiHeadAttention { + query: Linear, + key: Linear, + value: Linear, + out: Linear, + n_head: usize, +} + +impl MultiHeadAttention { + fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result { + let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?; + let value = Linear::load(n_state, n_state, &format!("{p}.value"), vb)?; + let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.key"), vb)?; + let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?; + Ok(Self { + query, + key, + value, + out, + n_head, + }) + } + + fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { + let q = self.query.forward(x)?; + let k = self.key.forward(xa.unwrap_or(x))?; + let v = self.value.forward(xa.unwrap_or(x))?; + let wv = self.qkv_attention(&q, &k, &v, mask)?; + let out = self.out.forward(&wv)?; + Ok(out) + } + + fn reshape_head(&self, x: &Tensor) -> Result { + let (n_batch, n_ctx, n_state) = x.shape().r3()?; + let target_dims = &[n_batch, n_ctx, self.n_head, n_state / self.n_head]; + Ok(x.reshape(target_dims)?.transpose(1, 2)?) + } + + fn qkv_attention( + &self, + q: &Tensor, + k: &Tensor, + v: &Tensor, + mask: Option<&Tensor>, + ) -> Result { + let (_, n_ctx, n_state) = q.shape().r3()?; + let scale = ((n_state / self.n_head) as f64).powf(-0.25); + let q = (self.reshape_head(q)? * scale)?; + let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?; + let v = self.reshape_head(v)?.contiguous()?; + let mut qk = q.matmul(&k)?; + if let Some(mask) = mask { + let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?; + qk = qk.broadcast_add(&mask)? + } + let w = qk.softmax(qk.rank() - 1)?; + let wv = w.matmul(&v)?.transpose(1, 2)?.flatten(Some(2), None)?; + Ok(wv) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L111 +struct ResidualAttentionBlock { + attn: MultiHeadAttention, + attn_ln: LayerNorm, + cross_attn: Option<(MultiHeadAttention, LayerNorm)>, + mlp_linear1: Linear, + mlp_linear2: Linear, + mlp_ln: LayerNorm, +} + +impl ResidualAttentionBlock { + fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result { + let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?; + let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?; + let cross_attn = if ca { + let cross_attn = + MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?; + let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?; + Some((cross_attn, cross_attn_ln)) + } else { + None + }; + let n_mlp = n_state * 4; + let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?; + let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.mlp.2"), vb)?; + let mlp_ln = LayerNorm::load(n_state, &format!("{p}.mlp_ln"), vb)?; + Ok(Self { + attn, + attn_ln, + cross_attn, + mlp_linear1, + mlp_linear2, + mlp_ln, + }) + } + + fn forward(&self, x: &Tensor, xa: Option<&Tensor>, mask: Option<&Tensor>) -> Result { + let attn = self.attn.forward(&self.attn_ln.forward(x)?, None, mask)?; + let mut x = (x + attn)?; + if let Some((attn, ln)) = &self.cross_attn { + x = (&x + attn.forward(&ln.forward(&x)?, xa, None)?)?; + } + let mlp = self.mlp_linear2.forward( + &self + .mlp_linear1 + .forward(&self.mlp_ln.forward(&x)?)? + .gelu()?, + )?; + Ok((x + mlp)?) + } +} + +fn sinusoids(length: usize, channels: usize) -> Result { + let max_timescale = 10000f32; + let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32; + let inv_timescales: Vec<_> = (0..channels / 2) + .map(|i| (i as f32 * (-log_timescale_increment)).exp()) + .collect(); + let arange: Vec<_> = (0..length).map(|c| c as f32).collect(); + let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?; + let arange = Tensor::new(arange.as_slice(), &Device::Cpu)?.unsqueeze(1)?; + let sh = (length, channels / 2); + let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?; + let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?; + Ok(sincos) +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143 +pub struct AudioEncoder { + conv1: Conv1D, + conv2: Conv1D, + positional_embedding: Tensor, + blocks: Vec, + ln_post: LayerNorm, +} + +impl AudioEncoder { + fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + let n_state = cfg.n_audio_state; + let n_head = cfg.n_audio_head; + let n_ctx = cfg.n_audio_ctx; + let cfg1 = ConvConfig { + padding: 1, + stride: 1, + }; + let cfg2 = ConvConfig { + padding: 1, + stride: 2, + }; + let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?; + let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?; + let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?; + let blocks = (0..cfg.n_audio_layer) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb) + }) + .collect::>>()?; + let ln_post = LayerNorm::load(n_state, &format!("{p}.ln_post"), vb)?; + Ok(Self { + conv1, + conv2, + positional_embedding, + blocks, + ln_post, + }) + } + pub fn forward(&self, x: &Tensor) -> Result { + let x = self.conv1.forward(x)?.gelu()?; + let x = self.conv2.forward(&x)?.gelu()?; + let x = x.transpose(1, 2)?; + let mut x = x.broadcast_add(&self.positional_embedding)?; + for block in self.blocks.iter() { + x = block.forward(&x, None, None)? + } + let x = self.ln_post.forward(&x)?; + Ok(x) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L176 +pub struct TextDecoder { + token_embedding: Embedding, + positional_embedding: Tensor, + blocks: Vec, + ln: LayerNorm, + mask: Tensor, +} + +impl TextDecoder { + fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { + let n_state = cfg.n_text_state; + let n_head = cfg.n_text_head; + let n_ctx = cfg.n_text_ctx; + let token_embedding = + Embedding::load(cfg.n_vocab, n_state, &format!("{p}.token_embedding"), vb)?; + let positional_embedding = + vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?; + let blocks = (0..cfg.n_text_layer) + .map(|i| { + ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.blocks.{i}"), vb) + }) + .collect::>>()?; + let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?; + let mask: Vec<_> = (0..n_ctx) + .flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 })) + .collect(); + let mask = Tensor::from_vec(mask, (n_ctx, n_ctx), &vb.device)?; + + Ok(Self { + token_embedding, + positional_embedding, + blocks, + ln, + mask, + }) + } + + pub fn forward(&self, x: &Tensor, xa: &Tensor) -> Result { + let x_dims = x.dims(); + let last = x_dims[x_dims.len() - 1]; + let token_embedding = self.token_embedding.forward(x)?; + let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; + let mut x = token_embedding.broadcast_add(&positional_embedding)?; + for block in self.blocks.iter() { + x = block.forward(&x, Some(xa), Some(&self.mask))?; + } + let x = self.ln.forward(&x)?; + let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?; + let logits = x.matmul(&w.t()?)?; + Ok(logits) + } +} + +// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L221 +pub struct Whisper { + pub encoder: AudioEncoder, + pub decoder: TextDecoder, + pub config: Config, +} + +impl Whisper { + pub fn load(vb: &VarBuilder, config: Config) -> Result { + let encoder = AudioEncoder::load("encoder", vb, &config)?; + let decoder = TextDecoder::load("decoder", vb, &config)?; + Ok(Self { + encoder, + decoder, + config, + }) + } + + pub fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result { + let enc = self.encoder.forward(mel)?; + let dec = self.decoder.forward(tokens, &enc)?; + Ok(dec) + } +}