mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Segment Anything - process images (#766)
* Start processing images. * Add LayerNorm2d. * Properly use LayerNorm2d. * Tweak eps. * Use LayerNorm on inputs with a rank different from 3. * Window partitioning. * Fix a couple todos. * More todos. * Hard-code the einsums. * More padding support. * Some sizes tweaks. * Use the hub to get the weights. * Use a batch matmul. * Tweaks. * More fixes. * Get some predictions to be generated.
This commit is contained in:
@ -36,7 +36,8 @@ impl Attention {
|
||||
fn separate_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let (b, n, c) = x.dims3()?;
|
||||
x.reshape((b, n, self.num_heads, c / self.num_heads))?
|
||||
.transpose(1, 2)
|
||||
.transpose(1, 2)?
|
||||
.contiguous()
|
||||
}
|
||||
|
||||
fn recombine_heads(&self, x: &Tensor) -> Result<Tensor> {
|
||||
@ -102,8 +103,12 @@ impl TwoWayAttentionBlock {
|
||||
2,
|
||||
vb.pp("cross_attn_image_to_token"),
|
||||
)?;
|
||||
// TODO: use relu in this mlp
|
||||
let mlp = crate::MlpBlock::new(embedding_dim, mlp_dim, vb.pp("mlp"))?;
|
||||
let mlp = crate::MlpBlock::new(
|
||||
embedding_dim,
|
||||
mlp_dim,
|
||||
candle_nn::Activation::Relu,
|
||||
vb.pp("mlp"),
|
||||
)?;
|
||||
Ok(Self {
|
||||
self_attn,
|
||||
norm1,
|
||||
@ -126,7 +131,7 @@ impl TwoWayAttentionBlock {
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
// Self attention block
|
||||
let queries = if self.skip_first_layer_pe {
|
||||
self.self_attn.forward(queries, keys, queries)?
|
||||
self.self_attn.forward(queries, queries, queries)?
|
||||
} else {
|
||||
let q = (queries + query_pe)?;
|
||||
let attn_out = self.self_attn.forward(&q, &q, queries)?;
|
||||
|
Reference in New Issue
Block a user