mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Replicate the sot-token logic from the Python implementation more acc… (#491)
* Replicate the sot-token logic from the Python implementation more accurately. * Add a flag to control the timestamp mode.
This commit is contained in:
@ -41,6 +41,8 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
||||
// Tokenizer dependent bits.
|
||||
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||
const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||
const EOT_TOKEN: &str = "<|endoftext|>";
|
||||
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||
|
||||
@ -66,12 +68,16 @@ struct Segment {
|
||||
struct Decoder {
|
||||
model: Whisper,
|
||||
rng: rand::rngs::StdRng,
|
||||
task: Option<Task>,
|
||||
timestamps: bool,
|
||||
tokenizer: Tokenizer,
|
||||
suppress_tokens: Tensor,
|
||||
sot_token: u32,
|
||||
transcribe_token: u32,
|
||||
translate_token: u32,
|
||||
eot_token: u32,
|
||||
no_speech_token: u32,
|
||||
no_timestamps_token: u32,
|
||||
language_token: Option<u32>,
|
||||
}
|
||||
|
||||
@ -82,6 +88,8 @@ impl Decoder {
|
||||
seed: u64,
|
||||
device: &Device,
|
||||
language_token: Option<u32>,
|
||||
task: Option<Task>,
|
||||
timestamps: bool,
|
||||
) -> Result<Self> {
|
||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||
.map(|i| {
|
||||
@ -95,18 +103,24 @@ impl Decoder {
|
||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
||||
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
|
||||
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||
Ok(Self {
|
||||
model,
|
||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||
tokenizer,
|
||||
task,
|
||||
timestamps,
|
||||
suppress_tokens,
|
||||
sot_token,
|
||||
transcribe_token,
|
||||
translate_token,
|
||||
eot_token,
|
||||
no_speech_token,
|
||||
language_token,
|
||||
no_timestamps_token,
|
||||
})
|
||||
}
|
||||
|
||||
@ -118,10 +132,19 @@ impl Decoder {
|
||||
let mut sum_logprob = 0f64;
|
||||
let mut no_speech_prob = f64::NAN;
|
||||
let mut tokens = vec![self.sot_token];
|
||||
if let Some(language_token) = self.language_token {
|
||||
tokens.push(language_token)
|
||||
match self.task {
|
||||
Some(Task::Transcribe) => tokens.push(self.transcribe_token),
|
||||
Some(Task::Translate) => tokens.push(self.translate_token),
|
||||
None => {
|
||||
// Nothing in this case, same as the Python implementation.
|
||||
}
|
||||
}
|
||||
if let Some(language_token) = self.language_token {
|
||||
tokens.push(language_token);
|
||||
}
|
||||
if !self.timestamps {
|
||||
tokens.push(self.no_timestamps_token);
|
||||
}
|
||||
tokens.push(self.transcribe_token);
|
||||
for i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||
|
||||
@ -240,6 +263,12 @@ pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Task {
|
||||
Transcribe,
|
||||
Translate,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum WhichModel {
|
||||
Tiny,
|
||||
@ -313,6 +342,15 @@ struct Args {
|
||||
/// Language.
|
||||
#[arg(long)]
|
||||
language: Option<String>,
|
||||
|
||||
/// Task, when no task is specified, the input tokens contain only the sot token which can
|
||||
/// improve things when in no-timestamp mode.
|
||||
#[arg(long)]
|
||||
task: Option<Task>,
|
||||
|
||||
/// Timestamps mode, this is not fully implemented yet.
|
||||
#[arg(long)]
|
||||
timestamps: bool,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
@ -414,7 +452,15 @@ fn main() -> Result<()> {
|
||||
anyhow::bail!("a language cannot be set for non-multilingual models")
|
||||
}
|
||||
};
|
||||
let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?;
|
||||
let mut dc = Decoder::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
&device,
|
||||
language_token,
|
||||
args.task,
|
||||
args.timestamps,
|
||||
)?;
|
||||
dc.run(&mel)?;
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user