Add the attention block. (#846)

* Add the attention block.

* Add more to clipnext.
This commit is contained in:
Laurent Mazare
2023-09-14 16:40:09 +02:00
committed by GitHub
parent 286f01db14
commit a0c6d5548c
4 changed files with 98 additions and 8 deletions

View File

@ -78,7 +78,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
}
#[derive(Debug)]
struct CrossAttention {
pub struct CrossAttention {
to_q: nn::Linear,
to_k: nn::Linear,
to_v: nn::Linear,
@ -94,7 +94,7 @@ struct CrossAttention {
impl CrossAttention {
// Defaults should be heads = 8, dim_head = 64, context_dim = None
fn new(
pub fn new(
vs: nn::VarBuilder,
query_dim: usize,
context_dim: Option<usize>,
@ -205,7 +205,7 @@ impl CrossAttention {
self.reshape_batch_dim_to_heads(&xs)
}
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
let _enter = self.span.enter();
let query = self.to_q.forward(xs)?;
let context = context.unwrap_or(xs).contiguous()?;