mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Use candle_nn::LSTM in encodec. (#1051)
* Use candle_nn::LSTM in encodec. * More Encodec implementation. * Decoder implementation.
This commit is contained in:
@ -1,6 +1,6 @@
|
||||
use crate::nn::conv1d_weight_norm;
|
||||
use candle::{DType, IndexOp, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
|
||||
use candle::{DType, IndexOp, Module, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
@ -199,25 +199,34 @@ impl EncodecResidualVectorQuantizer {
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||
#[derive(Debug)]
|
||||
struct EncodecLSTM {
|
||||
layers: Vec<(Tensor, Tensor, Tensor, Tensor)>,
|
||||
layers: Vec<candle_nn::LSTM>,
|
||||
}
|
||||
|
||||
impl EncodecLSTM {
|
||||
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("lstm");
|
||||
let mut layers = vec![];
|
||||
for i in 0..cfg.num_lstm_layers {
|
||||
let w_hh = vb.get((4 * dim, dim), &format!("weight_hh_l{i}"))?;
|
||||
let w_ih = vb.get((4 * dim, dim), &format!("weight_ih_l{i}"))?;
|
||||
let b_hh = vb.get(4 * dim, &format!("bias_hh_l{i}"))?;
|
||||
let b_ih = vb.get(4 * dim, &format!("bias_ih_l{i}"))?;
|
||||
layers.push((w_hh, w_ih, b_hh, b_ih))
|
||||
for layer_idx in 0..cfg.num_lstm_layers {
|
||||
let config = candle_nn::LSTMConfig {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
||||
layers.push(lstm)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
impl Module for EncodecLSTM {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
use candle_nn::RNN;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let states = layer.seq(&xs)?;
|
||||
xs = layer.states_to_tensor(&states)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
@ -247,7 +256,9 @@ impl EncodecConvTranspose1d {
|
||||
bias,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConvTranspose1d {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
@ -299,7 +310,9 @@ impl EncodecConv1d {
|
||||
conv,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: padding, depending on causal.
|
||||
let xs = self.conv.forward(xs)?;
|
||||
@ -340,7 +353,9 @@ impl EncodecResnetBlock {
|
||||
shortcut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = xs.elu(1.)?;
|
||||
@ -439,8 +454,17 @@ impl EncodecEncoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?;
|
||||
for (resnets, conv) in self.sampling_layers.iter() {
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?;
|
||||
}
|
||||
xs = xs.elu(1.0)?.apply(conv)?;
|
||||
}
|
||||
xs.apply(&self.final_lstm)?
|
||||
.elu(1.0)?
|
||||
.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
@ -507,8 +531,15 @@ impl EncodecDecoder {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
||||
for (conv, resnets) in self.sampling_layers.iter() {
|
||||
xs = xs.elu(1.)?.apply(conv)?;
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?
|
||||
}
|
||||
}
|
||||
xs.elu(1.)?.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user