mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use the stored embeddings.
This commit is contained in:
@ -424,6 +424,7 @@ impl AudioEncoder {
|
|||||||
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
fn load(p: &str, vb: &VarBuilder, cfg: &Config) -> Result<Self> {
|
||||||
let n_state = cfg.n_audio_state;
|
let n_state = cfg.n_audio_state;
|
||||||
let n_head = cfg.n_audio_head;
|
let n_head = cfg.n_audio_head;
|
||||||
|
let n_ctx = cfg.n_audio_ctx;
|
||||||
let cfg1 = ConvConfig {
|
let cfg1 = ConvConfig {
|
||||||
padding: 1,
|
padding: 1,
|
||||||
stride: 1,
|
stride: 1,
|
||||||
@ -434,7 +435,12 @@ impl AudioEncoder {
|
|||||||
};
|
};
|
||||||
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
|
let conv1 = Conv1D::load(cfg.n_mels, n_state, 3, cfg1, &format!("{p}.conv1"), vb)?;
|
||||||
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
let conv2 = Conv1D::load(n_state, n_state, 3, cfg2, &format!("{p}.conv2"), vb)?;
|
||||||
let positional_embedding = sinusoids(cfg.n_audio_ctx, n_state)?.to_device(&vb.device)?;
|
/* The positional embeddings could be regenerated via the following. */
|
||||||
|
let positional_embedding = if true {
|
||||||
|
vb.get((n_ctx, n_state), &format!("{p}.positional_embedding"))?
|
||||||
|
} else {
|
||||||
|
sinusoids(n_ctx, n_state)?.to_device(&vb.device)?
|
||||||
|
};
|
||||||
let blocks = (0..cfg.n_audio_layer)
|
let blocks = (0..cfg.n_audio_layer)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
ResidualAttentionBlock::load(n_state, n_head, false, &format!("{p}.blocks.{i}"), vb)
|
||||||
|
Reference in New Issue
Block a user