mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Get the audio-encoder to return some values.
This commit is contained in:
@ -233,13 +233,15 @@ impl Conv1D {
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
||||
let (bsize, _, _) = x.shape().r3()?;
|
||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
||||
let x = x.conv1d(&w, self.config.padding, self.config.stride)?;
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => x.broadcast_add(bias),
|
||||
Some(bias) => {
|
||||
let b = bias.shape().r1()?;
|
||||
let bias = bias.reshape((1, b, 1))?;
|
||||
Ok(x.broadcast_add(&bias)?)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -381,11 +383,9 @@ impl ResidualAttentionBlock {
|
||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
|
||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
|
||||
let mut x = (x + attn)?;
|
||||
// Cross-Attn
|
||||
if let Some((attn, ln)) = &self.cross_attn {
|
||||
x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?;
|
||||
}
|
||||
// Mlp
|
||||
let mlp = self.mlp_linear2.forward(
|
||||
&self
|
||||
.mlp_linear1
|
||||
@ -557,8 +557,8 @@ fn main() -> Result<()> {
|
||||
|
||||
let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
|
||||
let input = input.deserialize()?;
|
||||
let x = input.tensor("x", &device)?.to_dtype(DType::U32)?;
|
||||
let xa = input.tensor("xa", &device)?;
|
||||
let tokens = input.tensor("tokens", &device)?.to_dtype(DType::U32)?;
|
||||
let mel = input.tensor("mel", &device)?;
|
||||
|
||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
||||
let weights = weights.deserialize()?;
|
||||
@ -566,8 +566,8 @@ fn main() -> Result<()> {
|
||||
let cfg = Config::tiny_en();
|
||||
|
||||
let model = Whisper::load(&vb, &cfg)?;
|
||||
let logits = model.decoder.forward(&x, &xa)?;
|
||||
let logits = model.forward(&mel, &tokens)?;
|
||||
println!("{logits}");
|
||||
println!("python logits: {}", input.tensor("logits", &device)?);
|
||||
println!("python logits: {}", input.tensor("dec", &device)?);
|
||||
Ok(())
|
||||
}
|
||||
|
Reference in New Issue
Block a user