From d3f05eae8c4f2df186b46e433be101ac39fceca5 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 10 Sep 2023 09:40:27 +0100 Subject: [PATCH] Move some models to candle-transformers so that it's easier to re-use. (#794) * Move some models to candle-transformers so that they can be shared. * Also move falcon. * Move Llama. * Move whisper (partial). --- candle-examples/Cargo.toml | 12 ++-- candle-examples/examples/bert/main.rs | 3 +- candle-examples/examples/bigcode/main.rs | 3 +- candle-examples/examples/falcon/main.rs | 3 +- candle-examples/examples/llama/main.rs | 3 +- candle-examples/examples/whisper/main.rs | 71 ++++++------------- .../examples/whisper/multilingual.rs | 2 +- candle-transformers/Cargo.toml | 4 ++ .../src/models/bert.rs | 0 .../src/models/bigcode.rs | 0 .../src/models/falcon.rs | 11 ++- .../src/models/llama.rs | 2 +- candle-transformers/src/models/mod.rs | 6 +- .../src/models}/whisper/audio.rs | 10 +-- candle-transformers/src/models/whisper/mod.rs | 26 +++++++ .../src/models}/whisper/model.rs | 0 16 files changed, 78 insertions(+), 78 deletions(-) rename candle-examples/examples/bert/model.rs => candle-transformers/src/models/bert.rs (100%) rename candle-examples/examples/bigcode/model.rs => candle-transformers/src/models/bigcode.rs (100%) rename candle-examples/examples/falcon/model.rs => candle-transformers/src/models/falcon.rs (98%) rename candle-examples/examples/llama/model.rs => candle-transformers/src/models/llama.rs (99%) rename {candle-examples/examples => candle-transformers/src/models}/whisper/audio.rs (97%) create mode 100644 candle-transformers/src/models/whisper/mod.rs rename {candle-examples/examples => candle-transformers/src/models}/whisper/model.rs (100%) diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 6f8792a3..eb552b88 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -13,18 +13,18 @@ readme = "README.md" accelerate-src = { workspace = true, optional = true } candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } candle-datasets = { path = "../candle-datasets", version = "0.2.1" } +candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true } candle-nn = { path = "../candle-nn", version = "0.2.1" } candle-transformers = { path = "../candle-transformers", version = "0.2.1" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true } -safetensors = { workspace = true } -serde = { workspace = true } -serde_json = { workspace = true } -num-traits = { workspace = true } -intel-mkl-src = { workspace = true, optional = true } cudarc = { workspace = true, optional = true } half = { workspace = true, optional = true } image = { workspace = true } +intel-mkl-src = { workspace = true, optional = true } +num-traits = { workspace = true } rayon = { workspace = true } +safetensors = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/candle-examples/examples/bert/main.rs b/candle-examples/examples/bert/main.rs index 6cee66ee..9d0eccdf 100644 --- a/candle-examples/examples/bert/main.rs +++ b/candle-examples/examples/bert/main.rs @@ -3,14 +3,13 @@ extern crate intel_mkl_src; #[cfg(feature = "accelerate")] extern crate accelerate_src; -mod model; +use candle_transformers::models::bert::{BertModel, Config, DTYPE}; use anyhow::{anyhow, Error as E, Result}; use candle::Tensor; use candle_nn::VarBuilder; use clap::Parser; use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; -use model::{BertModel, Config, DTYPE}; use tokenizers::{PaddingParams, Tokenizer}; #[derive(Parser, Debug)] diff --git a/candle-examples/examples/bigcode/main.rs b/candle-examples/examples/bigcode/main.rs index 652cd47f..3540f75d 100644 --- a/candle-examples/examples/bigcode/main.rs +++ b/candle-examples/examples/bigcode/main.rs @@ -7,8 +7,7 @@ extern crate accelerate_src; use anyhow::{Error as E, Result}; use clap::Parser; -mod model; -use model::{Config, GPTBigCode}; +use candle_transformers::models::bigcode::{Config, GPTBigCode}; use candle::{DType, Device, Tensor}; use candle_nn::VarBuilder; diff --git a/candle-examples/examples/falcon/main.rs b/candle-examples/examples/falcon/main.rs index 05507f08..c45fe545 100644 --- a/candle-examples/examples/falcon/main.rs +++ b/candle-examples/examples/falcon/main.rs @@ -14,8 +14,7 @@ use clap::Parser; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; -mod model; -use model::{Config, Falcon}; +use candle_transformers::models::falcon::{Config, Falcon}; struct TextGeneration { model: Falcon, diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 6f8766d4..db3d216c 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use std::io::Write; -mod model; +use candle_transformers::models::llama as model; use model::{Config, Llama, LlamaConfig}; const EOS_TOKEN: &str = ""; -const MAX_SEQ_LEN: usize = 4096; const DEFAULT_PROMPT: &str = "My favorite theorem is "; #[derive(Parser, Debug)] diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index dbe9cc8d..c71d562a 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -10,41 +10,16 @@ extern crate accelerate_src; extern crate intel_mkl_src; use anyhow::{Error as E, Result}; -use candle::{DType, Device, IndexOp, Tensor}; +use candle::{Device, IndexOp, Tensor}; use candle_nn::{ops::softmax, VarBuilder}; use clap::{Parser, ValueEnum}; use hf_hub::{api::sync::Api, Repo, RepoType}; use rand::{distributions::Distribution, SeedableRng}; use tokenizers::Tokenizer; -mod audio; -mod model; -use model::{Config, Whisper}; mod multilingual; - -const DTYPE: DType = DType::F32; - -// Audio parameters. -const SAMPLE_RATE: usize = 16000; -const N_FFT: usize = 400; -const N_MELS: usize = 80; -const HOP_LENGTH: usize = 160; -const CHUNK_LENGTH: usize = 30; -const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk -const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input - -const NO_SPEECH_THRESHOLD: f64 = 0.6; -const LOGPROB_THRESHOLD: f64 = -1.0; -const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; -const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; - -// Tokenizer dependent bits. -const SOT_TOKEN: &str = "<|startoftranscript|>"; -const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; -const TRANSLATE_TOKEN: &str = "<|translate|>"; -const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; -const EOT_TOKEN: &str = "<|endoftext|>"; -const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; +use candle_transformers::models::whisper::{self as m, audio, model}; +use model::{Config, Whisper}; #[allow(dead_code)] #[derive(Debug, Clone)] @@ -94,7 +69,7 @@ impl Decoder { timestamps: bool, verbose: bool, ) -> Result { - let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; + let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?; // Suppress the notimestamps token when in timestamps mode. // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452 let suppress_tokens: Vec = (0..model.config.vocab_size as u32) @@ -109,11 +84,11 @@ impl Decoder { }) .collect(); let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; - let sot_token = token_id(&tokenizer, SOT_TOKEN)?; - let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; - let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?; - let eot_token = token_id(&tokenizer, EOT_TOKEN)?; - let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; + let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?; + let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?; + let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?; + let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?; + let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?; Ok(Self { model, rng: rand::rngs::StdRng::seed_from_u64(seed), @@ -220,17 +195,17 @@ impl Decoder { } fn decode_with_fallback(&mut self, segment: &Tensor) -> Result { - for (i, &t) in TEMPERATURES.iter().enumerate() { + for (i, &t) in m::TEMPERATURES.iter().enumerate() { let dr: Result = self.decode(segment, t); - if i == TEMPERATURES.len() - 1 { + if i == m::TEMPERATURES.len() - 1 { return dr; } // On errors, we try again with a different temperature. match dr { Ok(dr) => { - let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD - || dr.avg_logprob < LOGPROB_THRESHOLD; - if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { + let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD + || dr.avg_logprob < m::LOGPROB_THRESHOLD; + if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD { return Ok(dr); } } @@ -248,13 +223,13 @@ impl Decoder { let mut segments = vec![]; while seek < content_frames { let start = std::time::Instant::now(); - let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; - let segment_size = usize::min(content_frames - seek, N_FRAMES); + let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64; + let segment_size = usize::min(content_frames - seek, m::N_FRAMES); let mel_segment = mel.narrow(2, seek, segment_size)?; - let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; + let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64; let dr = self.decode_with_fallback(&mel_segment)?; seek += segment_size; - if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { + if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD { println!("no speech detected, skipping {seek} {dr:?}"); continue; } @@ -492,8 +467,8 @@ fn main() -> Result<()> { let mut input = std::fs::File::open(input)?; let (header, data) = wav::read(&mut input)?; println!("loaded wav data: {header:?}"); - if header.sampling_rate != SAMPLE_RATE as u32 { - anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE) + if header.sampling_rate != m::SAMPLE_RATE as u32 { + anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE) } let data = data.as_sixteen().expect("expected 16 bit wav file"); let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] @@ -501,14 +476,14 @@ fn main() -> Result<()> { .map(|v| *v as f32 / 32768.) .collect(); println!("pcm data loaded {}", pcm_data.len()); - let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?; + let mel = audio::pcm_to_mel(&pcm_data, &mel_filters); let mel_len = mel.len(); - let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?; + let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?; println!("loaded mel: {:?}", mel.dims()); let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = weights.deserialize()?; - let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); + let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device); let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let mut model = Whisper::load(&vb, config)?; diff --git a/candle-examples/examples/whisper/multilingual.rs b/candle-examples/examples/whisper/multilingual.rs index bc0bae1f..a82b09ef 100644 --- a/candle-examples/examples/whisper/multilingual.rs +++ b/candle-examples/examples/whisper/multilingual.rs @@ -113,7 +113,7 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor) .iter() .map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>"))) .collect::>>()?; - let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?; + let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?; let audio_features = model.encoder.forward(&mel, true)?; let tokens = Tensor::new(&[[sot_token]], device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index a05b9bb7..6b2087cb 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -14,7 +14,11 @@ accelerate-src = { workspace = true, optional = true } candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } candle-nn = { path = "../candle-nn", version = "0.2.1" } intel-mkl-src = { workspace = true, optional = true } +num-traits = { workspace = true } rand = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } wav = { workspace = true } [features] diff --git a/candle-examples/examples/bert/model.rs b/candle-transformers/src/models/bert.rs similarity index 100% rename from candle-examples/examples/bert/model.rs rename to candle-transformers/src/models/bert.rs diff --git a/candle-examples/examples/bigcode/model.rs b/candle-transformers/src/models/bigcode.rs similarity index 100% rename from candle-examples/examples/bigcode/model.rs rename to candle-transformers/src/models/bigcode.rs diff --git a/candle-examples/examples/falcon/model.rs b/candle-transformers/src/models/falcon.rs similarity index 98% rename from candle-examples/examples/falcon/model.rs rename to candle-transformers/src/models/falcon.rs index b638dd51..6ede136a 100644 --- a/candle-examples/examples/falcon/model.rs +++ b/candle-transformers/src/models/falcon.rs @@ -1,5 +1,4 @@ -use anyhow::Result; -use candle::{DType, Device, Tensor, D}; +use candle::{DType, Device, Result, Tensor, D}; use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; const MAX_SEQ_LEN: usize = 5000; @@ -21,7 +20,7 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result { if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { (weight, bias) } else { - return Err(err.into()); + return Err(err); } } }; @@ -82,13 +81,13 @@ impl Default for Config { impl Config { pub fn validate(&self) -> Result<()> { if self.alibi { - anyhow::bail!("alibi is not supported"); + candle::bail!("alibi is not supported"); } if self.new_decoder_architecture { - anyhow::bail!("new_decoder_architecture is not supported"); + candle::bail!("new_decoder_architecture is not supported"); } if self.n_head_kv.is_some() { - anyhow::bail!("n_head_kv is not supported"); + candle::bail!("n_head_kv is not supported"); } Ok(()) } diff --git a/candle-examples/examples/llama/model.rs b/candle-transformers/src/models/llama.rs similarity index 99% rename from candle-examples/examples/llama/model.rs rename to candle-transformers/src/models/llama.rs index 275856e0..eed4df5e 100644 --- a/candle-examples/examples/llama/model.rs +++ b/candle-transformers/src/models/llama.rs @@ -4,7 +4,7 @@ use serde::Deserialize; use std::collections::HashMap; use std::sync::{Arc, Mutex}; -use super::MAX_SEQ_LEN; +pub const MAX_SEQ_LEN: usize = 4096; #[derive(Deserialize)] pub struct LlamaConfig { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 8b137891..1b3dcf25 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -1 +1,5 @@ - +pub mod bert; +pub mod bigcode; +pub mod falcon; +pub mod llama; +pub mod whisper; diff --git a/candle-examples/examples/whisper/audio.rs b/candle-transformers/src/models/whisper/audio.rs similarity index 97% rename from candle-examples/examples/whisper/audio.rs rename to candle-transformers/src/models/whisper/audio.rs index 2ceed065..4e01de32 100644 --- a/candle-examples/examples/whisper/audio.rs +++ b/candle-transformers/src/models/whisper/audio.rs @@ -198,17 +198,13 @@ fn log_mel_spectrogram_( mel } -pub fn pcm_to_mel( - samples: &[T], - filters: &[T], -) -> anyhow::Result> { - let mel = log_mel_spectrogram_( +pub fn pcm_to_mel(samples: &[T], filters: &[T]) -> Vec { + log_mel_spectrogram_( samples, filters, super::N_FFT, super::HOP_LENGTH, super::N_MELS, false, - ); - Ok(mel) + ) } diff --git a/candle-transformers/src/models/whisper/mod.rs b/candle-transformers/src/models/whisper/mod.rs new file mode 100644 index 00000000..7dc8107b --- /dev/null +++ b/candle-transformers/src/models/whisper/mod.rs @@ -0,0 +1,26 @@ +pub mod audio; +pub mod model; + +pub const DTYPE: candle::DType = candle::DType::F32; + +// Audio parameters. +pub const SAMPLE_RATE: usize = 16000; +pub const N_FFT: usize = 400; +pub const N_MELS: usize = 80; +pub const HOP_LENGTH: usize = 160; +pub const CHUNK_LENGTH: usize = 30; +pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk +pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input + +pub const NO_SPEECH_THRESHOLD: f64 = 0.6; +pub const LOGPROB_THRESHOLD: f64 = -1.0; +pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0]; +pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4; + +// Tokenizer dependent bits. +pub const SOT_TOKEN: &str = "<|startoftranscript|>"; +pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; +pub const TRANSLATE_TOKEN: &str = "<|translate|>"; +pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>"; +pub const EOT_TOKEN: &str = "<|endoftext|>"; +pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; diff --git a/candle-examples/examples/whisper/model.rs b/candle-transformers/src/models/whisper/model.rs similarity index 100% rename from candle-examples/examples/whisper/model.rs rename to candle-transformers/src/models/whisper/model.rs