From 89b525b5e758218179dd32293e7167e3aae1b28f Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 11 Oct 2023 19:24:32 +0200 Subject: [PATCH] Convmixer (#1073) * Only optimize float tensors. * Use full tensors for zeros and ones. * Add a benchmark for the matmul slowness. * Add the convmixer model. * Proper adaptive pooling. --- candle-nn/examples/cpu_benchmarks.rs | 4 +- candle-transformers/src/models/convmixer.rs | 82 +++++++++++++++++++++ candle-transformers/src/models/mod.rs | 1 + 3 files changed, 85 insertions(+), 2 deletions(-) create mode 100644 candle-transformers/src/models/convmixer.rs diff --git a/candle-nn/examples/cpu_benchmarks.rs b/candle-nn/examples/cpu_benchmarks.rs index e58ea727..6007ff6c 100644 --- a/candle-nn/examples/cpu_benchmarks.rs +++ b/candle-nn/examples/cpu_benchmarks.rs @@ -185,8 +185,8 @@ impl Benchmark for Matmul { type PreProcessData = (Tensor, Tensor); type RunResult = Tensor; fn preprocess() -> Result { - let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?; - let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?; + let lhs = Tensor::randn(0f32, 1., (1024 * 4, 1024 * 4), &Device::Cpu)?; + let rhs = Tensor::randn(0f32, 1., (1024 * 4, 1), &Device::Cpu)?; Ok((lhs, rhs)) } diff --git a/candle-transformers/src/models/convmixer.rs b/candle-transformers/src/models/convmixer.rs new file mode 100644 index 00000000..76245f37 --- /dev/null +++ b/candle-transformers/src/models/convmixer.rs @@ -0,0 +1,82 @@ +use candle::Result; +use candle_nn::{batch_norm, Conv2dConfig, Module, VarBuilder}; + +#[allow(clippy::many_single_char_names)] +fn conv2d_same( + i: usize, + o: usize, + k: usize, + c: Conv2dConfig, + vb: VarBuilder, +) -> Result { + let conv2d = candle_nn::conv2d(i, o, k, c, vb)?; + let s = c.stride; + let module = candle_nn::func(move |xs| { + let ih = xs.dim(2)?; + let iw = xs.dim(3)?; + let oh = (ih + s - 1) / s; + let ow = (iw + s - 1) / s; + let pad_h = usize::max((oh - 1) * s + k - ih, 0); + let pad_w = usize::max((ow - 1) * s + k - iw, 0); + if pad_h > 0 || pad_w > 0 { + xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)? + .pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)? + .apply(&conv2d) + } else { + xs.apply(&conv2d) + } + }); + Ok(module) +} + +fn block(dim: usize, kernel_size: usize, vb: VarBuilder) -> Result { + let conv2d_cfg = Conv2dConfig { + groups: dim, + ..Default::default() + }; + let vb_fn = vb.pp(0).pp("fn"); + let conv1 = conv2d_same(dim, dim, kernel_size, conv2d_cfg, vb_fn.pp(0))?; + let bn1 = batch_norm(dim, 1e-5, vb_fn.pp(2))?; + let conv2 = candle_nn::conv2d(dim, dim, 1, Default::default(), vb.pp(1))?; + let bn2 = batch_norm(dim, 1e-5, vb.pp(3))?; + Ok(candle_nn::func(move |xs| { + let ys = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?; + (xs + ys)?.apply(&conv2)?.gelu_erf()?.apply(&bn2) + })) +} + +fn convmixer( + nclasses: usize, + dim: usize, + depth: usize, + kernel_size: usize, + patch_size: usize, + vb: VarBuilder, +) -> Result> { + let conv2d_cfg = Conv2dConfig { + stride: patch_size, + ..Default::default() + }; + let conv1 = candle_nn::conv2d(3, dim, patch_size, conv2d_cfg, vb.pp(0))?; + let bn1 = batch_norm(dim, 1e-5, vb.pp(2))?; + let blocks: Vec<_> = (0..depth) + .map(|index| block(dim, kernel_size, vb.pp(3 + index))) + .collect::>>()?; + let fc = candle_nn::linear(dim, nclasses, vb.pp(25))?; + Ok(candle_nn::func(move |xs| { + let mut xs = xs.apply(&conv1)?.gelu_erf()?.apply(&bn1)?; + for block in blocks.iter() { + xs = xs.apply(block)? + } + // This performs the adaptive average pooling with a target size of (1, 1). + xs.mean(3)?.mean(2)?.apply(&fc) + })) +} + +pub fn c1536_20(nclasses: usize, vb: VarBuilder) -> Result> { + convmixer(nclasses, 1536, 20, 9, 7, vb) +} + +pub fn c1024_20(nclasses: usize, vb: VarBuilder) -> Result> { + convmixer(nclasses, 1024, 20, 9, 14, vb) +} diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 81044112..aa9ea81a 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -1,5 +1,6 @@ pub mod bert; pub mod bigcode; +pub mod convmixer; pub mod dinov2; pub mod efficientnet; pub mod falcon;