More whisper sampling.

This commit is contained in:
laurent
2023-07-04 22:18:07 +01:00
parent 80f25e6fbb
commit 9fe7a42895

View File

@ -619,13 +619,20 @@ impl Decode {
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
let model = &self.model;
let audio_features = model.encoder.forward(mel)?;
println!("audio features: {:?}", audio_features.dims());
let sample_len = model.config.n_text_ctx / 2;
let mut sum_logprob = 0f64;
let no_speech_prob = f64::NAN;
let mut tokens: Vec<u32> = vec![]; // TODO: get initial tokens
// TODO: 50257 is the start of transcipt token, be more principled about get initial tokens
let mut tokens: Vec<u32> = vec![50257];
for _i in 0..sample_len {
let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?;
// Insert a batch dim.
let tokens_t = tokens_t.unsqueeze(0)?;
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
let logits = logits.squeeze(0)?;
let (seq_len, _) = logits.shape().r2()?;
let logits = logits.get(seq_len - 1)?;
let next_token = if t > 0f64 {
let prs = (&logits / t)?.softmax(logits.rank() - 1)?;
let logits_v: Vec<f32> = prs.to_vec1()?;
@ -698,7 +705,6 @@ fn main() -> Result<()> {
let input = input.deserialize()?;
let mel = input.tensor("mel", &device)?;
println!("loaded mel: {:?}", mel.dims());
let mel = if mel.rank() > 2 { mel.squeeze(0)? } else { mel };
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
let weights = weights.deserialize()?;
@ -710,14 +716,13 @@ fn main() -> Result<()> {
tokenizer,
};
let (_, content_frames) = mel.shape().r2()?;
let content_frames = content_frames - N_SAMPLES;
let (_, _, content_frames) = mel.shape().r3()?;
let mut seek = 0;
let mut segments = vec![];
while seek < content_frames {
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let segment_size = usize::min(content_frames - seek, N_FRAMES);
let mel_segment = mel.narrow(1, seek, segment_size)?;
let mel_segment = mel.narrow(2, seek, segment_size)?;
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
let dr = dc.decode_with_fallback(&mel_segment)?;
seek += segment_size;