mirror of
https://github.com/huggingface/candle.git
synced 2025-06-21 12:20:46 +00:00
Add the attention block. (#846)
* Add the attention block. * Add more to clipnext.
This commit is contained in:
@ -78,7 +78,7 @@ fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result<Ten
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CrossAttention {
|
||||
pub struct CrossAttention {
|
||||
to_q: nn::Linear,
|
||||
to_k: nn::Linear,
|
||||
to_v: nn::Linear,
|
||||
@ -94,7 +94,7 @@ struct CrossAttention {
|
||||
|
||||
impl CrossAttention {
|
||||
// Defaults should be heads = 8, dim_head = 64, context_dim = None
|
||||
fn new(
|
||||
pub fn new(
|
||||
vs: nn::VarBuilder,
|
||||
query_dim: usize,
|
||||
context_dim: Option<usize>,
|
||||
@ -205,7 +205,7 @@ impl CrossAttention {
|
||||
self.reshape_batch_dim_to_heads(&xs)
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
pub fn forward(&self, xs: &Tensor, context: Option<&Tensor>) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let query = self.to_q.forward(xs)?;
|
||||
let context = context.unwrap_or(xs).contiguous()?;
|
||||
|
Reference in New Issue
Block a user