Small cleanup.

This commit is contained in:
laurent
2023-07-04 13:21:59 +01:00
parent 599160605c
commit 99b83773b5

View File

@ -435,10 +435,10 @@ impl AudioEncoder {
}; };
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?; let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?; let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
/* The positional embeddings could be regenerated via the following. */
let positional_embedding = if true { let positional_embedding = if true {
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))? vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?
} else { } else {
/* The positional embeddings could be regenerated via the following. */
sinusoids(n_ctx, n_state)?.to_device(&vb.device)? sinusoids(n_ctx, n_state)?.to_device(&vb.device)?
}; };
let blocks = (0..cfg.n_audio_layer) let blocks = (0..cfg.n_audio_layer)
@ -474,7 +474,6 @@ struct TextDecoder {
positional_embedding: Tensor, positional_embedding: Tensor,
blocks: Vec<ResidualAttentionBlock>, blocks: Vec<ResidualAttentionBlock>,
ln: LayerNorm, ln: LayerNorm,
mask: Tensor,
} }
impl TextDecoder { impl TextDecoder {
@ -492,13 +491,11 @@ impl TextDecoder {
}) })
.collect::<Result<Vec<_>>>()?; .collect::<Result<Vec<_>>>()?;
let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?; let ln = LayerNorm::load(n_state, &format!("{p}.ln"), vb)?;
let mask = Tensor::new(&[0u32], &vb.device)?; // TODO
Ok(Self { Ok(Self {
token_embedding, token_embedding,
positional_embedding, positional_embedding,
blocks, blocks,
ln, ln,
mask,
}) })
} }
fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> { fn forward(&self, x: &Tensor, xa: &Tensor) -> Result<Tensor> {