Adapting whisper for Hub use.

This commit is contained in:
Nicolas Patry
2023-07-05 14:35:27 +00:00
parent 4e80319147
commit e85573a4bd
2 changed files with 101 additions and 60 deletions

View File

@ -7,6 +7,7 @@
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_hub::{api::Api, Repo, RepoType};
use clap::Parser;
use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer;
@ -68,7 +69,7 @@ impl Decode {
let model = &self.model;
let audio_features = model.encoder.forward(mel)?;
println!("audio features: {:?}", audio_features.dims());
let sample_len = model.config.n_text_ctx / 2;
let sample_len = model.config.max_target_positions / 2;
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![SOT_TOKEN];
@ -112,7 +113,7 @@ impl Decode {
.softmax(logits.rank() - 1)?
.get(next_token as usize)?
.to_scalar::<f32>()? as f64;
if next_token == EOT_TOKEN || tokens.len() > model.config.n_text_ctx {
if next_token == EOT_TOKEN || tokens.len() > model.config.max_target_positions {
break;
}
sum_logprob += prob.ln();
@ -165,15 +166,16 @@ struct Args {
cpu: bool,
#[arg(long)]
weights: String,
model_id: Option<String>,
/// The model to use, check out available models: https://huggingface.co/models?search=whisper
#[arg(long)]
revision: Option<String>,
/// The input to be processed, in wav formats.
#[arg(long)]
#[arg(long, default_value = "jfk.wav")]
input: String,
#[arg(long)]
tokenizer_config: String,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
@ -186,7 +188,8 @@ struct Args {
filters: String,
}
fn main() -> Result<()> {
#[tokio::main]
async fn main() -> Result<()> {
let args = Args::parse();
let device = if args.cpu {
Device::Cpu
@ -195,7 +198,35 @@ fn main() -> Result<()> {
};
let rng = rand::rngs::StdRng::seed_from_u64(args.seed);
let tokenizer = Tokenizer::from_file(args.tokenizer_config).map_err(E::msg)?;
let default_model = "openai/whisper-tiny.en".to_string();
let path = std::path::PathBuf::from(default_model.clone());
let default_revision = "refs/pr/15".to_string();
let (model_id, revision) = match (args.model_id, args.revision) {
(Some(model_id), Some(revision)) => (model_id, revision),
(Some(model_id), None) => (model_id, "main".to_string()),
(None, Some(revision)) => (default_model, revision),
(None, None) => (default_model, default_revision),
};
let (config_filename, tokenizer_filename, weights_filename) = if path.exists() {
let mut config_filename = path.clone();
config_filename.push("config.json");
let mut tokenizer_filename = path.clone();
tokenizer_filename.push("tokenizer.json");
let mut model_filename = path.clone();
model_filename.push("model.safetensors");
(config_filename, tokenizer_filename, model_filename)
} else {
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let api = Api::new()?;
(
api.get(&repo, "config.json").await?,
api.get(&repo, "tokenizer.json").await?,
api.get(&repo, "model.safetensors").await?,
)
};
println!("Weights {weights_filename:?}");
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
let mel_filters = mel_filters.deserialize()?;
@ -220,10 +251,11 @@ fn main() -> Result<()> {
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, device);
let model = Whisper::load(&vb, Config::tiny_en())?;
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let model = Whisper::load(&vb, config)?;
let mut dc = Decode {
model,
rng,