mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Use multiple transformer layer in the same cross-attn blocks. (#653)
* Use multiple transformer layer in the same cross-attn blocks. * Make the context contiguous if required.
This commit is contained in:
@ -208,9 +208,9 @@ impl CrossAttention {
|
||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query = self.to_q.forward(xs)?;
|
||||
let context = context.unwrap_or(xs);
|
||||
let key = self.to_k.forward(context)?;
|
||||
let value = self.to_v.forward(context)?;
|
||||
let context = context.unwrap_or(xs).contiguous()?;
|
||||
let key = self.to_k.forward(&context)?;
|
||||
let value = self.to_v.forward(&context)?;
|
||||
let query = self.reshape_heads_to_batch_dim(&query)?;
|
||||
let key = self.reshape_heads_to_batch_dim(&key)?;
|
||||
let value = self.reshape_heads_to_batch_dim(&value)?;
|
||||
|
Reference in New Issue
Block a user