Add the conv1d layer (but not the op).

This commit is contained in:
laurent
2023-07-04 10:01:05 +01:00
parent d71b31144d
commit e6b01d0c18

View File

@ -172,6 +172,79 @@ impl Linear {
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct ConvConfig {
padding: usize,
stride: usize,
}
impl Default for ConvConfig {
fn default() -> Self {
Self {
padding: 0,
stride: 1,
}
}
}
struct Conv1D {
weight: Tensor,
bias: Option<Tensor>,
config: ConvConfig,
}
impl Conv1D {
fn load(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: ConvConfig,
p: &str,
vb: &VarBuilder,
) -> Result<Self> {
let weight = vb.get(
(out_channels, in_channels, kernel_size),
&format!("{p}.weight"),
)?;
let bias = vb.get(out_channels, &format!("{p}.bias"))?;
Ok(Self {
weight,
bias: Some(bias),
config,
})
}
fn load_no_bias(
in_channels: usize,
out_channels: usize,
kernel_size: usize,
config: ConvConfig,
p: &str,
vb: &VarBuilder,
) -> Result<Self> {
let weight = vb.get(
(out_channels, in_channels, kernel_size),
&format!("{p}.weight"),
)?;
Ok(Self {
weight,
bias: None,
config,
})
}
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let (bsize, _, _) = x.shape().r3()?;
let w = self.weight.broadcast_left(bsize)?.t()?;
// TODO: Add the conv1d operation
let x = x.matmul(&w)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
}
}
}
struct Dropout {
pr: f64,
}
@ -341,8 +414,8 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
struct AudioEncoder {
conv1: Linear, // TODO
conv2: Linear, // TODO
conv1: Conv1D,
conv2: Conv1D,
positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>,
ln_post: LayerNorm,
@ -352,8 +425,16 @@ impl AudioEncoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.n_audio_state;
let n_head = cfg.n_audio_head;
let conv1 = Linear::load(cfg.n_mels, n_state, &format!("{p}.conv1"), vb)?;
let conv2 = Linear::load(n_state, n_state, &format!("{p}.conv2"), vb)?;
let cfg1 = ConvConfig {
padding: 1,
stride: 1,
};
let cfg2 = ConvConfig {
padding: 1,
stride: 2,
};
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
let positional_embedding = sinusoids(cfg.n_audio_ctx, n_state)?.to_device(&vb.device)?;
let blocks = (0..cfg.n_audio_layer)
.map(|i| {