mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Small cleanup.
This commit is contained in:
@ -435,10 +435,10 @@ impl AudioEncoder {
|
|||||||
};
|
};
|
||||||
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
|
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
|
||||||
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
||||||
/* The positional embeddings could be regenerated via the following. */
|
|
||||||
let positional_embedding = if true {
|
let positional_embedding = if true {
|
||||||
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?
|
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?
|
||||||
} else {
|
} else {
|
||||||
|
/* The positional embeddings could be regenerated via the following. */
|
||||||
sinusoids(n_ctx, n_state)?.to_device(&vb.device)?
|
sinusoids(n_ctx, n_state)?.to_device(&vb.device)?
|
||||||
};
|
};
|
||||||
let blocks = (0..cfg.n_audio_layer)
|
let blocks = (0..cfg.n_audio_layer)
|
||||||
@ -474,7 +474,6 @@ struct TextDecoder {
|
|||||||
positional_embedding: Tensor,
|
positional_embedding: Tensor,
|
||||||
blocks: Vec<ResidualAttentionBlock>,
|
blocks: Vec<ResidualAttentionBlock>,
|
||||||
ln: LayerNorm,
|
ln: LayerNorm,
|
||||||
mask: Tensor,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TextDecoder {
|
impl TextDecoder {
|
||||||
@ -492,13 +491,11 @@ impl TextDecoder {
|
|||||||
})
|
})
|
||||||
.collect::<Result<Vec<_>>>()?;
|
.collect::<Result<Vec<_>>>()?;
|
||||||
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
|
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
|
||||||
let mask = Tensor::new(&[0u32], &vb.device)?; // TODO
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
token_embedding,
|
token_embedding,
|
||||||
positional_embedding,
|
positional_embedding,
|
||||||
blocks,
|
blocks,
|
||||||
ln,
|
ln,
|
||||||
mask,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {
|
||||||
|
Reference in New Issue
Block a user