mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Move some models to candle-transformers so that it's easier to re-use. (#794)
* Move some models to candle-transformers so that they can be shared. * Also move falcon. * Move Llama. * Move whisper (partial).
This commit is contained in:
@ -13,18 +13,18 @@ readme = "README.md"
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
|
||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
cudarc = { workspace = true, optional = true }
|
||||
half = { workspace = true, optional = true }
|
||||
image = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = { workspace = true }
|
||||
|
@ -3,14 +3,13 @@ extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
mod model;
|
||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||
|
||||
use anyhow::{anyhow, Error as E, Result};
|
||||
use candle::Tensor;
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||
use model::{BertModel, Config, DTYPE};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
|
@ -7,8 +7,7 @@ extern crate accelerate_src;
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
mod model;
|
||||
use model::{Config, GPTBigCode};
|
||||
use candle_transformers::models::bigcode::{Config, GPTBigCode};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
|
@ -14,8 +14,7 @@ use clap::Parser;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod model;
|
||||
use model::{Config, Falcon};
|
||||
use candle_transformers::models::falcon::{Config, Falcon};
|
||||
|
||||
struct TextGeneration {
|
||||
model: Falcon,
|
||||
|
@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::io::Write;
|
||||
|
||||
mod model;
|
||||
use candle_transformers::models::llama as model;
|
||||
use model::{Config, Llama, LlamaConfig};
|
||||
|
||||
const EOS_TOKEN: &str = "</s>";
|
||||
const MAX_SEQ_LEN: usize = 4096;
|
||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
|
@ -10,41 +10,16 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use candle::{DType, Device, IndexOp, Tensor};
|
||||
use candle::{Device, IndexOp, Tensor};
|
||||
use candle_nn::{ops::softmax, VarBuilder};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
mod audio;
|
||||
mod model;
|
||||
use model::{Config, Whisper};
|
||||
mod multilingual;
|
||||
|
||||
const DTYPE: DType = DType::F32;
|
||||
|
||||
// Audio parameters.
|
||||
const SAMPLE_RATE: usize = 16000;
|
||||
const N_FFT: usize = 400;
|
||||
const N_MELS: usize = 80;
|
||||
const HOP_LENGTH: usize = 160;
|
||||
const CHUNK_LENGTH: usize = 30;
|
||||
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
||||
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
|
||||
|
||||
const NO_SPEECH_THRESHOLD: f64 = 0.6;
|
||||
const LOGPROB_THRESHOLD: f64 = -1.0;
|
||||
const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
||||
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||
|
||||
// Tokenizer dependent bits.
|
||||
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||
const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||
const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||
use candle_transformers::models::whisper::{self as m, audio, model};
|
||||
use model::{Config, Whisper};
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[derive(Debug, Clone)]
|
||||
@ -94,7 +69,7 @@ impl Decoder {
|
||||
timestamps: bool,
|
||||
verbose: bool,
|
||||
) -> Result<Self> {
|
||||
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
||||
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
||||
// Suppress the notimestamps token when in timestamps mode.
|
||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||
@ -109,11 +84,11 @@ impl Decoder {
|
||||
})
|
||||
.collect();
|
||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
@ -220,17 +195,17 @@ impl Decoder {
|
||||
}
|
||||
|
||||
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
|
||||
for (i, &t) in TEMPERATURES.iter().enumerate() {
|
||||
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
|
||||
let dr: Result<DecodingResult> = self.decode(segment, t);
|
||||
if i == TEMPERATURES.len() - 1 {
|
||||
if i == m::TEMPERATURES.len() - 1 {
|
||||
return dr;
|
||||
}
|
||||
// On errors, we try again with a different temperature.
|
||||
match dr {
|
||||
Ok(dr) => {
|
||||
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
|
||||
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|
||||
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
|
||||
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
|
||||
return Ok(dr);
|
||||
}
|
||||
}
|
||||
@ -248,13 +223,13 @@ impl Decoder {
|
||||
let mut segments = vec![];
|
||||
while seek < content_frames {
|
||||
let start = std::time::Instant::now();
|
||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
|
||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||
let dr = self.decode_with_fallback(&mel_segment)?;
|
||||
seek += segment_size;
|
||||
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
||||
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
|
||||
println!("no speech detected, skipping {seek} {dr:?}");
|
||||
continue;
|
||||
}
|
||||
@ -492,8 +467,8 @@ fn main() -> Result<()> {
|
||||
let mut input = std::fs::File::open(input)?;
|
||||
let (header, data) = wav::read(&mut input)?;
|
||||
println!("loaded wav data: {header:?}");
|
||||
if header.sampling_rate != SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE)
|
||||
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
||||
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
|
||||
}
|
||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||
@ -501,14 +476,14 @@ fn main() -> Result<()> {
|
||||
.map(|v| *v as f32 / 32768.)
|
||||
.collect();
|
||||
println!("pcm data loaded {}", pcm_data.len());
|
||||
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
|
||||
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
|
||||
let mel_len = mel.len();
|
||||
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
|
||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||
let weights = weights.deserialize()?;
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
||||
let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
|
||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||
let mut model = Whisper::load(&vb, config)?;
|
||||
|
||||
|
@ -113,7 +113,7 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
|
||||
.iter()
|
||||
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
|
||||
let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
|
||||
let audio_features = model.encoder.forward(&mel, true)?;
|
||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||
|
@ -14,7 +14,11 @@ accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
num-traits = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
wav = { workspace = true }
|
||||
|
||||
[features]
|
||||
|
@ -1,5 +1,4 @@
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, D};
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||
|
||||
const MAX_SEQ_LEN: usize = 5000;
|
||||
@ -21,7 +20,7 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
||||
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||
(weight, bias)
|
||||
} else {
|
||||
return Err(err.into());
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -82,13 +81,13 @@ impl Default for Config {
|
||||
impl Config {
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if self.alibi {
|
||||
anyhow::bail!("alibi is not supported");
|
||||
candle::bail!("alibi is not supported");
|
||||
}
|
||||
if self.new_decoder_architecture {
|
||||
anyhow::bail!("new_decoder_architecture is not supported");
|
||||
candle::bail!("new_decoder_architecture is not supported");
|
||||
}
|
||||
if self.n_head_kv.is_some() {
|
||||
anyhow::bail!("n_head_kv is not supported");
|
||||
candle::bail!("n_head_kv is not supported");
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -4,7 +4,7 @@ use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::MAX_SEQ_LEN;
|
||||
pub const MAX_SEQ_LEN: usize = 4096;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct LlamaConfig {
|
@ -1 +1,5 @@
|
||||
|
||||
pub mod bert;
|
||||
pub mod bigcode;
|
||||
pub mod falcon;
|
||||
pub mod llama;
|
||||
pub mod whisper;
|
||||
|
@ -198,17 +198,13 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
||||
mel
|
||||
}
|
||||
|
||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
||||
samples: &[T],
|
||||
filters: &[T],
|
||||
) -> anyhow::Result<Vec<T>> {
|
||||
let mel = log_mel_spectrogram_(
|
||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> {
|
||||
log_mel_spectrogram_(
|
||||
samples,
|
||||
filters,
|
||||
super::N_FFT,
|
||||
super::HOP_LENGTH,
|
||||
super::N_MELS,
|
||||
false,
|
||||
);
|
||||
Ok(mel)
|
||||
)
|
||||
}
|
26
candle-transformers/src/models/whisper/mod.rs
Normal file
26
candle-transformers/src/models/whisper/mod.rs
Normal file
@ -0,0 +1,26 @@
|
||||
pub mod audio;
|
||||
pub mod model;
|
||||
|
||||
pub const DTYPE: candle::DType = candle::DType::F32;
|
||||
|
||||
// Audio parameters.
|
||||
pub const SAMPLE_RATE: usize = 16000;
|
||||
pub const N_FFT: usize = 400;
|
||||
pub const N_MELS: usize = 80;
|
||||
pub const HOP_LENGTH: usize = 160;
|
||||
pub const CHUNK_LENGTH: usize = 30;
|
||||
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
||||
pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
|
||||
|
||||
pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
|
||||
pub const LOGPROB_THRESHOLD: f64 = -1.0;
|
||||
pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
||||
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||
|
||||
// Tokenizer dependent bits.
|
||||
pub const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||
pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||
pub const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||
pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||
pub const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
Reference in New Issue
Block a user