Run the text decoding bit.

This commit is contained in:
laurent
2023-07-04 12:03:28 +01:00
parent 950b4af49e
commit aea090401d
2 changed files with 24 additions and 17 deletions

View File

@ -8,6 +8,6 @@ data = torch.load("tiny.en.pt")
weights = {}
for k, v in data["model_state_dict"].items():
weights[k] = v.contiguous()
print(k, v.shape)
print(k, v.shape, v.dtype)
save_file(weights, "tiny.en.safetensors")
print(data["dims"])

View File

@ -3,7 +3,7 @@
// TODO:
// - kv-cache support?
use anyhow::{Error as E, Result};
use anyhow::Result;
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
use clap::Parser;
use std::collections::HashMap;
@ -97,7 +97,7 @@ struct Config {
}
impl Config {
fn tiny() -> Self {
fn tiny_en() -> Self {
Self {
n_mels: 80,
n_vocab: 51864,
@ -302,8 +302,8 @@ struct MultiHeadAttention {
impl MultiHeadAttention {
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?;
let value = Linear::load_no_bias(n_state, n_state, &format!("{p}.value"), vb)?;
let key = Linear::load(n_state, n_state, &format!("{p}.key"), vb)?;
let value = Linear::load(n_state, n_state, &format!("{p}.value"), vb)?;
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.key"), vb)?;
let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?;
Ok(Self {
query,
@ -500,12 +500,13 @@ impl TextDecoder {
let last = x_dims[x_dims.len() - 1];
let token_embedding = self.token_embedding.forward(x)?;
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
let mut x = (token_embedding + positional_embedding)?;
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
for block in self.blocks.iter() {
x = block.forward(&x, Some(xa))?
}
let x = self.ln.forward(&x)?;
let logits = x.matmul(&self.token_embedding.embeddings.t()?)?;
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;
let logits = x.matmul(&w.t()?)?;
Ok(logits)
}
}
@ -517,9 +518,9 @@ struct Whisper {
}
impl Whisper {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let encoder = AudioEncoder::load(&format!("{p}.encoder"), vb, cfg)?;
let decoder = TextDecoder::load(&format!("{p}.decoder"), vb, cfg)?;
fn load(vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let encoder = AudioEncoder::load("encoder", vb, cfg)?;
let decoder = TextDecoder::load("decoder", vb, cfg)?;
Ok(Self { encoder, decoder })
}
fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
@ -537,15 +538,13 @@ struct Args {
cpu: bool,
#[arg(long)]
tokenizer_config: String,
weights: String,
#[arg(long)]
weights: String,
input: String,
}
fn main() -> Result<()> {
use tokenizers::Tokenizer;
let args = Args::parse();
let device = if args.cpu {
Device::Cpu
@ -553,11 +552,19 @@ fn main() -> Result<()> {
Device::new_cuda(0)?
};
let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
let input = input.deserialize()?;
let x = input.tensor("x", &device)?.to_dtype(DType::U32)?;
let xa = input.tensor("xa", &device)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
let weights = weights.deserialize()?;
let _vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
let cfg = Config::tiny_en();
let model = Whisper::load(&vb, &cfg)?;
let logits = model.decoder.forward(&x, &xa)?;
println!("{logits}");
println!("python logits: {}", input.tensor("logits", &device)?);
Ok(())
}