Fix the musicgen example. (#724)

* Fix the musicgen example.

* Retrieve the weights from the hub.
This commit is contained in:
Laurent Mazare
2023-09-03 15:50:39 +02:00
committed by GitHub
parent f7980e07e0
commit bbec527bb9
5 changed files with 62 additions and 134 deletions

View File

@ -1,7 +1,6 @@
use crate::nn::{conv1d, conv1d_weight_norm, Conv1d, Conv1dConfig, VarBuilder};
use anyhow::Result;
use candle::{DType, IndexOp, Tensor};
use candle_nn::Module;
use crate::nn::conv1d_weight_norm;
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{conv1d, Conv1d, Conv1dConfig, Module, VarBuilder};
// Encodec Model
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
@ -183,7 +182,7 @@ impl EncodecResidualVectorQuantizer {
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
if codes.dim(0)? != self.layers.len() {
anyhow::bail!(
candle::bail!(
"codes shape {:?} does not match the number of quantization layers {}",
codes.shape(),
self.layers.len()
@ -321,7 +320,7 @@ impl EncodecResnetBlock {
let h = dim / cfg.compress;
let mut layer = Layer::new(vb.pp("block"));
if dilations.len() != 2 {
anyhow::bail!("expected dilations of size 2")
candle::bail!("expected dilations of size 2")
}
// TODO: Apply dilations!
layer.inc();