mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add the conv1d layer (but not the op).
This commit is contained in:
@ -172,6 +172,79 @@ impl Linear {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
struct ConvConfig {
|
||||||
|
padding: usize,
|
||||||
|
stride: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ConvConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
padding: 0,
|
||||||
|
stride: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Conv1D {
|
||||||
|
weight: Tensor,
|
||||||
|
bias: Option<Tensor>,
|
||||||
|
config: ConvConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Conv1D {
|
||||||
|
fn load(
|
||||||
|
in_channels: usize,
|
||||||
|
out_channels: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
config: ConvConfig,
|
||||||
|
p: &str,
|
||||||
|
vb: &VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = vb.get(
|
||||||
|
(out_channels, in_channels, kernel_size),
|
||||||
|
&format!("{p}.weight"),
|
||||||
|
)?;
|
||||||
|
let bias = vb.get(out_channels, &format!("{p}.bias"))?;
|
||||||
|
Ok(Self {
|
||||||
|
weight,
|
||||||
|
bias: Some(bias),
|
||||||
|
config,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_no_bias(
|
||||||
|
in_channels: usize,
|
||||||
|
out_channels: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
config: ConvConfig,
|
||||||
|
p: &str,
|
||||||
|
vb: &VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let weight = vb.get(
|
||||||
|
(out_channels, in_channels, kernel_size),
|
||||||
|
&format!("{p}.weight"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
weight,
|
||||||
|
bias: None,
|
||||||
|
config,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||||
|
let (bsize, _, _) = x.shape().r3()?;
|
||||||
|
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||||
|
// TODO: Add the conv1d operation
|
||||||
|
let x = x.matmul(&w)?;
|
||||||
|
match &self.bias {
|
||||||
|
None => Ok(x),
|
||||||
|
Some(bias) => x.broadcast_add(bias),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct Dropout {
|
struct Dropout {
|
||||||
pr: f64,
|
pr: f64,
|
||||||
}
|
}
|
||||||
@ -341,8 +414,8 @@ fn sinusoids(length: usize, channels: usize) -> Result<Tensor> {
|
|||||||
|
|
||||||
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
// https://github.com/openai/whisper/blob/f572f2161ba831bae131364c3bffdead7af6d210/whisper/model.py#L143
|
||||||
struct AudioEncoder {
|
struct AudioEncoder {
|
||||||
conv1: Linear, // TODO
|
conv1: Conv1D,
|
||||||
conv2: Linear, // TODO
|
conv2: Conv1D,
|
||||||
positional_embedding: Tensor,
|
positional_embedding: Tensor,
|
||||||
blocks: Vec<ResidualAttentionBlock>,
|
blocks: Vec<ResidualAttentionBlock>,
|
||||||
ln_post: LayerNorm,
|
ln_post: LayerNorm,
|
||||||
@ -352,8 +425,16 @@ impl AudioEncoder {
|
|||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let n_state = cfg.n_audio_state;
|
let n_state = cfg.n_audio_state;
|
||||||
let n_head = cfg.n_audio_head;
|
let n_head = cfg.n_audio_head;
|
||||||
let conv1 = Linear::load(cfg.n_mels, n_state, &format!("{p}.conv1"), vb)?;
|
let cfg1 = ConvConfig {
|
||||||
let conv2 = Linear::load(n_state, n_state, &format!("{p}.conv2"), vb)?;
|
padding: 1,
|
||||||
|
stride: 1,
|
||||||
|
};
|
||||||
|
let cfg2 = ConvConfig {
|
||||||
|
padding: 1,
|
||||||
|
stride: 2,
|
||||||
|
};
|
||||||
|
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
|
||||||
|
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
||||||
let positional_embedding = sinusoids(cfg.n_audio_ctx, n_state)?.to_device(&vb.device)?;
|
let positional_embedding = sinusoids(cfg.n_audio_ctx, n_state)?.to_device(&vb.device)?;
|
||||||
let blocks = (0..cfg.n_audio_layer)
|
let blocks = (0..cfg.n_audio_layer)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
|
Reference in New Issue
Block a user