Use multiple transformer layer in the same cross-attn blocks. (#653)

* Use multiple transformer layer in the same cross-attn blocks.

* Make the context contiguous if required.
This commit is contained in:
Laurent Mazare
2023-08-29 11:13:43 +01:00
committed by GitHub
parent d0a330448d
commit 62ef494dc1
4 changed files with 43 additions and 22 deletions

View File

@ -208,9 +208,9 @@ impl CrossAttention {
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let query = self.to_q.forward(xs)?;
let context = context.unwrap_or(xs);
let key = self.to_k.forward(context)?;
let value = self.to_v.forward(context)?;
let context = context.unwrap_or(xs).contiguous()?;
let key = self.to_k.forward(&context)?;
let value = self.to_v.forward(&context)?;
let query = self.reshape_heads_to_batch_dim(&query)?;
let key = self.reshape_heads_to_batch_dim(&key)?;
let value = self.reshape_heads_to_batch_dim(&value)?;