mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Move some models to candle-transformers so that it's easier to re-use. (#794)
* Move some models to candle-transformers so that they can be shared. * Also move falcon. * Move Llama. * Move whisper (partial).
This commit is contained in:
@ -13,18 +13,18 @@ readme = "README.md"
|
|||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||||
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
|
candle-datasets = { path = "../candle-datasets", version = "0.2.1" }
|
||||||
|
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||||
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
|
candle-transformers = { path = "../candle-transformers", version = "0.2.1" }
|
||||||
candle-flash-attn = { path = "../candle-flash-attn", version = "0.2.1", optional = true }
|
|
||||||
safetensors = { workspace = true }
|
|
||||||
serde = { workspace = true }
|
|
||||||
serde_json = { workspace = true }
|
|
||||||
num-traits = { workspace = true }
|
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
|
||||||
cudarc = { workspace = true, optional = true }
|
cudarc = { workspace = true, optional = true }
|
||||||
half = { workspace = true, optional = true }
|
half = { workspace = true, optional = true }
|
||||||
image = { workspace = true }
|
image = { workspace = true }
|
||||||
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
|
num-traits = { workspace = true }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
|
safetensors = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
anyhow = { workspace = true }
|
anyhow = { workspace = true }
|
||||||
|
@ -3,14 +3,13 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
mod model;
|
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||||
|
|
||||||
use anyhow::{anyhow, Error as E, Result};
|
use anyhow::{anyhow, Error as E, Result};
|
||||||
use candle::Tensor;
|
use candle::Tensor;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
|
||||||
use model::{BertModel, Config, DTYPE};
|
|
||||||
use tokenizers::{PaddingParams, Tokenizer};
|
use tokenizers::{PaddingParams, Tokenizer};
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
|
@ -7,8 +7,7 @@ extern crate accelerate_src;
|
|||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
|
|
||||||
mod model;
|
use candle_transformers::models::bigcode::{Config, GPTBigCode};
|
||||||
use model::{Config, GPTBigCode};
|
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
@ -14,8 +14,7 @@ use clap::Parser;
|
|||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
mod model;
|
use candle_transformers::models::falcon::{Config, Falcon};
|
||||||
use model::{Config, Falcon};
|
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
model: Falcon,
|
model: Falcon,
|
||||||
|
@ -21,11 +21,10 @@ use candle_transformers::generation::LogitsProcessor;
|
|||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use std::io::Write;
|
use std::io::Write;
|
||||||
|
|
||||||
mod model;
|
use candle_transformers::models::llama as model;
|
||||||
use model::{Config, Llama, LlamaConfig};
|
use model::{Config, Llama, LlamaConfig};
|
||||||
|
|
||||||
const EOS_TOKEN: &str = "</s>";
|
const EOS_TOKEN: &str = "</s>";
|
||||||
const MAX_SEQ_LEN: usize = 4096;
|
|
||||||
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
const DEFAULT_PROMPT: &str = "My favorite theorem is ";
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
|
@ -10,41 +10,16 @@ extern crate accelerate_src;
|
|||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{DType, Device, IndexOp, Tensor};
|
use candle::{Device, IndexOp, Tensor};
|
||||||
use candle_nn::{ops::softmax, VarBuilder};
|
use candle_nn::{ops::softmax, VarBuilder};
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
use rand::{distributions::Distribution, SeedableRng};
|
use rand::{distributions::Distribution, SeedableRng};
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
mod audio;
|
|
||||||
mod model;
|
|
||||||
use model::{Config, Whisper};
|
|
||||||
mod multilingual;
|
mod multilingual;
|
||||||
|
use candle_transformers::models::whisper::{self as m, audio, model};
|
||||||
const DTYPE: DType = DType::F32;
|
use model::{Config, Whisper};
|
||||||
|
|
||||||
// Audio parameters.
|
|
||||||
const SAMPLE_RATE: usize = 16000;
|
|
||||||
const N_FFT: usize = 400;
|
|
||||||
const N_MELS: usize = 80;
|
|
||||||
const HOP_LENGTH: usize = 160;
|
|
||||||
const CHUNK_LENGTH: usize = 30;
|
|
||||||
const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
|
||||||
const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
|
|
||||||
|
|
||||||
const NO_SPEECH_THRESHOLD: f64 = 0.6;
|
|
||||||
const LOGPROB_THRESHOLD: f64 = -1.0;
|
|
||||||
const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
|
||||||
const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
|
||||||
|
|
||||||
// Tokenizer dependent bits.
|
|
||||||
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
|
||||||
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
|
||||||
const TRANSLATE_TOKEN: &str = "<|translate|>";
|
|
||||||
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
|
||||||
const EOT_TOKEN: &str = "<|endoftext|>";
|
|
||||||
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
|
||||||
|
|
||||||
#[allow(dead_code)]
|
#[allow(dead_code)]
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
@ -94,7 +69,7 @@ impl Decoder {
|
|||||||
timestamps: bool,
|
timestamps: bool,
|
||||||
verbose: bool,
|
verbose: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
let no_timestamps_token = token_id(&tokenizer, m::NO_TIMESTAMPS_TOKEN)?;
|
||||||
// Suppress the notimestamps token when in timestamps mode.
|
// Suppress the notimestamps token when in timestamps mode.
|
||||||
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
// https://github.com/openai/whisper/blob/e8622f9afc4eba139bf796c210f5c01081000472/whisper/decoding.py#L452
|
||||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||||
@ -109,11 +84,11 @@ impl Decoder {
|
|||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
let sot_token = token_id(&tokenizer, m::SOT_TOKEN)?;
|
||||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
let transcribe_token = token_id(&tokenizer, m::TRANSCRIBE_TOKEN)?;
|
||||||
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
|
let translate_token = token_id(&tokenizer, m::TRANSLATE_TOKEN)?;
|
||||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
let eot_token = token_id(&tokenizer, m::EOT_TOKEN)?;
|
||||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
let no_speech_token = token_id(&tokenizer, m::NO_SPEECH_TOKEN)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||||
@ -220,17 +195,17 @@ impl Decoder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
|
fn decode_with_fallback(&mut self, segment: &Tensor) -> Result<DecodingResult> {
|
||||||
for (i, &t) in TEMPERATURES.iter().enumerate() {
|
for (i, &t) in m::TEMPERATURES.iter().enumerate() {
|
||||||
let dr: Result<DecodingResult> = self.decode(segment, t);
|
let dr: Result<DecodingResult> = self.decode(segment, t);
|
||||||
if i == TEMPERATURES.len() - 1 {
|
if i == m::TEMPERATURES.len() - 1 {
|
||||||
return dr;
|
return dr;
|
||||||
}
|
}
|
||||||
// On errors, we try again with a different temperature.
|
// On errors, we try again with a different temperature.
|
||||||
match dr {
|
match dr {
|
||||||
Ok(dr) => {
|
Ok(dr) => {
|
||||||
let needs_fallback = dr.compression_ratio > COMPRESSION_RATIO_THRESHOLD
|
let needs_fallback = dr.compression_ratio > m::COMPRESSION_RATIO_THRESHOLD
|
||||||
|| dr.avg_logprob < LOGPROB_THRESHOLD;
|
|| dr.avg_logprob < m::LOGPROB_THRESHOLD;
|
||||||
if !needs_fallback || dr.no_speech_prob > NO_SPEECH_THRESHOLD {
|
if !needs_fallback || dr.no_speech_prob > m::NO_SPEECH_THRESHOLD {
|
||||||
return Ok(dr);
|
return Ok(dr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -248,13 +223,13 @@ impl Decoder {
|
|||||||
let mut segments = vec![];
|
let mut segments = vec![];
|
||||||
while seek < content_frames {
|
while seek < content_frames {
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
let time_offset = (seek * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
let segment_size = usize::min(content_frames - seek, m::N_FRAMES);
|
||||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||||
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
let segment_duration = (segment_size * m::HOP_LENGTH) as f64 / m::SAMPLE_RATE as f64;
|
||||||
let dr = self.decode_with_fallback(&mel_segment)?;
|
let dr = self.decode_with_fallback(&mel_segment)?;
|
||||||
seek += segment_size;
|
seek += segment_size;
|
||||||
if dr.no_speech_prob > NO_SPEECH_THRESHOLD && dr.avg_logprob < LOGPROB_THRESHOLD {
|
if dr.no_speech_prob > m::NO_SPEECH_THRESHOLD && dr.avg_logprob < m::LOGPROB_THRESHOLD {
|
||||||
println!("no speech detected, skipping {seek} {dr:?}");
|
println!("no speech detected, skipping {seek} {dr:?}");
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -492,8 +467,8 @@ fn main() -> Result<()> {
|
|||||||
let mut input = std::fs::File::open(input)?;
|
let mut input = std::fs::File::open(input)?;
|
||||||
let (header, data) = wav::read(&mut input)?;
|
let (header, data) = wav::read(&mut input)?;
|
||||||
println!("loaded wav data: {header:?}");
|
println!("loaded wav data: {header:?}");
|
||||||
if header.sampling_rate != SAMPLE_RATE as u32 {
|
if header.sampling_rate != m::SAMPLE_RATE as u32 {
|
||||||
anyhow::bail!("wav file must have a {} sampling rate", SAMPLE_RATE)
|
anyhow::bail!("wav file must have a {} sampling rate", m::SAMPLE_RATE)
|
||||||
}
|
}
|
||||||
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
let data = data.as_sixteen().expect("expected 16 bit wav file");
|
||||||
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
let pcm_data: Vec<_> = data[..data.len() / header.channel_count as usize]
|
||||||
@ -501,14 +476,14 @@ fn main() -> Result<()> {
|
|||||||
.map(|v| *v as f32 / 32768.)
|
.map(|v| *v as f32 / 32768.)
|
||||||
.collect();
|
.collect();
|
||||||
println!("pcm data loaded {}", pcm_data.len());
|
println!("pcm data loaded {}", pcm_data.len());
|
||||||
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters)?;
|
let mel = audio::pcm_to_mel(&pcm_data, &mel_filters);
|
||||||
let mel_len = mel.len();
|
let mel_len = mel.len();
|
||||||
let mel = Tensor::from_vec(mel, (1, N_MELS, mel_len / N_MELS), &device)?;
|
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||||
println!("loaded mel: {:?}", mel.dims());
|
println!("loaded mel: {:?}", mel.dims());
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
|
||||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||||
let mut model = Whisper::load(&vb, config)?;
|
let mut model = Whisper::load(&vb, config)?;
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ pub fn detect_language(model: &mut Whisper, tokenizer: &Tokenizer, mel: &Tensor)
|
|||||||
.iter()
|
.iter()
|
||||||
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
.map(|(t, _)| crate::token_id(tokenizer, &format!("<|{t}|>")))
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let sot_token = crate::token_id(tokenizer, crate::SOT_TOKEN)?;
|
let sot_token = crate::token_id(tokenizer, crate::m::SOT_TOKEN)?;
|
||||||
let audio_features = model.encoder.forward(&mel, true)?;
|
let audio_features = model.encoder.forward(&mel, true)?;
|
||||||
let tokens = Tensor::new(&[[sot_token]], device)?;
|
let tokens = Tensor::new(&[[sot_token]], device)?;
|
||||||
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
let language_token_ids = Tensor::new(language_token_ids.as_slice(), device)?;
|
||||||
|
@ -14,7 +14,11 @@ accelerate-src = { workspace = true, optional = true }
|
|||||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
|
num-traits = { workspace = true }
|
||||||
rand = { workspace = true }
|
rand = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
tracing = { workspace = true }
|
||||||
wav = { workspace = true }
|
wav = { workspace = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
use anyhow::Result;
|
use candle::{DType, Device, Result, Tensor, D};
|
||||||
use candle::{DType, Device, Tensor, D};
|
|
||||||
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
use candle_nn::{Embedding, LayerNorm, Linear, Module, VarBuilder};
|
||||||
|
|
||||||
const MAX_SEQ_LEN: usize = 5000;
|
const MAX_SEQ_LEN: usize = 5000;
|
||||||
@ -21,7 +20,7 @@ fn layer_norm(size: usize, eps: f64, vb: VarBuilder) -> Result<LayerNorm> {
|
|||||||
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
if let (Ok(weight), Ok(bias)) = (vb.get(size, "gamma"), vb.get(size, "beta")) {
|
||||||
(weight, bias)
|
(weight, bias)
|
||||||
} else {
|
} else {
|
||||||
return Err(err.into());
|
return Err(err);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -82,13 +81,13 @@ impl Default for Config {
|
|||||||
impl Config {
|
impl Config {
|
||||||
pub fn validate(&self) -> Result<()> {
|
pub fn validate(&self) -> Result<()> {
|
||||||
if self.alibi {
|
if self.alibi {
|
||||||
anyhow::bail!("alibi is not supported");
|
candle::bail!("alibi is not supported");
|
||||||
}
|
}
|
||||||
if self.new_decoder_architecture {
|
if self.new_decoder_architecture {
|
||||||
anyhow::bail!("new_decoder_architecture is not supported");
|
candle::bail!("new_decoder_architecture is not supported");
|
||||||
}
|
}
|
||||||
if self.n_head_kv.is_some() {
|
if self.n_head_kv.is_some() {
|
||||||
anyhow::bail!("n_head_kv is not supported");
|
candle::bail!("n_head_kv is not supported");
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
@ -4,7 +4,7 @@ use serde::Deserialize;
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
use super::MAX_SEQ_LEN;
|
pub const MAX_SEQ_LEN: usize = 4096;
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct LlamaConfig {
|
pub struct LlamaConfig {
|
@ -1 +1,5 @@
|
|||||||
|
pub mod bert;
|
||||||
|
pub mod bigcode;
|
||||||
|
pub mod falcon;
|
||||||
|
pub mod llama;
|
||||||
|
pub mod whisper;
|
||||||
|
@ -198,17 +198,13 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
|||||||
mel
|
mel
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn pcm_to_mel<T: Float + std::fmt::Display>(
|
pub fn pcm_to_mel<T: Float + std::fmt::Display>(samples: &[T], filters: &[T]) -> Vec<T> {
|
||||||
samples: &[T],
|
log_mel_spectrogram_(
|
||||||
filters: &[T],
|
|
||||||
) -> anyhow::Result<Vec<T>> {
|
|
||||||
let mel = log_mel_spectrogram_(
|
|
||||||
samples,
|
samples,
|
||||||
filters,
|
filters,
|
||||||
super::N_FFT,
|
super::N_FFT,
|
||||||
super::HOP_LENGTH,
|
super::HOP_LENGTH,
|
||||||
super::N_MELS,
|
super::N_MELS,
|
||||||
false,
|
false,
|
||||||
);
|
)
|
||||||
Ok(mel)
|
|
||||||
}
|
}
|
26
candle-transformers/src/models/whisper/mod.rs
Normal file
26
candle-transformers/src/models/whisper/mod.rs
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
pub mod audio;
|
||||||
|
pub mod model;
|
||||||
|
|
||||||
|
pub const DTYPE: candle::DType = candle::DType::F32;
|
||||||
|
|
||||||
|
// Audio parameters.
|
||||||
|
pub const SAMPLE_RATE: usize = 16000;
|
||||||
|
pub const N_FFT: usize = 400;
|
||||||
|
pub const N_MELS: usize = 80;
|
||||||
|
pub const HOP_LENGTH: usize = 160;
|
||||||
|
pub const CHUNK_LENGTH: usize = 30;
|
||||||
|
pub const N_SAMPLES: usize = CHUNK_LENGTH * SAMPLE_RATE; // 480000 samples in a 30-second chunk
|
||||||
|
pub const N_FRAMES: usize = N_SAMPLES / HOP_LENGTH; // 3000 frames in a mel spectrogram input
|
||||||
|
|
||||||
|
pub const NO_SPEECH_THRESHOLD: f64 = 0.6;
|
||||||
|
pub const LOGPROB_THRESHOLD: f64 = -1.0;
|
||||||
|
pub const TEMPERATURES: [f64; 6] = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0];
|
||||||
|
pub const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||||
|
|
||||||
|
// Tokenizer dependent bits.
|
||||||
|
pub const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||||
|
pub const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||||
|
pub const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||||
|
pub const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||||
|
pub const EOT_TOKEN: &str = "<|endoftext|>";
|
||||||
|
pub const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
Reference in New Issue
Block a user