Replicate the sot-token logic from the Python implementation more acc… (#491)

* Replicate the sot-token logic from the Python implementation more accurately.

* Add a flag to control the timestamp mode.
This commit is contained in:
Laurent Mazare
2023-08-17 16:59:36 +01:00
committed by GitHub
parent 5f30c1e1e0
commit 3164cd24fa

View File

@ -41,6 +41,8 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits. // Tokenizer dependent bits.
const SOT_TOKEN: &str = "<|startoftranscript|>"; const SOT_TOKEN: &str = "<|startoftranscript|>";
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>"; const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
const TRANSLATE_TOKEN: &str = "<|translate|>";
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
const EOT_TOKEN: &str = "<|endoftext|>"; const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>"; const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
@ -66,12 +68,16 @@ struct Segment {
struct Decoder { struct Decoder {
model: Whisper, model: Whisper,
rng: rand::rngs::StdRng, rng: rand::rngs::StdRng,
task: Option<Task>,
timestamps: bool,
tokenizer: Tokenizer, tokenizer: Tokenizer,
suppress_tokens: Tensor, suppress_tokens: Tensor,
sot_token: u32, sot_token: u32,
transcribe_token: u32, transcribe_token: u32,
translate_token: u32,
eot_token: u32, eot_token: u32,
no_speech_token: u32, no_speech_token: u32,
no_timestamps_token: u32,
language_token: Option<u32>, language_token: Option<u32>,
} }
@ -82,6 +88,8 @@ impl Decoder {
seed: u64, seed: u64,
device: &Device, device: &Device,
language_token: Option<u32>, language_token: Option<u32>,
task: Option<Task>,
timestamps: bool,
) -> Result<Self> { ) -> Result<Self> {
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32) let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
.map(|i| { .map(|i| {
@ -95,18 +103,24 @@ impl Decoder {
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?; let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?; let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?; let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?; let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?; let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
Ok(Self { Ok(Self {
model, model,
rng: rand::rngs::StdRng::seed_from_u64(seed), rng: rand::rngs::StdRng::seed_from_u64(seed),
tokenizer, tokenizer,
task,
timestamps,
suppress_tokens, suppress_tokens,
sot_token, sot_token,
transcribe_token, transcribe_token,
translate_token,
eot_token, eot_token,
no_speech_token, no_speech_token,
language_token, language_token,
no_timestamps_token,
}) })
} }
@ -118,10 +132,19 @@ impl Decoder {
let mut sum_logprob = 0f64; let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN; let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token]; let mut tokens = vec![self.sot_token];
if let Some(language_token) = self.language_token { match self.task {
tokens.push(language_token) Some(Task::Transcribe) => tokens.push(self.transcribe_token),
Some(Task::Translate) => tokens.push(self.translate_token),
None => {
// Nothing in this case, same as the Python implementation.
}
}
if let Some(language_token) = self.language_token {
tokens.push(language_token);
}
if !self.timestamps {
tokens.push(self.no_timestamps_token);
} }
tokens.push(self.transcribe_token);
for i in 0..sample_len { for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?; let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
@ -240,6 +263,12 @@ pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
} }
} }
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Task {
Transcribe,
Translate,
}
#[derive(Clone, Copy, Debug, ValueEnum)] #[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel { enum WhichModel {
Tiny, Tiny,
@ -313,6 +342,15 @@ struct Args {
/// Language. /// Language.
#[arg(long)] #[arg(long)]
language: Option<String>, language: Option<String>,
/// Task, when no task is specified, the input tokens contain only the sot token which can
/// improve things when in no-timestamp mode.
#[arg(long)]
task: Option<Task>,
/// Timestamps mode, this is not fully implemented yet.
#[arg(long)]
timestamps: bool,
} }
fn main() -> Result<()> { fn main() -> Result<()> {
@ -414,7 +452,15 @@ fn main() -> Result<()> {
anyhow::bail!("a language cannot be set for non-multilingual models") anyhow::bail!("a language cannot be set for non-multilingual models")
} }
}; };
let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?; let mut dc = Decoder::new(
model,
tokenizer,
args.seed,
&device,
language_token,
args.task,
args.timestamps,
)?;
dc.run(&mel)?; dc.run(&mel)?;
Ok(()) Ok(())
} }