mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 20:22:49 +00:00
Add a flag to select the dtype used in metavoice. (#1805)
This commit is contained in:
@ -562,6 +562,7 @@ pub mod gpt {
|
||||
ln_f: Norm,
|
||||
lm_heads: Vec<Linear>,
|
||||
cfg: Config,
|
||||
dtype: DType,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -596,6 +597,7 @@ pub mod gpt {
|
||||
ln_f,
|
||||
lm_heads,
|
||||
cfg,
|
||||
dtype: vb.dtype(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -608,7 +610,7 @@ pub mod gpt {
|
||||
let (b, _num_hierarchies, t) = idx.dims3()?;
|
||||
let pos = Tensor::arange(0u32, t as u32, device)?;
|
||||
let pos_emb = pos.apply(&self.wpe)?;
|
||||
let mut tok_emb = Tensor::zeros((b, t, self.cfg.n_embd), DType::F32, device)?;
|
||||
let mut tok_emb = Tensor::zeros((b, t, self.cfg.n_embd), self.dtype, device)?;
|
||||
for (wte_idx, wte) in self.wtes.iter().enumerate() {
|
||||
let emb = idx.i((.., wte_idx, ..))?.apply(wte)?;
|
||||
tok_emb = (tok_emb + emb)?;
|
||||
@ -847,10 +849,11 @@ pub mod transformer {
|
||||
}
|
||||
let norm = rms_norm(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
|
||||
let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
|
||||
let dtype = vb.dtype();
|
||||
let spk_cond_mask = Tensor::cat(
|
||||
&[
|
||||
Tensor::ones((1, 1, cfg.dim), DType::F32, vb.device())?,
|
||||
Tensor::zeros((1, 1, cfg.dim), DType::F32, vb.device())?,
|
||||
Tensor::ones((1, 1, cfg.dim), dtype, vb.device())?,
|
||||
Tensor::zeros((1, 1, cfg.dim), dtype, vb.device())?,
|
||||
],
|
||||
0,
|
||||
)?;
|
||||
@ -887,6 +890,7 @@ pub mod transformer {
|
||||
.apply(&self.speaker_cond_pos)?
|
||||
.broadcast_mul(&self.spk_cond_mask)?,
|
||||
)?;
|
||||
let mask = mask.to_dtype(xs.dtype())?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, pos, &mask)?
|
||||
}
|
||||
|
Reference in New Issue
Block a user