mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add a flag to select the dtype used in metavoice. (#1805)
This commit is contained in:
@ -297,6 +297,7 @@ impl VectorQuantization {
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ResidualVectorQuantizer {
|
||||
layers: Vec<VectorQuantization>,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl ResidualVectorQuantizer {
|
||||
@ -305,7 +306,10 @@ impl ResidualVectorQuantizer {
|
||||
let layers = (0..cfg.num_quantizers())
|
||||
.map(|i| VectorQuantization::new(cfg, vb.pp(i)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self { layers })
|
||||
Ok(Self {
|
||||
layers,
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
@ -321,7 +325,7 @@ impl ResidualVectorQuantizer {
|
||||
}
|
||||
|
||||
pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
||||
let mut quantized_out = Tensor::zeros((), self.dtype, codes.device())?;
|
||||
let ncodes = codes.dim(0)?;
|
||||
if ncodes > self.layers.len() {
|
||||
candle::bail!(
|
||||
|
Reference in New Issue
Block a user