Run the text decoding bit.

This commit is contained in:
laurent
2023-07-04 12:03:28 +01:00
parent 950b4af49e
commit aea090401d
2 changed files with 24 additions and 17 deletions

View File

@ -8,6 +8,6 @@ data = torch.load("tiny.en.pt")
weights = {} weights = {}
for k, v in data["model_state_dict"].items(): for k, v in data["model_state_dict"].items():
weights[k] = v.contiguous() weights[k] = v.contiguous()
print(k, v.shape) print(k, v.shape, v.dtype)
save_file(weights, "tiny.en.safetensors") save_file(weights, "tiny.en.safetensors")
print(data["dims"]) print(data["dims"])

View File

@ -3,7 +3,7 @@
// TODO: // TODO:
// - kv-cache support? // - kv-cache support?
use anyhow::{Error as E, Result}; use anyhow::Result;
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor}; use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use clap::Parser; use clap::Parser;
use std::collections::HashMap; use std::collections::HashMap;
@ -97,7 +97,7 @@ struct Config {
} }
impl Config { impl Config {
fn tiny() -> Self { fn tiny_en() -> Self {
Self { Self {
n_mels: 80, n_mels: 80,
n_vocab: 51864, n_vocab: 51864,
@ -302,8 +302,8 @@ struct MultiHeadAttention {
impl MultiHeadAttention { impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> { fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?; let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?;
let value = Linear::load_no_bias(n_state, n_state, &format!("{p}.value"), vb)?; let value = Linear::load(n_state, n_state, &format!("{p}.value"), vb)?;
let key = Linear::load(n_state, n_state, &format!("{p}.key"), vb)?; let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.key"), vb)?;
let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?; let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?;
Ok(Self { Ok(Self {
query, query,
@ -500,12 +500,13 @@ impl TextDecoder {
let last = x_dims[x_dims.len() - 1]; let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?; let token_embedding = self.token_embedding.forward(x)?;
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?; let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
let mut x = (token_embedding + positional_embedding)?; let mut x = token_embedding.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() { for block in self.blocks.iter() {
x = block.forward(&x, Some(xa))? x = block.forward(&x, Some(xa))?
} }
let x = self.ln.forward(&x)?; let x = self.ln.forward(&x)?;
let logits = x.matmul(&self.token_embedding.embeddings.t()?)?; let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;
let logits = x.matmul(&w.t()?)?;
Ok(logits) Ok(logits)
} }
} }
@ -517,9 +518,9 @@ struct Whisper {
} }
impl Whisper { impl Whisper {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { fn load(vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let encoder = AudioEncoder::load(&format!("{p}.encoder"), vb, cfg)?; let encoder = AudioEncoder::load("encoder", vb, cfg)?;
let decoder = TextDecoder::load(&format!("{p}.decoder"), vb, cfg)?; let decoder = TextDecoder::load("decoder", vb, cfg)?;
Ok(Self { encoder, decoder }) Ok(Self { encoder, decoder })
} }
fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> { fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
@ -537,15 +538,13 @@ struct Args {
cpu: bool, cpu: bool,
#[arg(long)] #[arg(long)]
tokenizer_config: String, weights: String,
#[arg(long)] #[arg(long)]
weights: String, input: String,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse(); let args = Args::parse();
let device = if args.cpu { let device = if args.cpu {
Device::Cpu Device::Cpu
@ -553,11 +552,19 @@ fn main() -> Result<()> {
Device::new_cuda(0)? Device::new_cuda(0)?
}; };
let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?; let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
let _tokenizer = tokenizer.with_padding(None).with_truncation(None); let input = input.deserialize()?;
let x = input.tensor("x", &device)?.to_dtype(DType::U32)?;
let xa = input.tensor("xa", &device)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? }; let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
let weights = weights.deserialize()?; let weights = weights.deserialize()?;
let _vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device); let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
let cfg = Config::tiny_en();
let model = Whisper::load(&vb, &cfg)?;
let logits = model.decoder.forward(&x, &xa)?;
println!("{logits}");
println!("python logits: {}", input.tensor("logits", &device)?);
Ok(()) Ok(())
} }