Move some models to candle-transformers so that it's easier to re-use. (#794)

* Move some models to candle-transformers so that they can be shared.

* Also move falcon.

* Move Llama.

* Move whisper (partial).
This commit is contained in:
Laurent Mazare
2023-09-10 09:40:27 +01:00
committed by GitHub
parent 258ac32c38
commit d3f05eae8c
16 changed files with 78 additions and 78 deletions

View File

@ -13,18 +13,18 @@ readme = "README.md"
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
candle-datasets = { path = "../candle-datasets", version = "0.2.1" } candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
candle-nn = { path = "../candle-nn", version = "0.2.1" } candle-nn = { path = "../candle-nn", version = "0.2.1" }
candle-transformers = { path = "../candle-transformers", version = "0.2.1" } candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
num-traits = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
cudarc = { workspace = true, optional = true } cudarc = { workspace = true, optional = true }
half = { workspace = true, optional = true } half = { workspace = true, optional = true }
image = { workspace = true } image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rayon = { workspace = true } rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
[dev-dependencies] [dev-dependencies]
anyhow = { workspace = true } anyhow = { workspace = true }

View File

@ -3,14 +3,13 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")] #[cfg(feature = "accelerate")]
extern crate accelerate_src; extern crate accelerate_src;
mod model; use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use anyhow::{anyhow, Error as E, Result}; use anyhow::{anyhow, Error as E, Result};
use candle::Tensor; use candle::Tensor;
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use clap::Parser; use clap::Parser;
use hf_hub::{api::sync::Api, Cache, Repo, RepoType}; use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
use model::{BertModel, Config, DTYPE};
use tokenizers::{PaddingParams, Tokenizer}; use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)] #[derive(Parser, Debug)]

View File

@ -7,8 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use clap::Parser; use clap::Parser;
mod model; use candle_transformers::models::bigcode::{Config, GPTBigCode};
use model::{Config, GPTBigCode};
use candle::{DType, Device, Tensor}; use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;

View File

