Replicate the sot-token logic from the Python implementation more acc… (#491)

* Replicate the sot-token logic from the Python implementation more accurately.

* Add a flag to control the timestamp mode.
This commit is contained in:
Laurent Mazare
2023-08-17 16:59:36 +01:00
committed by GitHub
parent 5f30c1e1e0
commit 3164cd24fa

View File

@ -41,6 +41,8 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
// Tokenizer dependent bits.
const SOT_TOKEN: &str = "<|startoftranscript|>";
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
const TRANSLATE_TOKEN: &str = "<|translate|>";
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
const EOT_TOKEN: &str = "<|endoftext|>";
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
@ -66,12 +68,16 @@ struct Segment {
struct Decoder {
model: Whisper,
rng: rand::rngs::StdRng,
task: Option<Task>,
timestamps: bool,
tokenizer: Tokenizer,
suppress_tokens: Tensor,
sot_token: u32,
transcribe_token: u32,
translate_token: u32,
eot_token: u32,
no_speech_token: u32,
no_timestamps_token: u32,
language_token: Option<u32>,
}
@ -82,6 +88,8 @@ impl Decoder {
seed: u64,
device: &Device,
language_token: Option<u32>,
task: Option<Task>,
timestamps: bool,
) -> Result<Self> {
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
.map(|i| {
@ -95,18 +103,24 @@ impl Decoder {
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
Ok(Self {
model,
rng: rand::rngs::StdRng::seed_from_u64(seed),
tokenizer,
task,
timestamps,
suppress_tokens,
sot_token,
transcribe_token,
translate_token,
eot_token,
no_speech_token,
language_token,
no_timestamps_token,
})
}
@ -118,10 +132,19 @@ impl Decoder {
let mut sum_logprob = 0f64;
let mut no_speech_prob = f64::NAN;
let mut tokens = vec![self.sot_token];
if let Some(language_token) = self.language_token {
tokens.push(language_token)
match self.task {
Some(Task::Transcribe) => tokens.push(self.transcribe_token),
Some(Task::Translate) => tokens.push(self.translate_token),
None => {
// Nothing in this case, same as the Python implementation.
}
}
if let Some(language_token) = self.language_token {
tokens.push(language_token);
}
if !self.timestamps {
tokens.push(self.no_timestamps_token);
}
tokens.push(self.transcribe_token);
for i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
@ -240,6 +263,12 @@ pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Task {
Transcribe,
Translate,
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel {
Tiny,
@ -313,6 +342,15 @@ struct Args {
/// Language.
#[arg(long)]
language: Option<String>,
/// Task, when no task is specified, the input tokens contain only the sot token which can
/// improve things when in no-timestamp mode.
#[arg(long)]
task: Option<Task>,
/// Timestamps mode, this is not fully implemented yet.
#[arg(long)]
timestamps: bool,
}
fn main() -> Result<()> {
@ -414,7 +452,15 @@ fn main() -> Result<()> {
anyhow::bail!("a language cannot be set for non-multilingual models")
}
};
let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?;
let mut dc = Decoder::new(
model,
tokenizer,
args.seed,
&device,
language_token,
args.task,
args.timestamps,
)?;
dc.run(&mel)?;
Ok(())
}