mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Replicate the sot-token logic from the Python implementation more acc… (#491)
* Replicate the sot-token logic from the Python implementation more accurately. * Add a flag to control the timestamp mode.
This commit is contained in:
@ -41,6 +41,8 @@ const COMPRESSION_RATIO_THRESHOLD: f64 = 2.4;
|
|||||||
// Tokenizer dependent bits.
|
// Tokenizer dependent bits.
|
||||||
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
const SOT_TOKEN: &str = "<|startoftranscript|>";
|
||||||
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
const TRANSCRIBE_TOKEN: &str = "<|transcribe|>";
|
||||||
|
const TRANSLATE_TOKEN: &str = "<|translate|>";
|
||||||
|
const NO_TIMESTAMPS_TOKEN: &str = "<|notimestamps|>";
|
||||||
const EOT_TOKEN: &str = "<|endoftext|>";
|
const EOT_TOKEN: &str = "<|endoftext|>";
|
||||||
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
const NO_SPEECH_TOKEN: &str = "<|nocaptions|>";
|
||||||
|
|
||||||
@ -66,12 +68,16 @@ struct Segment {
|
|||||||
struct Decoder {
|
struct Decoder {
|
||||||
model: Whisper,
|
model: Whisper,
|
||||||
rng: rand::rngs::StdRng,
|
rng: rand::rngs::StdRng,
|
||||||
|
task: Option<Task>,
|
||||||
|
timestamps: bool,
|
||||||
tokenizer: Tokenizer,
|
tokenizer: Tokenizer,
|
||||||
suppress_tokens: Tensor,
|
suppress_tokens: Tensor,
|
||||||
sot_token: u32,
|
sot_token: u32,
|
||||||
transcribe_token: u32,
|
transcribe_token: u32,
|
||||||
|
translate_token: u32,
|
||||||
eot_token: u32,
|
eot_token: u32,
|
||||||
no_speech_token: u32,
|
no_speech_token: u32,
|
||||||
|
no_timestamps_token: u32,
|
||||||
language_token: Option<u32>,
|
language_token: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,6 +88,8 @@ impl Decoder {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
language_token: Option<u32>,
|
language_token: Option<u32>,
|
||||||
|
task: Option<Task>,
|
||||||
|
timestamps: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
let suppress_tokens: Vec<f32> = (0..model.config.vocab_size as u32)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
@ -95,18 +103,24 @@ impl Decoder {
|
|||||||
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
let suppress_tokens = Tensor::new(suppress_tokens.as_slice(), device)?;
|
||||||
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
let sot_token = token_id(&tokenizer, SOT_TOKEN)?;
|
||||||
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
let transcribe_token = token_id(&tokenizer, TRANSCRIBE_TOKEN)?;
|
||||||
|
let translate_token = token_id(&tokenizer, TRANSLATE_TOKEN)?;
|
||||||
|
let no_timestamps_token = token_id(&tokenizer, NO_TIMESTAMPS_TOKEN)?;
|
||||||
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
let eot_token = token_id(&tokenizer, EOT_TOKEN)?;
|
||||||
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
let no_speech_token = token_id(&tokenizer, NO_SPEECH_TOKEN)?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
model,
|
model,
|
||||||
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
rng: rand::rngs::StdRng::seed_from_u64(seed),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
|
task,
|
||||||
|
timestamps,
|
||||||
suppress_tokens,
|
suppress_tokens,
|
||||||
sot_token,
|
sot_token,
|
||||||
transcribe_token,
|
transcribe_token,
|
||||||
|
translate_token,
|
||||||
eot_token,
|
eot_token,
|
||||||
no_speech_token,
|
no_speech_token,
|
||||||
language_token,
|
language_token,
|
||||||
|
no_timestamps_token,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -118,10 +132,19 @@ impl Decoder {
|
|||||||
let mut sum_logprob = 0f64;
|
let mut sum_logprob = 0f64;
|
||||||
let mut no_speech_prob = f64::NAN;
|
let mut no_speech_prob = f64::NAN;
|
||||||
let mut tokens = vec![self.sot_token];
|
let mut tokens = vec![self.sot_token];
|
||||||
if let Some(language_token) = self.language_token {
|
match self.task {
|
||||||
tokens.push(language_token)
|
Some(Task::Transcribe) => tokens.push(self.transcribe_token),
|
||||||
|
Some(Task::Translate) => tokens.push(self.translate_token),
|
||||||
|
None => {
|
||||||
|
// Nothing in this case, same as the Python implementation.
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(language_token) = self.language_token {
|
||||||
|
tokens.push(language_token);
|
||||||
|
}
|
||||||
|
if !self.timestamps {
|
||||||
|
tokens.push(self.no_timestamps_token);
|
||||||
}
|
}
|
||||||
tokens.push(self.transcribe_token);
|
|
||||||
for i in 0..sample_len {
|
for i in 0..sample_len {
|
||||||
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
let tokens_t = Tensor::new(tokens.as_slice(), mel.device())?;
|
||||||
|
|
||||||
@ -240,6 +263,12 @@ pub fn token_id(tokenizer: &Tokenizer, token: &str) -> candle::Result<u32> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Task {
|
||||||
|
Transcribe,
|
||||||
|
Translate,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum WhichModel {
|
enum WhichModel {
|
||||||
Tiny,
|
Tiny,
|
||||||
@ -313,6 +342,15 @@ struct Args {
|
|||||||
/// Language.
|
/// Language.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
language: Option<String>,
|
language: Option<String>,
|
||||||
|
|
||||||
|
/// Task, when no task is specified, the input tokens contain only the sot token which can
|
||||||
|
/// improve things when in no-timestamp mode.
|
||||||
|
#[arg(long)]
|
||||||
|
task: Option<Task>,
|
||||||
|
|
||||||
|
/// Timestamps mode, this is not fully implemented yet.
|
||||||
|
#[arg(long)]
|
||||||
|
timestamps: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
@ -414,7 +452,15 @@ fn main() -> Result<()> {
|
|||||||
anyhow::bail!("a language cannot be set for non-multilingual models")
|
anyhow::bail!("a language cannot be set for non-multilingual models")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let mut dc = Decoder::new(model, tokenizer, args.seed, &device, language_token)?;
|
let mut dc = Decoder::new(
|
||||||
|
model,
|
||||||
|
tokenizer,
|
||||||
|
args.seed,
|
||||||
|
&device,
|
||||||
|
language_token,
|
||||||
|
args.task,
|
||||||
|
args.timestamps,
|
||||||
|
)?;
|
||||||
dc.run(&mel)?;
|
dc.run(&mel)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user