mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Adapting whisper for Hub use.
This commit is contained in:
@ -7,6 +7,7 @@
|
|||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
|
use candle_hub::{api::Api, Repo, RepoType};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
use rand::{distributions::Distribution, SeedableRng};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
@ -68,7 +69,7 @@ impl Decode {
|
|||||||
let model = &self.model;
|
let model = &self.model;
|
||||||
let audio_features = model.encoder.forward(mel)?;
|
let audio_features = model.encoder.forward(mel)?;
|
||||||
println!("audio features: {:?}", audio_features.dims());
|
println!("audio features: {:?}", audio_features.dims());
|
||||||
let sample_len = model.config.n_text_ctx / 2;
|
let sample_len = model.config.max_target_positions / 2;
|
||||||
let mut sum_logprob = 0f64;
|
let mut sum_logprob = 0f64;
|
||||||
let mut no_speech_prob = f64::NAN;
|
let mut no_speech_prob = f64::NAN;
|
||||||
let mut tokens = vec![SOT_TOKEN];
|
let mut tokens = vec![SOT_TOKEN];
|
||||||
@ -112,7 +113,7 @@ impl Decode {
|
|||||||
.softmax(logits.rank() - 1)?
|
.softmax(logits.rank() - 1)?
|
||||||
.get(next_token as usize)?
|
.get(next_token as usize)?
|
||||||
.to_scalar::<f32>()? as f64;
|
.to_scalar::<f32>()? as f64;
|
||||||
if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx {
|
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sum_logprob += prob.ln();
|
sum_logprob += prob.ln();
|
||||||
@ -165,15 +166,16 @@ struct Args {
|
|||||||
cpu: bool,
|
cpu: bool,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
weights: String,
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
/// The model to use, check out available models: https://huggingface.co/models?search=whisper
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
/// The input to be processed, in wav formats.
|
/// The input to be processed, in wav formats.
|
||||||
#[arg(long)]
|
#[arg(long, default_value = "jfk.wav")]
|
||||||
input: String,
|
input: String,
|
||||||
|
|
||||||
#[arg(long)]
|
|
||||||
tokenizer_config: String,
|
|
||||||
|
|
||||||
/// The seed to use when generating random samples.
|
/// The seed to use when generating random samples.
|
||||||
#[arg(long, default_value_t = 299792458)]
|
#[arg(long, default_value_t = 299792458)]
|
||||||
seed: u64,
|
seed: u64,
|
||||||
@ -186,7 +188,8 @@ struct Args {
|
|||||||
filters: String,
|
filters: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let device = if args.cpu {
|
let device = if args.cpu {
|
||||||
Device::Cpu
|
Device::Cpu
|
||||||
@ -195,7 +198,35 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
let rng = rand::rngs::StdRng::seed_from_u64(args.seed);
|
let rng = rand::rngs::StdRng::seed_from_u64(args.seed);
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
|
let default_model = "openai/whisper-tiny.en".to_string();
|
||||||
|
let path = std::path::PathBuf::from(default_model.clone());
|
||||||
|
let default_revision = "refs/pr/15".to_string();
|
||||||
|
let (model_id, revision) = match (args.model_id, args.revision) {
|
||||||
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
|
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||||
|
(None, Some(revision)) => (default_model, revision),
|
||||||
|
(None, None) => (default_model, default_revision),
|
||||||
|
};
|
||||||
|
|
||||||
|
let (config_filename, tokenizer_filename, weights_filename) = if path.exists() {
|
||||||
|
let mut config_filename = path.clone();
|
||||||
|
config_filename.push("config.json");
|
||||||
|
let mut tokenizer_filename = path.clone();
|
||||||
|
tokenizer_filename.push("tokenizer.json");
|
||||||
|
let mut model_filename = path.clone();
|
||||||
|
model_filename.push("model.safetensors");
|
||||||
|
(config_filename, tokenizer_filename, model_filename)
|
||||||
|
} else {
|
||||||
|
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||||
|
let api = Api::new()?;
|
||||||
|
(
|
||||||
|
api.get(&repo, "config.json").await?,
|
||||||
|
api.get(&repo, "tokenizer.json").await?,
|
||||||
|
api.get(&repo, "model.safetensors").await?,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
println!("Weights {weights_filename:?}");
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
|
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
|
||||||
let mel_filters = mel_filters.deserialize()?;
|
let mel_filters = mel_filters.deserialize()?;
|
||||||
@ -220,10 +251,11 @@ fn main() -> Result<()> {
|
|||||||
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
|
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
|
||||||
println!("loaded mel: {:?}", mel.dims());
|
println!("loaded mel: {:?}", mel.dims());
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
|
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
|
||||||
let model = Whisper::load(&vb, Config::tiny_en())?;
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||||
|
let model = Whisper::load(&vb, config)?;
|
||||||
let mut dc = Decode {
|
let mut dc = Decode {
|
||||||
model,
|
model,
|
||||||
rng,
|
rng,
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
// back when using RUST_LIB_BACKTRACE=1.
|
// back when using RUST_LIB_BACKTRACE=1.
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
use candle::{safetensors::SafeTensors, DType, Device, Shape, Tensor};
|
||||||
|
use serde::Deserialize;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
pub struct VarBuilder<'a> {
|
pub struct VarBuilder<'a> {
|
||||||
@ -76,33 +77,33 @@ impl HiddenAct {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub n_mels: usize,
|
pub num_mel_bins: usize,
|
||||||
pub n_audio_ctx: usize,
|
pub max_source_positions: usize,
|
||||||
pub n_audio_state: usize,
|
pub d_model: usize,
|
||||||
pub n_audio_head: usize,
|
pub encoder_attention_heads: usize,
|
||||||
pub n_audio_layer: usize,
|
pub encoder_layers: usize,
|
||||||
pub n_vocab: usize,
|
pub vocab_size: usize,
|
||||||
pub n_text_ctx: usize,
|
pub max_target_positions: usize,
|
||||||
pub n_text_state: usize,
|
// pub n_text_state: usize,
|
||||||
pub n_text_head: usize,
|
pub decoder_attention_heads: usize,
|
||||||
pub n_text_layer: usize,
|
pub decoder_layers: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn tiny_en() -> Self {
|
pub fn tiny_en() -> Self {
|
||||||
Self {
|
Self {
|
||||||
n_mels: 80,
|
num_mel_bins: 80,
|
||||||
n_vocab: 51864,
|
vocab_size: 51864,
|
||||||
n_audio_ctx: 1500,
|
max_source_positions: 1500,
|
||||||
n_audio_state: 384,
|
d_model: 384,
|
||||||
n_audio_head: 6,
|
encoder_attention_heads: 6,
|
||||||
n_audio_layer: 4,
|
encoder_layers: 4,
|
||||||
n_text_ctx: 448,
|
max_target_positions: 448,
|
||||||
n_text_state: 384,
|
// n_text_state: 384,
|
||||||
n_text_head: 6,
|
decoder_attention_heads: 6,
|
||||||
n_text_layer: 4,
|
decoder_layers: 4,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -297,10 +298,10 @@ struct MultiHeadAttention {
|
|||||||
|
|
||||||
impl MultiHeadAttention {
|
impl MultiHeadAttention {
|
||||||
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
fn load(n_state: usize, n_head: usize, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
let query = Linear::load(n_state, n_state, &format!("{p}.query"), vb)?;
|
let query = Linear::load(n_state, n_state, &format!("{p}.q_proj"), vb)?;
|
||||||
let value = Linear::load(n_state, n_state, &format!("{p}.value"), vb)?;
|
let value = Linear::load(n_state, n_state, &format!("{p}.v_proj"), vb)?;
|
||||||
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.key"), vb)?;
|
let key = Linear::load_no_bias(n_state, n_state, &format!("{p}.k_proj"), vb)?;
|
||||||
let out = Linear::load(n_state, n_state, &format!("{p}.out"), vb)?;
|
let out = Linear::load(n_state, n_state, &format!("{p}.out_proj"), vb)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
query,
|
query,
|
||||||
key,
|
key,
|
||||||
@ -360,20 +361,21 @@ struct ResidualAttentionBlock {
|
|||||||
|
|
||||||
impl ResidualAttentionBlock {
|
impl ResidualAttentionBlock {
|
||||||
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
fn load(n_state: usize, n_head: usize, ca: bool, p: &str, vb: &VarBuilder) -> Result<Self> {
|
||||||
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.attn"), vb)?;
|
let attn = MultiHeadAttention::load(n_state, n_head, &format!("{p}.self_attn"), vb)?;
|
||||||
let attn_ln = LayerNorm::load(n_state, &format!("{p}.attn_ln"), vb)?;
|
let attn_ln = LayerNorm::load(n_state, &format!("{p}.self_attn_layer_norm"), vb)?;
|
||||||
let cross_attn = if ca {
|
let cross_attn = if ca {
|
||||||
let cross_attn =
|
let cross_attn =
|
||||||
MultiHeadAttention::load(n_state, n_head, &format!("{p}.cross_attn"), vb)?;
|
MultiHeadAttention::load(n_state, n_head, &format!("{p}.encoder_attn"), vb)?;
|
||||||
let cross_attn_ln = LayerNorm::load(n_state, &format!("{p}.cross_attn_ln"), vb)?;
|
let cross_attn_ln =
|
||||||
|
LayerNorm::load(n_state, &format!("{p}.encoder_attn_layer_norm"), vb)?;
|
||||||
Some((cross_attn, cross_attn_ln))
|
Some((cross_attn, cross_attn_ln))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let n_mlp = n_state * 4;
|
let n_mlp = n_state * 4;
|
||||||
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.mlp.0"), vb)?;
|
let mlp_linear1 = Linear::load(n_state, n_mlp, &format!("{p}.fc1"), vb)?;
|
||||||
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.mlp.2"), vb)?;
|
let mlp_linear2 = Linear::load(n_mlp, n_state, &format!("{p}.fc2"), vb)?;
|
||||||
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.mlp_ln"), vb)?;
|
let mlp_ln = LayerNorm::load(n_state, &format!("{p}.final_layer_norm"), vb)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
attn,
|
attn,
|
||||||
attn_ln,
|
attn_ln,
|
||||||
@ -426,9 +428,9 @@ pub struct AudioEncoder {
|
|||||||
|
|
||||||
impl AudioEncoder {
|
impl AudioEncoder {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let n_state = cfg.n_audio_state;
|
let n_state = cfg.d_model;
|
||||||
let n_head = cfg.n_audio_head;
|
let n_head = cfg.encoder_attention_heads;
|
||||||
let n_ctx = cfg.n_audio_ctx;
|
let n_ctx = cfg.max_source_positions;
|
||||||
let cfg1 = ConvConfig {
|
let cfg1 = ConvConfig {
|
||||||
padding: 1,
|
padding: 1,
|
||||||
stride: 1,
|
stride: 1,
|
||||||
@ -437,15 +439,22 @@ impl AudioEncoder {
|
|||||||
padding: 1,
|
padding: 1,
|
||||||
stride: 2,
|
stride: 2,
|
||||||
};
|
};
|
||||||
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
|
let conv1 = Conv1D::load(
|
||||||
|
cfg.num_mel_bins,
|
||||||
|
n_state,
|
||||||
|
3,
|
||||||
|
cfg1,
|
||||||
|
&format!("{p}.conv1"),
|
||||||
|
vb,
|
||||||
|
)?;
|
||||||
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
||||||
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?;
|
let positional_embedding = sinusoids(n_ctx, n_state)?.to_device(&vb.device)?;
|
||||||
let blocks = (0..cfg.n_audio_layer)
|
let blocks = (0..cfg.encoder_layers)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.layers.{i}"), vb)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln_post = LayerNorm::load(n_state, &format!("{p}.ln_post"), vb)?;
|
let ln_post = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conv1,
|
conv1,
|
||||||
conv2,
|
conv2,
|
||||||
@ -480,19 +489,19 @@ pub struct TextDecoder {
|
|||||||
|
|
||||||
impl TextDecoder {
|
impl TextDecoder {
|
||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let n_state = cfg.n_text_state;
|
let n_state = cfg.d_model;
|
||||||
let n_head = cfg.n_text_head;
|
let n_head = cfg.decoder_attention_heads;
|
||||||
let n_ctx = cfg.n_text_ctx;
|
let n_ctx = cfg.max_target_positions;
|
||||||
let token_embedding =
|
let token_embedding =
|
||||||
Embedding::load(cfg.n_vocab, n_state, &format!("{p}.token_embedding"), vb)?;
|
Embedding::load(cfg.vocab_size, n_state, &format!("{p}.embed_tokens"), vb)?;
|
||||||
let positional_embedding =
|
let positional_embedding =
|
||||||
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?;
|
vb.get((n_ctx, n_state), &format!("{p}.embed_positions.weight"))?;
|
||||||
let blocks = (0..cfg.n_text_layer)
|
let blocks = (0..cfg.decoder_layers)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.blocks.{i}"), vb)
|
ResidualAttentionBlock::load(n_state, n_head, true, &format!("{p}.layers.{i}"), vb)
|
||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
|
let ln = LayerNorm::load(n_state, &format!("{p}.layer_norm"), vb)?;
|
||||||
let mask: Vec<_> = (0..n_ctx)
|
let mask: Vec<_> = (0..n_ctx)
|
||||||
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
.flat_map(|i| (0..n_ctx).map(move |j| if j > i { f32::NEG_INFINITY } else { 0f32 }))
|
||||||
.collect();
|
.collect();
|
||||||
@ -532,8 +541,8 @@ pub struct Whisper {
|
|||||||
|
|
||||||
impl Whisper {
|
impl Whisper {
|
||||||
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
|
pub fn load(vb: &VarBuilder, config: Config) -> Result<Self> {
|
||||||
let encoder = AudioEncoder::load("encoder", vb, &config)?;
|
let encoder = AudioEncoder::load("model.encoder", vb, &config)?;
|
||||||
let decoder = TextDecoder::load("decoder", vb, &config)?;
|
let decoder = TextDecoder::load("model.decoder", vb, &config)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
encoder,
|
encoder,
|
||||||
decoder,
|
decoder,
|
||||||
|
Reference in New Issue
Block a user