mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 09:57:10 +00:00
Run the text decoding bit.
This commit is contained in:
@ -8,6 +8,6 @@ data = torch.load("tiny.en.pt")
|
||||
weights = {}
|
||||
for k, v in data["model_state_dict"].items():
|
||||
weights[k] = v.contiguous()
|
||||
print(k, v.shape)
|
||||
print(k, v.shape, v.dtype)
|
||||
save_file(weights, "tiny.en.safetensors")
|
||||
print(data["dims"])
|
||||
|
@ -3,7 +3,7 @@
|
||||
// TODO:
|
||||
// - kv-cache support?
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use anyhow::Result;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use clap::Parser;
|
||||
use std::collections::HashMap;
|
||||
@ -97,7 +97,7 @@ struct Config {
|
||||
}
|
||||
|
||||
impl Config {
|
||||
fn tiny() -> Self {
|
||||
fn tiny_en() -> Self {
|
||||
Self {
|
||||
n_mels: 80,
|
||||
n_vocab: 51864,
|
||||
@ -302,8 +302,8 @@ struct MultiHeadAttention {
|
||||
impl MultiHeadAttention {
|
||||
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?;
|
||||
let value = Linear::load_no_bias(n_state, n_state, &format!("{p}.value"), vb)?;
|
||||
let key = Linear::load(n_state, n_state, &format!("{p}.key"), vb)?;
|
||||
let value = Linear::load(n_state, n_state, &format!("{p}.value"), vb)?;
|
||||
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.key"), vb)?;
|
||||
let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?;
|
||||
Ok(Self {
|
||||
query,
|
||||
@ -500,12 +500,13 @@ impl TextDecoder {
|
||||
let last = x_dims[x_dims.len() - 1];
|
||||
let token_embedding = self.token_embedding.forward(x)?;
|
||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||
let mut x = (token_embedding + positional_embedding)?;
|
||||
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||
for block in self.blocks.iter() {
|
||||
x = block.forward(&x, Some(xa))?
|
||||
}
|
||||
let x = self.ln.forward(&x)?;
|
||||
let logits = x.matmul(&self.token_embedding.embeddings.t()?)?;
|
||||
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;
|
||||
let logits = x.matmul(&w.t()?)?;
|
||||
Ok(logits)
|
||||
}
|
||||
}
|
||||
@ -517,9 +518,9 @@ struct Whisper {
|
||||
}
|
||||
|
||||
impl Whisper {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let encoder = AudioEncoder::load(&format!("{p}.encoder"), vb, cfg)?;
|
||||
let decoder = TextDecoder::load(&format!("{p}.decoder"), vb, cfg)?;
|
||||
fn load(vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let encoder = AudioEncoder::load("encoder", vb, cfg)?;
|
||||
let decoder = TextDecoder::load("decoder", vb, cfg)?;
|
||||
Ok(Self { encoder, decoder })
|
||||
}
|
||||
fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
|
||||
@ -537,15 +538,13 @@ struct Args {
|
||||
cpu: bool,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_config: String,
|
||||
weights: String,
|
||||
|
||||
#[arg(long)]
|
||||
weights: String,
|
||||
input: String,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
let args = Args::parse();
|
||||
let device = if args.cpu {
|
||||
Device::Cpu
|
||||
@ -553,11 +552,19 @@ fn main() -> Result<()> {
|
||||
Device::new_cuda(0)?
|
||||
};
|
||||
|
||||
let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
|
||||
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
||||
let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
|
||||
let input = input.deserialize()?;
|
||||
let x = input.tensor("x", &device)?.to_dtype(DType::U32)?;
|
||||
let xa = input.tensor("xa", &device)?;
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let _vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
|
||||
let cfg = Config::tiny_en();
|
||||
|
||||
let model = Whisper::load(&vb, &cfg)?;
|
||||
let logits = model.decoder.forward(&x, &xa)?;
|
||||
println!("{logits}");
|
||||
println!("python logits: {}", input.tensor("logits", &device)?);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user