mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Sinusoid embeddings.
This commit is contained in:
@ -307,6 +307,21 @@ impl ResidualAttentionBlock {
|
||||
}
|
||||
}
|
||||
|
||||
fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
||||
let max_timescale = 10000f32;
|
||||
let log_timescale_increment = max_timescale.ln() / (channels / 2 - 1) as f32;
|
||||
let inv_timescales: Vec<_> = (0..channels / 2)
|
||||
.map(|i| (i as f32 * (-log_timescale_increment)).exp())
|
||||
.collect();
|
||||
let arange: Vec<_> = (0..length).map(|c| c as f32).collect();
|
||||
let inv_timescales = Tensor::new(inv_timescales.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
|
||||
let arange = Tensor::new(arange.as_slice(), &Device::Cpu)?.unsqueeze(1)?;
|
||||
let sh = (length, channels / 2);
|
||||
let scaled_time = (arange.broadcast_as(sh)? * inv_timescales.broadcast_as(sh)?)?;
|
||||
let sincos = Tensor::cat(&[scaled_time.sin()?, scaled_time.cos()?], 1)?;
|
||||
Ok(sincos)
|
||||
}
|
||||
|
||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
||||
struct AudioEncoder {
|
||||
conv1: Linear, // TODO
|
||||
@ -322,7 +337,7 @@ impl AudioEncoder {
|
||||
let n_head = cfg.n_audio_head;
|
||||
let conv1 = Linear::load(cfg.n_mels, n_state, &format!("{p}.conv1"), vb)?;
|
||||
let conv2 = Linear::load(n_state, n_state, &format!("{p}.conv2"), vb)?;
|
||||
let positional_embedding = Tensor::new(&[0u32], &vb.device)?; // TODO
|
||||
let positional_embedding = sinusoids(cfg.n_audio_ctx, n_state)?.to_device(&vb.device)?;
|
||||
let blocks = (0..cfg.n_audio_layer)
|
||||
.map(|i| {
|
||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
||||
|
Reference in New Issue
Block a user