Use the stored embeddings.

This commit is contained in:
laurent
2023-07-04 13:13:28 +01:00
parent 0d99b43792
commit 599160605c

View File

@ -424,6 +424,7 @@ impl AudioEncoder {
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> { fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
let n_state = cfg.n_audio_state; let n_state = cfg.n_audio_state;
let n_head = cfg.n_audio_head; let n_head = cfg.n_audio_head;
let n_ctx = cfg.n_audio_ctx;
let cfg1 = ConvConfig { let cfg1 = ConvConfig {
padding: 1, padding: 1,
stride: 1, stride: 1,
@ -434,7 +435,12 @@ impl AudioEncoder {
}; };
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?; let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?; let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
let positional_embedding = sinusoids(cfg.n_audio_ctx, n_state)?.to_device(&vb.device)?; /* The positional embeddings could be regenerated via the following. */
let positional_embedding = if true {
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?
} else {
sinusoids(n_ctx, n_state)?.to_device(&vb.device)?
};
let blocks = (0..cfg.n_audio_layer) let blocks = (0..cfg.n_audio_layer)
.map(|i| { .map(|i| {
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb) ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)