mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
More whisper sampling.
This commit is contained in:
@ -619,13 +619,20 @@ impl Decode {
|
||||
fn decode(&mut self, mel: &Tensor, t: f64) -> Result<DecodingResult> {
|
||||
let model = &self.model;
|
||||
let audio_features = model.encoder.forward(mel)?;
|
||||
println!("audio features: {:?}", audio_features.dims());
|
||||
let sample_len = model.config.n_text_ctx / 2;
|
||||
let mut sum_logprob = 0f64;
|
||||
let no_speech_prob = f64::NAN;
|
||||
let mut tokens: Vec<u32> = vec![]; // TODO: get initial tokens
|
||||
// TODO: 50257 is the start of transcipt token, be more principled about get initial tokens
|
||||
let mut tokens: Vec<u32> = vec![50257];
|
||||
for _i in 0..sample_len {
|
||||
let tokens_t = Tensor::new(tokens.as_slice(), &mel.device())?;
|
||||
// Insert a batch dim.
|
||||
let tokens_t = tokens_t.unsqueeze(0)?;
|
||||
let logits = model.decoder.forward(&tokens_t, &audio_features)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let (seq_len, _) = logits.shape().r2()?;
|
||||
let logits = logits.get(seq_len - 1)?;
|
||||
let next_token = if t > 0f64 {
|
||||
let prs = (&logits / t)?.softmax(logits.rank() - 1)?;
|
||||
let logits_v: Vec<f32> = prs.to_vec1()?;
|
||||
@ -698,7 +705,6 @@ fn main() -> Result<()> {
|
||||
let input = input.deserialize()?;
|
||||
let mel = input.tensor("mel", &device)?;
|
||||
println!("loaded mel: {:?}", mel.dims());
|
||||
let mel = if mel.rank() > 2 { mel.squeeze(0)? } else { mel };
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
@ -710,14 +716,13 @@ fn main() -> Result<()> {
|
||||
tokenizer,
|
||||
};
|
||||
|
||||
let (_, content_frames) = mel.shape().r2()?;
|
||||
let content_frames = content_frames - N_SAMPLES;
|
||||
let (_, _, content_frames) = mel.shape().r3()?;
|
||||
let mut seek = 0;
|
||||
let mut segments = vec![];
|
||||
while seek < content_frames {
|
||||
let time_offset = (seek * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let segment_size = usize::min(content_frames - seek, N_FRAMES);
|
||||
let mel_segment = mel.narrow(1, seek, segment_size)?;
|
||||
let mel_segment = mel.narrow(2, seek, segment_size)?;
|
||||
let segment_duration = (segment_size * HOP_LENGTH) as f64 / SAMPLE_RATE as f64;
|
||||
let dr = dc.decode_with_fallback(&mel_segment)?;
|
||||
seek += segment_size;
|
||||
|
Reference in New Issue
Block a user