Automatic mask generation (#779)

* A few more contiguous fixes for cuda.

* Mask generation.

* Generic bbox.

* Generate all the masks.
This commit is contained in:
Laurent Mazare
2023-09-08 19:11:34 +01:00
committed by GitHub
parent 158ff3c609
commit 0906acab91
7 changed files with 125 additions and 26 deletions

View File

@ -45,9 +45,9 @@ impl Attention {
}
fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let q = self.q_proj.forward(q)?;
let k = self.k_proj.forward(k)?;
let v = self.v_proj.forward(v)?;
let q = self.q_proj.forward(&q.contiguous()?)?;
let k = self.k_proj.forward(&k.contiguous()?)?;
let v = self.v_proj.forward(&v.contiguous()?)?;
let q = self.separate_heads(&q)?;
let k = self.separate_heads(&k)?;