Add options to use local files + specify a custom repo or branch. (#1973)

This commit is contained in:
Laurent Mazare
2024-03-31 09:32:50 +02:00
committed by GitHub
parent eead1dcead
commit f9954b73ba
2 changed files with 40 additions and 16 deletions

View File

@ -19,11 +19,8 @@ impl Config {
fn scaled_dot_product_attention(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let dim = q.dim(D::Minus1)?;
let scale_factor = 1.0 / (dim as f64).sqrt();
let k = k.transpose(D::Minus2, D::Minus1)?.contiguous()?;
let mut attn_weights = (q.contiguous()?.matmul(&k)? * scale_factor)?;
attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?.contiguous()?;
let attn_weights = attn_weights.matmul(&v.contiguous()?)?;
Ok(attn_weights)
let attn_weights = (q.matmul(&k.t()?)? * scale_factor)?;
candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v)
}
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
@ -101,10 +98,15 @@ impl Module for Attention {
.apply(&self.qkv)?
.reshape((b, n, 3, self.num_heads, self.head_dim))?
.permute((2, 0, 3, 1, 4))?;
let (q, k, v) = (qkv.i(0)?, qkv.i(1)?, qkv.i(2)?);
let attn_weights = scaled_dot_product_attention(&q, &k, &v)?;
let attn_weights = attn_weights.transpose(1, 2)?.reshape((b, n, c))?;
attn_weights.apply(&self.proj)
let (q, k, v) = (
qkv.i(0)?.contiguous()?,
qkv.i(1)?.contiguous()?,
qkv.i(2)?.contiguous()?,
);
scaled_dot_product_attention(&q, &k, &v)?
.transpose(1, 2)?
.reshape((b, n, c))?
.apply(&self.proj)
}
}
@ -275,11 +277,11 @@ impl Module for VisionEncoder {
let (p1, p2) = (14, 14);
let h = hp1 / p1;
let w = wp2 / p2;
let xs = xs
.reshape((b, c, h, p1, h, p2))?
xs.reshape((b, c, h, p1, h, p2))?
.permute((0, 2, 4, 1, 3, 5))?
.reshape((b, h * w, c * p1 * p2))?;
xs.apply(&self.encoder)?.apply(&self.projection)
.reshape((b, h * w, c * p1 * p2))?
.apply(&self.encoder)?
.apply(&self.projection)
}
}