mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Adapting whisper for Hub use.
This commit is contained in:
@ -2,6 +2,7 @@
|
||||
// back when using RUST_LIB_BACKTRACE=1.
|
||||
use anyhow::Result;
|
||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub struct VarBuilder<'a> {
|
||||
@ -76,33 +77,33 @@ impl HiddenAct {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
pub n_mels: usize,
|
||||
pub n_audio_ctx: usize,
|
||||
pub n_audio_state: usize,
|
||||
pub n_audio_head: usize,
|
||||
pub n_audio_layer: usize,
|
||||
pub n_vocab: usize,
|
||||
pub n_text_ctx: usize,
|
||||
pub n_text_state: usize,
|
||||
pub n_text_head: usize,
|
||||
pub n_text_layer: usize,
|
||||
pub num_mel_bins: usize,
|
||||
pub max_source_positions: usize,
|
||||
pub d_model: usize,
|
||||
pub encoder_attention_heads: usize,
|
||||
pub encoder_layers: usize,
|
||||
pub vocab_size: usize,
|
||||
pub max_target_positions: usize,
|
||||
// pub n_text_state: usize,
|
||||
pub decoder_attention_heads: usize,
|
||||
pub decoder_layers: usize,
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn tiny_en() -> Self {
|
||||
Self {
|
||||
n_mels: 80,
|
||||
n_vocab: 51864,
|
||||
n_audio_ctx: 1500,
|
||||
n_audio_state: 384,
|
||||
n_audio_head: 6,
|
||||
n_audio_layer: 4,
|
||||
n_text_ctx: 448,
|
||||
n_text_state: 384,
|
||||
n_text_head: 6,
|
||||
n_text_layer: 4,
|
||||
num_mel_bins: 80,
|
||||
vocab_size: 51864,
|
||||
max_source_positions: 1500,
|
||||
d_model: 384,
|
||||
encoder_attention_heads: 6,
|
||||
encoder_layers: 4,
|
||||
max_target_positions: 448,
|
||||
// n_text_state: 384,
|
||||
decoder_attention_heads: 6,
|
||||
decoder_layers: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -297,10 +298,10 @@ struct MultiHeadAttention {
|
||||
|
||||
impl MultiHeadAttention {
|
||||
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?;
|
||||
let value = Linear::load(n_state, n_state, &format!("{p}.value"), vb)?;
|
||||
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.key"), vb)?;
|
||||
let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?;
|
||||
let query = Linear::load(n_state, n_state, &format!("{p}.q_proj"), vb)?;
|
||||
let value = Linear::load(n_state, n_state, &format!("{p}.v_proj"), vb)?;
|
||||
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
|
||||
let out = Linear::load(n_state, n_state, &format!("{p}.out_proj"), vb)?;
|
||||
Ok(Self {
|
||||
query,
|
||||
key,
|
||||
@ -360,20 +361,21 @@ struct ResidualAttentionBlock {
|
||||
|
||||
impl ResidualAttentionBlock {
|
||||
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?;
|
||||
let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?;
|
||||
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?;
|
||||
let attn_ln = LayerNorm::load(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
|
||||
let cross_attn = if ca {
|
||||
let cross_attn =
|
||||
MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?;
|
||||
let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?;
|
||||
MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?;
|
||||
let cross_attn_ln =
|
||||
LayerNorm::load(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
||||
Some((cross_attn, cross_attn_ln))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let n_mlp = n_state * 4;
|
||||
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?;
|
||||
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.mlp.2"), vb)?;
|
||||
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.mlp_ln"), vb)?;
|
||||
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
|
||||
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
|
||||
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.final_layer_norm"), vb)?;
|
||||
Ok(Self {
|
||||
attn,
|
||||
attn_ln,
|
||||
@ -426,9 +428,9 @@ pub struct AudioEncoder {
|
||||
|
||||
impl AudioEncoder {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let n_state = cfg.n_audio_state;
|
||||
let n_head = cfg.n_audio_head;
|
||||
let n_ctx = cfg.n_audio_ctx;
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.encoder_attention_heads;
|
||||
let n_ctx = cfg.max_source_positions;
|
||||
let cfg1 = ConvConfig {
|
||||
padding: 1,
|
||||
stride: 1,
|
||||
@ -437,15 +439,22 @@ impl AudioEncoder {
|
||||
padding: 1,
|
||||
stride: 2,
|
||||
};
|
||||
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
|
||||
let conv1 = Conv1D::load(
|
||||
cfg.num_mel_bins,
|
||||
n_state,
|
||||
3,
|
||||
cfg1,
|
||||
&format!("{p}.conv1"),
|
||||
vb,
|
||||
)?;
|
||||
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?;
|
||||
let blocks = (0..cfg.n_audio_layer)
|
||||
let blocks = (0..cfg.encoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln_post = LayerNorm::load(n_state, &format!("{p}.ln_post"), vb)?;
|
||||
let ln_post = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
conv2,
|
||||
@ -480,19 +489,19 @@ pub struct TextDecoder {
|
||||
|
||||
impl TextDecoder {
|
||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let n_state = cfg.n_text_state;
|
||||
let n_head = cfg.n_text_head;
|
||||
let n_ctx = cfg.n_text_ctx;
|
||||
let n_state = cfg.d_model;
|
||||
let n_head = cfg.decoder_attention_heads;
|
||||
let n_ctx = cfg.max_target_positions;
|
||||
let token_embedding =
|
||||
Embedding::load(cfg.n_vocab, n_state, &format!("{p}.token_embedding"), vb)?;
|
||||
Embedding::load(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?;
|
||||
let positional_embedding =
|
||||
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?;
|
||||
let blocks = (0..cfg.n_text_layer)
|
||||
vb.get((n_ctx, n_state), &format!("{p}.embed_positions.weight"))?;
|
||||
let blocks = (0..cfg.decoder_layers)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.blocks.{i}"), vb)
|
||||
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
|
||||
let ln = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||
let mask: Vec<_> = (0..n_ctx)
|
||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect();
|
||||
@ -532,8 +541,8 @@ pub struct Whisper {
|
||||
|
||||
impl Whisper {
|
||||
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
|
||||
let encoder = AudioEncoder::load("encoder", vb, &config)?;
|
||||
let decoder = TextDecoder::load("decoder", vb, &config)?;
|
||||
let encoder = AudioEncoder::load("model.encoder", vb, &config)?;
|
||||
let decoder = TextDecoder::load("model.decoder", vb, &config)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
|
Reference in New Issue
Block a user