@ -14,8 +14,7 @@ use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
mod model; use candle_transformers::models::falcon::{Config, Falcon};
use model::{Config, Falcon};
struct TextGeneration { struct TextGeneration {
model: Falcon, model: Falcon,

View File

@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use std::io::Write; use std::io::Write;
mod model; use candle_transformers::models::llama as model;
use model::{Config, Llama, LlamaConfig}; use model::{Config, Llama, LlamaConfig};
const EOS_TOKEN: &str = "</s>"; const EOS_TOKEN: &str = "</s>";
const MAX_SEQ_LEN: usize = 4096;
const DEFAULT_PROMPT: &str = "My favorite theorem is "; const DEFAULT_PROMPT: &str = "My favorite theorem is ";
#[derive(Parser, Debug)] #[derive(Parser, Debug)]

View File

@ -10,41 +10,16 @@ extern crate accelerate_src;
extern crate intel_mkl_src; extern crate intel_mkl_src;
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor}; use candle::{Device, IndexOp, Tensor};
use candle_nn::{ops::softmax, VarBuilder}; use candle_nn::{ops::softmax, VarBuilder};
use clap::{Parser, ValueEnum}; use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType}; use hf_hub::{api::sync::Api, Repo, RepoType};
use rand::{distributions::Distribution, SeedableRng}; use rand::{distributions::Distribution, SeedableRng};
use tokenizers::Tokenizer; use tokenizers::Tokenizer;
mod audio;
mod model;
use model::{Config, Whisper};
mod multilingual; mod multilingual;
use candle_transformers::models::whisper::{self as m, audio, model};
const DTYPE: DType = DType::F32; use model::{Config, Whisper};
// Audio parameters.
const SAMPLE_RATE: usize = 16000;
const N_FFT: usize = 400;
const N_MELS: usize = 80;
const HOP_LENGTH: usize = 160;
const CHUNK_LENGTH: usize = 30;
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
const NO_SPEECH_THRESHOLD: f64 = 0.6;
const LOGPROB_THRESHOLD: f64 = -1.0;
const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits.
const SOT_TOKEN: &str = "<|startoftranscript|>";
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
const TRANSLATE_TOKEN: &str = "<|translate|>";
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -94,7 +69,7 @@ impl Decoder {
timestamps: bool, timestamps: bool,
verbose: bool, verbose: bool,
) -> Result<Self> { ) -> Result<Self> {
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?; let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
// Suppress the notimestamps token when in timestamps mode. // Suppress the notimestamps token when in timestamps mode.
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452 // https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32) let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
@ -109,11 +84,11 @@ impl Decoder {
}) })
.collect(); .collect();
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?; let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?; let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?; let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
Ok(Self { Ok(Self {
model, model,
rng: rand::rngs::StdRng::seed_from_u64(seed), rng: rand::rngs::StdRng::seed_from_u64(seed),
@ -220,17 +195,17 @@ impl Decoder {
} }
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> { fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
for (i, &t) in TEMPERATURES.iter().enumerate() { for (i, &t) in m::TEMPERATURES.iter().enumerate() {
let dr: Result<DecodingResult> = self.decode(segment, t); let dr: Result<DecodingResult> = self.decode(segment, t);
if i == TEMPERATURES.len() - 1 { if i == m::TEMPERATURES.len() - 1 {
return dr; return dr;
} }
// On errors, we try again with a different temperature. // On errors, we try again with a different temperature.
match dr { match dr {
Ok(dr) => { Ok(dr) => {
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|| dr.avg_logprob < LOGPROB_THRESHOLD; || dr.avg_logprob < m::LOGPROB_THRESHOLD;
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD { if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
return Ok(dr); return Ok(dr);
} }
} }
@ -248,13 +223,13 @@ impl Decoder {
let mut segments = vec![]; let mut segments = vec![];
while seek < content_frames { while seek < content_frames {
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, N_FRAMES); let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
let mel_segment = mel.narrow(2, seek, segment_size)?; let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64; let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
let dr = self.decode_with_fallback(&mel_segment)?; let dr = self.decode_with_fallback(&mel_segment)?;
seek += segment_size; seek += segment_size;
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD { if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
println!("no speech detected, skipping {seek} {dr:?}"); println!("no speech detected, skipping {seek} {dr:?}");
continue; continue;
} }
@ -492,8 +467,8 @@ fn main() -> Result<()> {
let mut input = std::fs::File::open(input)?; let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?; let (header, data) = wav::read(&mut input)?;
println!("loaded wav data: {header:?}"); println!("loaded wav data: {header:?}");
if header.sampling_rate != SAMPLE_RATE as u32 { if header.sampling_rate != m::SAMPLE_RATE as u32 {
anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE) anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
} }
let data = data.as_sixteen().expect("expected 16 bit wav file"); let data = data.as_sixteen().expect("expected 16 bit wav file");
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize] let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
@ -501,14 +476,14 @@ fn main() -> Result<()> {
.map(|v| *v as f32 / 32768.) .map(|v| *v as f32 / 32768.)
.collect(); .collect();
println!("pcm data loaded {}", pcm_data.len()); println!("pcm data loaded {}", pcm_data.len());
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?; let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
let mel_len = mel.len(); let mel_len = mel.len();
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?; let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims()); println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
let weights = weights.deserialize()?; let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device); let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = Whisper::load(&vb, config)?; let mut model = Whisper::load(&vb, config)?;

View File

@ -113,7 +113,7 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
.iter() .iter()
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>"))) .map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?; let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
let audio_features = model.encoder.forward(&mel, true)?; let audio_features = model.encoder.forward(&mel, true)?;
let tokens = Tensor::new(&[[sot_token]], device)?; let tokens = Tensor::new(&[[sot_token]], device)?;
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?; let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;

View File

@ -14,7 +14,11 @@ accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" } candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.2.1" } candle-nn = { path = "../candle-nn", version = "0.2.1" }
intel-mkl-src = { workspace = true, optional = true } intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rand = { workspace = true } rand = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tracing = { workspace = true }
wav = { workspace = true } wav = { workspace = true }
[features] [features]

View File

@ -1,5 +1,4 @@
use anyhow::Result; use candle::{DType, Device, Result, Tensor, D};
use candle::{DType, Device, Tensor, D};
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder}; use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
const MAX_SEQ_LEN: usize = 5000; const MAX_SEQ_LEN: usize = 5000;
@ -21,7 +20,7 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) { if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
(weight, bias) (weight, bias)
} else { } else {
return Err(err.into()); return Err(err);
} }
} }
}; };
@ -82,13 +81,13 @@ impl Default for Config {
impl Config { impl Config {
pub fn validate(&self) -> Result<()> { pub fn validate(&self) -> Result<()> {
if self.alibi { if self.alibi {
anyhow::bail!("alibi is not supported"); candle::bail!("alibi is not supported");
} }
if self.new_decoder_architecture { if self.new_decoder_architecture {
anyhow::bail!("new_decoder_architecture is not supported"); candle::bail!("new_decoder_architecture is not supported");
} }
if self.n_head_kv.is_some() { if self.n_head_kv.is_some() {
anyhow::bail!("n_head_kv is not supported"); candle::bail!("n_head_kv is not supported");
} }
Ok(()) Ok(())
} }

View File

@ -4,7 +4,7 @@ use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use super::MAX_SEQ_LEN; pub const MAX_SEQ_LEN: usize = 4096;
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct LlamaConfig { pub struct LlamaConfig {

View File

@ -1 +1,5 @@
pub mod bert;
pub mod bigcode;
pub mod falcon;
pub mod llama;
pub mod whisper;

View File

@ -198,17 +198,13 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
mel mel
} }
pub fn pcm_to_mel<T: Float + std::fmt::Display>( pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> {
samples: &[T], log_mel_spectrogram_(
filters: &[T],
) -> anyhow::Result<Vec<T>> {
let mel = log_mel_spectrogram_(
samples, samples,
filters, filters,
super::N_FFT, super::N_FFT,
super::HOP_LENGTH, super::HOP_LENGTH,
super::N_MELS, super::N_MELS,
false, false,
); )
Ok(mel)
} }

View File

@ -0,0 +1,26 @@
pub mod audio;
pub mod model;
pub const DTYPE: candle::DType = candle::DType::F32;
// Audio parameters.
pub const SAMPLE_RATE: usize = 16000;
pub const N_FFT: usize = 400;
pub const N_MELS: usize = 80;
pub const HOP_LENGTH: usize = 160;
pub const CHUNK_LENGTH: usize = 30;
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
pub const LOGPROB_THRESHOLD: f64 = -1.0;
pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits.
pub const SOT_TOKEN: &str = "<|startoftranscript|>";
pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
pub const TRANSLATE_TOKEN: &str = "<|translate|>";
pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
pub const EOT_TOKEN: &str = "<|endoftext|>";
pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";