mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 02:16:37 +00:00
Run the text decoding bit.
This commit is contained in:
@ -8,6 +8,6 @@ data = torch.load("tiny.en.pt")
|
|||||||
weights = {}
|
weights = {}
|
||||||
for k, v in data["model_state_dict"].items():
|
for k, v in data["model_state_dict"].items():
|
||||||
weights[k] = v.contiguous()
|
weights[k] = v.contiguous()
|
||||||
print(k, v.shape)
|
print(k, v.shape, v.dtype)
|
||||||
save_file(weights, "tiny.en.safetensors")
|
save_file(weights, "tiny.en.safetensors")
|
||||||
print(data["dims"])
|
print(data["dims"])
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
// TODO:
|
// TODO:
|
||||||
// - kv-cache support?
|
// - kv-cache support?
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::Result;
|
||||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
@ -97,7 +97,7 @@ struct Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
fn tiny() -> Self {
|
fn tiny_en() -> Self {
|
||||||
Self {
|
Self {
|
||||||
n_mels: 80,
|
n_mels: 80,
|
||||||
n_vocab: 51864,
|
n_vocab: 51864,
|
||||||
@ -302,8 +302,8 @@ struct MultiHeadAttention {
|
|||||||
impl MultiHeadAttention {
|
impl MultiHeadAttention {
|
||||||
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?;
|
let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?;
|
||||||
let value = Linear::load_no_bias(n_state, n_state, &format!("{p}.value"), vb)?;
|
let value = Linear::load(n_state, n_state, &format!("{p}.value"), vb)?;
|
||||||
let key = Linear::load(n_state, n_state, &format!("{p}.key"), vb)?;
|
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.key"), vb)?;
|
||||||
let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?;
|
let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
query,
|
query,
|
||||||
@ -500,12 +500,13 @@ impl TextDecoder {
|
|||||||
let last = x_dims[x_dims.len() - 1];
|
let last = x_dims[x_dims.len() - 1];
|
||||||
let token_embedding = self.token_embedding.forward(x)?;
|
let token_embedding = self.token_embedding.forward(x)?;
|
||||||
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
let positional_embedding = self.positional_embedding.narrow(0, 0, last)?;
|
||||||
let mut x = (token_embedding + positional_embedding)?;
|
let mut x = token_embedding.broadcast_add(&positional_embedding)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
x = block.forward(&x, Some(xa))?
|
x = block.forward(&x, Some(xa))?
|
||||||
}
|
}
|
||||||
let x = self.ln.forward(&x)?;
|
let x = self.ln.forward(&x)?;
|
||||||
let logits = x.matmul(&self.token_embedding.embeddings.t()?)?;
|
let w = self.token_embedding.embeddings.broadcast_left(x_dims[0])?;
|
||||||
|
let logits = x.matmul(&w.t()?)?;
|
||||||
Ok(logits)
|
Ok(logits)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -517,9 +518,9 @@ struct Whisper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Whisper {
|
impl Whisper {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let encoder = AudioEncoder::load(&format!("{p}.encoder"), vb, cfg)?;
|
let encoder = AudioEncoder::load("encoder", vb, cfg)?;
|
||||||
let decoder = TextDecoder::load(&format!("{p}.decoder"), vb, cfg)?;
|
let decoder = TextDecoder::load("decoder", vb, cfg)?;
|
||||||
Ok(Self { encoder, decoder })
|
Ok(Self { encoder, decoder })
|
||||||
}
|
}
|
||||||
fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
|
fn forward(&self, mel: &Tensor, tokens: &Tensor) -> Result<Tensor> {
|
||||||
@ -537,15 +538,13 @@ struct Args {
|
|||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tokenizer_config: String,
|
weights: String,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
weights: String,
|
input: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
use tokenizers::Tokenizer;
|
|
||||||
|
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let device = if args.cpu {
|
let device = if args.cpu {
|
||||||
Device::Cpu
|
Device::Cpu
|
||||||
@ -553,11 +552,19 @@ fn main() -> Result<()> {
|
|||||||
Device::new_cuda(0)?
|
Device::new_cuda(0)?
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
|
let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
|
||||||
let _tokenizer = tokenizer.with_padding(None).with_truncation(None);
|
let input = input.deserialize()?;
|
||||||
|
let x = input.tensor("x", &device)?.to_dtype(DType::U32)?;
|
||||||
|
let xa = input.tensor("xa", &device)?;
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let _vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device.clone());
|
||||||
|
let cfg = Config::tiny_en();
|
||||||
|
|
||||||
|
let model = Whisper::load(&vb, &cfg)?;
|
||||||
|
let logits = model.decoder.forward(&x, &xa)?;
|
||||||
|
println!("{logits}");
|
||||||
|
println!("python logits: {}", input.tensor("logits", &device)?);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user