diff --git a/candle-examples/examples/whisper/extract_weights.py b/candle-examples/examples/whisper/extract_weights.py index d6ccffc6..65602703 100644 --- a/candle-examples/examples/whisper/extract_weights.py +++ b/candle-examples/examples/whisper/extract_weights.py @@ -8,6 +8,6 @@ data = torch.load("tiny.en.pt") weights = {} for k, v in data["model_state_dict"].items(): weights[k] = v.contiguous() - print(k, v.shape) + print(k, v.shape, v.dtype) save_file(weights, "tiny.en.safetensors") print(data["dims"]) diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index a380d30e..c3a2769f 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -3,7 +3,7 @@ // TODO: // - kv-cache support? -use anyhow::{Error as E, Result}; +use anyhow::Result; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use clap::Parser; use std::collections::HashMap; @@ -97,7 +97,7 @@ struct Config { } impl Config { - fn tiny() -> Self { + fn tiny_en() -> Self { Self { n_mels: 80, n_vocab: 51864, @@ -302,8 +302,8 @@ struct MultiHeadAttention { impl MultiHeadAttention { fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result { let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?; - let value = Linear::load_no_bias(n_state, n_state, &format!("{p}.value"), vb)?; - let key = Linear::load(n_state, n_state, &format!("{p}.key"), vb)?; + let value = Linear::load(n_state, n_state, &format!("{p}.value"), vb)?; + let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.key"), vb)?; let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?; Ok(Self { query, @@ -500,12 +500,13 @@ impl TextDecoder { let last = x_dims[x_dims.len() - 1]; let token_embedding = self.token_embedding.forward(x)?; let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; - let mut x = (token_embedding + positional_embedding)?; + let mut x = token_embedding.broadcast_add(&positional_embedding)?; for block in self.blocks.iter() { x = block.forward(&x, Some(xa))? } let x = self.ln.forward(&x)?; - let logits = x.matmul(&self.token_embedding.embeddings.t()?)?; + let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?; + let logits = x.matmul(&w.t()?)?; Ok(logits) } } @@ -517,9 +518,9 @@ struct Whisper { } impl Whisper { - fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result { - let encoder = AudioEncoder::load(&format!("{p}.encoder"), vb, cfg)?; - let decoder = TextDecoder::load(&format!("{p}.decoder"), vb, cfg)?; + fn load(vb: &VarBuilder, cfg: &Config) -> Result { + let encoder = AudioEncoder::load("encoder", vb, cfg)?; + let decoder = TextDecoder::load("decoder", vb, cfg)?; Ok(Self { encoder, decoder }) } fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result { @@ -537,15 +538,13 @@ struct Args { cpu: bool, #[arg(long)] - tokenizer_config: String, + weights: String, #[arg(long)] - weights: String, + input: String, } fn main() -> Result<()> { - use tokenizers::Tokenizer; - let args = Args::parse(); let device = if args.cpu { Device::Cpu @@ -553,11 +552,19 @@ fn main() -> Result<()> { Device::new_cuda(0)? }; - let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?; - let _tokenizer = tokenizer.with_padding(None).with_truncation(None); + let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? }; + let input = input.deserialize()?; + let x = input.tensor("x", &device)?.to_dtype(DType::U32)?; + let xa = input.tensor("xa", &device)?; let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; let weights = weights.deserialize()?; - let _vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device); + let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone()); + let cfg = Config::tiny_en(); + + let model = Whisper::load(&vb, &cfg)?; + let logits = model.decoder.forward(&x, &xa)?; + println!("{logits}"); + println!("python logits: {}", input.tensor("logits", &device)?); Ok(()) }