Add a flag to select the dtype used in metavoice. (#1805)

This commit is contained in:
Laurent Mazare
2024-03-05 12:16:00 +01:00
committed by GitHub
parent bd9ab9bc04
commit 8a99cf7dd2
3 changed files with 35 additions and 15 deletions

View File

@ -297,6 +297,7 @@ impl VectorQuantization {
#[derive(Clone, Debug)]
pub struct ResidualVectorQuantizer {
layers: Vec<VectorQuantization>,
dtype: DType,
}
impl ResidualVectorQuantizer {
@ -305,7 +306,10 @@ impl ResidualVectorQuantizer {
let layers = (0..cfg.num_quantizers())
.map(|i| VectorQuantization::new(cfg, vb.pp(i)))
.collect::<Result<Vec<_>>>()?;
Ok(Self { layers })
Ok(Self {
layers,
dtype: vb.dtype(),
})
}
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
@ -321,7 +325,7 @@ impl ResidualVectorQuantizer {
}
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
let mut quantized_out = Tensor::zeros((), self.dtype, codes.device())?;
let ncodes = codes.dim(0)?;
if ncodes > self.layers.len() {
candle::bail!(