mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
FastViT fixes. (#2452)
* correct optional SE layer dimensions. * head_dim instead of num_heads is 32. * update test example output.
This commit is contained in:
@ -339,8 +339,8 @@ fn positional_encoding(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
fn attention(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
|
||||
let qkv = linear_no_bias(dim, dim * 3, vb.pp("qkv"))?;
|
||||
let proj = linear(dim, dim, vb.pp("proj"))?;
|
||||
let num_heads = 32;
|
||||
let head_dim = dim / num_heads;
|
||||
let head_dim = 32;
|
||||
let num_heads = dim / head_dim;
|
||||
let scale = (head_dim as f64).powf(-0.5);
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
@ -434,7 +434,7 @@ fn fastvit_patch_embed(
|
||||
) -> Result<Func<'static>> {
|
||||
let lk = conv_norm(in_channels, out_channels, 7, 2, vb.pp("proj.0.large_conv"))?;
|
||||
let sk = conv_norm(in_channels, out_channels, 3, 2, vb.pp("proj.0.small_conv"))?;
|
||||
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("proj.0.se"));
|
||||
let se = squeeze_and_excitation(out_channels, out_channels / 4, vb.pp("proj.0.se"));
|
||||
let mb = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp("proj.1"))?;
|
||||
|
||||
Ok(Func::new(move |xs| {
|
||||
|
Reference in New Issue
Block a user