Get the audio-encoder to return some values.

This commit is contained in:
laurent
2023-07-04 14:06:09 +01:00
parent b3d4d0fd0f
commit c3739d001b

View File

@ -233,13 +233,15 @@ impl Conv1D {
})
}
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
let (bsize, _, _) = x.shape().r3()?;
let w = self.weight.broadcast_left(bsize)?.t()?;
let x = x.conv1d(&w, self.config.padding, self.config.stride)?;
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
match &self.bias {
None => Ok(x),
Some(bias) => x.broadcast_add(bias),
Some(bias) => {
let b = bias.shape().r1()?;
let bias = bias.reshape((1, b, 1))?;
Ok(x.broadcast_add(&bias)?)
}
}
}
}
@ -381,11 +383,9 @@ impl ResidualAttentionBlock {
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
let mut x = (x + attn)?;
// Cross-Attn
if let Some((attn, ln)) = &self.cross_attn {
x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?;
}
// Mlp
let mlp = self.mlp_linear2.forward(
&self
.mlp_linear1
@ -557,8 +557,8 @@ fn main() -> Result<()> {
let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
let input = input.deserialize()?;
let x = input.tensor("x", &device)?.to_dtype(DType::U32)?;
let xa = input.tensor("xa", &device)?;
let tokens = input.tensor("tokens", &device)?.to_dtype(DType::U32)?;
let mel = input.tensor("mel", &device)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
let weights = weights.deserialize()?;
@ -566,8 +566,8 @@ fn main() -> Result<()> {
let cfg = Config::tiny_en();
let model = Whisper::load(&vb, &cfg)?;
let logits = model.decoder.forward(&x, &xa)?;
let logits = model.forward(&mel, &tokens)?;
println!("{logits}");
println!("python logits: {}", input.tensor("logits", &device)?);
println!("python logits: {}", input.tensor("dec", &device)?);
Ok(())
}