mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Get the audio-encoder to return some values.
This commit is contained in:
@ -233,13 +233,15 @@ impl Conv1D {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
let (bsize, _, _) = x.shape().r3()?;
|
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
|
||||||
let w = self.weight.broadcast_left(bsize)?.t()?;
|
|
||||||
let x = x.conv1d(&w, self.config.padding, self.config.stride)?;
|
|
||||||
match &self.bias {
|
match &self.bias {
|
||||||
None => Ok(x),
|
None => Ok(x),
|
||||||
Some(bias) => x.broadcast_add(bias),
|
Some(bias) => {
|
||||||
|
let b = bias.shape().r1()?;
|
||||||
|
let bias = bias.reshape((1, b, 1))?;
|
||||||
|
Ok(x.broadcast_add(&bias)?)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -381,11 +383,9 @@ impl ResidualAttentionBlock {
|
|||||||
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor, xa: Option<&Tensor>) -> Result<Tensor> {
|
||||||
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
|
let attn = self.attn.forward(&self.attn_ln.forward(x)?, None)?;
|
||||||
let mut x = (x + attn)?;
|
let mut x = (x + attn)?;
|
||||||
// Cross-Attn
|
|
||||||
if let Some((attn, ln)) = &self.cross_attn {
|
if let Some((attn, ln)) = &self.cross_attn {
|
||||||
x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?;
|
x = (&x + attn.forward(&ln.forward(&x)?, xa)?)?;
|
||||||
}
|
}
|
||||||
// Mlp
|
|
||||||
let mlp = self.mlp_linear2.forward(
|
let mlp = self.mlp_linear2.forward(
|
||||||
&self
|
&self
|
||||||
.mlp_linear1
|
.mlp_linear1
|
||||||
@ -557,8 +557,8 @@ fn main() -> Result<()> {
|
|||||||
|
|
||||||
let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
|
let input = unsafe { candle::safetensors::MmapedFile::new(args.input)? };
|
||||||
let input = input.deserialize()?;
|
let input = input.deserialize()?;
|
||||||
let x = input.tensor("x", &device)?.to_dtype(DType::U32)?;
|
let tokens = input.tensor("tokens", &device)?.to_dtype(DType::U32)?;
|
||||||
let xa = input.tensor("xa", &device)?;
|
let mel = input.tensor("mel", &device)?;
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
let weights = unsafe { candle::safetensors::MmapedFile::new(args.weights)? };
|
||||||
let weights = weights.deserialize()?;
|
let weights = weights.deserialize()?;
|
||||||
@ -566,8 +566,8 @@ fn main() -> Result<()> {
|
|||||||
let cfg = Config::tiny_en();
|
let cfg = Config::tiny_en();
|
||||||
|
|
||||||
let model = Whisper::load(&vb, &cfg)?;
|
let model = Whisper::load(&vb, &cfg)?;
|
||||||
let logits = model.decoder.forward(&x, &xa)?;
|
let logits = model.forward(&mel, &tokens)?;
|
||||||
println!("{logits}");
|
println!("{logits}");
|
||||||
println!("python logits: {}", input.tensor("logits", &device)?);
|
println!("python logits: {}", input.tensor("dec", &device)?);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user