diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..7f997cb0 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,13 @@ +# Changelog +This documents the main changes to the `candle` crate. + +## Unreleased +### Added +- Add a group parameter to convolutions + [566](https://github.com/huggingface/candle/pull/566). +- New dtype: int64 + [563](https://github.com/huggingface/candle/pull/563). +- Handling of the GGUF file format. + [559](https://github.com/huggingface/candle/pull/559). + +## v0.1.2 - 2023-08-21 diff --git a/Cargo.toml b/Cargo.toml index 7957c038..d391ee7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,7 @@ exclude = [ ] [workspace.package] -version = "0.1.2" +version = "0.1.3" edition = "2021" description = "Minimalist ML framework." repository = "https://github.com/huggingface/candle" diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index b190c55e..3b3e4eb7 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -12,7 +12,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } byteorder = { workspace = true } -candle-kernels = { path = "../candle-kernels", version = "0.1.2", optional = true } +candle-kernels = { path = "../candle-kernels", version = "0.1.3", optional = true } cudarc = { workspace = true, optional = true } gemm = { workspace = true } half = { workspace = true } diff --git a/candle-core/examples/basics.rs b/candle-core/examples/basics.rs index efce913a..9d4734de 100644 --- a/candle-core/examples/basics.rs +++ b/candle-core/examples/basics.rs @@ -11,7 +11,7 @@ fn main() -> Result<()> { let inp = Tensor::randn(0f32, 1., (2, 320, 96, 96), &Device::Cpu)?; let w = Tensor::randn(0f32, 1., (320, 320, 3, 3), &Device::Cpu)?; let start = std::time::Instant::now(); - let res = inp.conv2d(&w, 0, 1); + let res = inp.conv2d(&w, 0, 1, 1)?; println!("{:?}", start.elapsed()); println!("{res:?}"); Ok(()) diff --git a/candle-core/examples/cpu_benchmarks.rs b/candle-core/examples/cpu_benchmarks.rs index d7f60f81..1ebd9b75 100644 --- a/candle-core/examples/cpu_benchmarks.rs +++ b/candle-core/examples/cpu_benchmarks.rs @@ -40,7 +40,7 @@ impl Benchmark for Conv1d { } fn run_one(d: &Self::PreProcessData) -> Result { - d.0.conv1d(&d.1, 0, 1) + d.0.conv1d(&d.1, 0, 1, 1) } const ITERS: usize = 5; @@ -59,7 +59,7 @@ impl Benchmark for Conv2d { } fn run_one(d: &Self::PreProcessData) -> Result { - d.0.conv2d(&d.1, 0, 1) + d.0.conv2d(&d.1, 0, 1, 1) } const ITERS: usize = 1; diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 12febb60..ac435488 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -11,7 +11,7 @@ fn main() -> Result<()> { let device = Device::new_cuda(0)?; let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?; let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?; - let res = t.conv2d(&w, 1, 1)?; + let res = t.conv2d(&w, 1, 1, 1)?; println!("{res:?}"); Ok(()) } diff --git a/candle-core/src/conv.rs b/candle-core/src/conv.rs index e3fea861..d4b7a76d 100644 --- a/candle-core/src/conv.rs +++ b/candle-core/src/conv.rs @@ -1,3 +1,5 @@ +use crate::{op::BackpropOp, op::Op, Error, Result, Tensor}; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct ParamsConv1D { pub(crate) b_size: usize, @@ -51,3 +53,113 @@ impl ParamsConv2D { vec![self.b_size, self.c_out, self.out_h(), self.out_w()] } } + +impl Tensor { + fn conv1d_single_group(&self, kernel: &Self, params: &ParamsConv1D) -> Result { + let storage = + self.storage() + .conv1d(self.layout(), &kernel.storage(), kernel.layout(), params)?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D { + arg, + kernel, + padding: params.padding, + stride: params.stride, + }); + let out_dims = params.out_dims(); + Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + } + + /// Applies a 1D convolution over the input tensor. + pub fn conv1d( + &self, + kernel: &Self, + padding: usize, + stride: usize, + groups: usize, + ) -> Result { + let (c_out, c_in_k, k_size) = kernel.dims3()?; + let (b_size, c_in, l_in) = self.dims3()?; + if c_in != c_in_k * groups { + Err(Error::Conv1dInvalidArgs { + inp_shape: self.shape().clone(), + k_shape: kernel.shape().clone(), + padding, + stride, + msg: "the number of in-channels on the input doesn't match the kernel size", + } + .bt())? + } + + let params = ParamsConv1D { + b_size, + l_in, + c_out, + c_in, + k_size, + padding, + stride, + }; + if groups == 1 { + self.conv1d_single_group(kernel, ¶ms) + } else { + let blocks = self.chunk(groups, 1)?; + let blocks = blocks + .iter() + .map(|block| block.conv1d_single_group(kernel, ¶ms)) + .collect::>>()?; + Tensor::cat(&blocks, 1) + } + } + + fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result { + let storage = + self.storage() + .conv2d(self.layout(), &kernel.storage(), kernel.layout(), params)?; + let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D { + arg, + kernel, + padding: params.padding, + stride: params.stride, + }); + let out_dims = params.out_dims(); + Ok(crate::tensor::from_storage(storage, out_dims, op, false)) + } + + /// Applies a 2D convolution over the input tensor. + pub fn conv2d( + &self, + kernel: &Self, + padding: usize, + stride: usize, + groups: usize, + ) -> Result { + let (b_size, c_in, i_h, i_w) = self.dims4()?; + let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; + if c_in != c_in_k * groups { + crate::bail!( + "in_channel mismatch between input ({c_in}, groups {groups}) and kernel ({c_in_k})" + ) + } + let params = ParamsConv2D { + b_size, + i_h, + i_w, + k_h, + k_w, + c_out, + c_in, + padding, + stride, + }; + if groups == 1 { + self.conv2d_single_group(kernel, ¶ms) + } else { + let blocks = self.chunk(groups, 1)?; + let blocks = blocks + .iter() + .map(|block| block.conv2d_single_group(kernel, ¶ms)) + .collect::>>()?; + Tensor::cat(&blocks, 1) + } + } +} diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index a4b9795b..46f9c53f 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -124,7 +124,7 @@ macro_rules! broadcast_binary_op { } /// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides. -fn from_storage>( +pub(crate) fn from_storage>( storage: Storage, shape: S, op: BackpropOp, @@ -787,72 +787,6 @@ impl Tensor { self.cmp(rhs, CmpOp::Le) } - /// Applies a 1D convolution over the input tensor. - pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result { - let (c_out, c_in_k, k_size) = kernel.dims3()?; - let (b_size, c_in, l_in) = self.dims3()?; - if c_in != c_in_k { - Err(Error::Conv1dInvalidArgs { - inp_shape: self.shape().clone(), - k_shape: kernel.shape().clone(), - padding, - stride, - msg: "the number of in-channels on the input doesn't match the kernel size", - } - .bt())? - } - let params = crate::conv::ParamsConv1D { - b_size, - l_in, - c_out, - c_in, - k_size, - padding, - stride, - }; - let storage = - self.storage() - .conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?; - let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D { - arg, - kernel, - padding, - stride, - }); - let out_dims = params.out_dims(); - Ok(from_storage(storage, out_dims, op, false)) - } - - pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result { - let (b_size, c_in, i_h, i_w) = self.dims4()?; - let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?; - if c_in != c_in_k { - crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})") - } - let params = crate::conv::ParamsConv2D { - b_size, - i_h, - i_w, - k_h, - k_w, - c_out, - c_in, - padding, - stride, - }; - let storage = - self.storage() - .conv2d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?; - let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D { - arg, - kernel, - padding, - stride, - }); - let out_dims = params.out_dims(); - Ok(from_storage(storage, out_dims, op, false)) - } - pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result { let (n, c, _h, _w) = self.dims4()?; let op = BackpropOp::new1(self, Op::UpsampleNearest2D); @@ -1920,7 +1854,7 @@ impl Tensor { } } - fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { + pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> { self.storage.read().unwrap() } diff --git a/candle-core/tests/conv_tests.rs b/candle-core/tests/conv_tests.rs index c777fec7..d09fa344 100644 --- a/candle-core/tests/conv_tests.rs +++ b/candle-core/tests/conv_tests.rs @@ -33,13 +33,13 @@ fn conv1d(dev: &Device) -> Result<()> { dev, )? .reshape((2, 4, 3))?; - let res = t.conv1d(&w, 0, 1)?; + let res = t.conv1d(&w, 0, 1, 1)?; assert_eq!(res.dims(), [1, 2, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [2.6357, -1.3336, 4.1393, -1.1784, 3.5675, 0.5069] ); - let res = t.conv1d(&w, /*padding*/ 1, 1)?; + let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?; assert_eq!(res.dims(), [1, 2, 5]); // Same as pytorch default padding: use zeros. assert_eq!( @@ -52,13 +52,13 @@ fn conv1d(dev: &Device) -> Result<()> { fn conv1d_small(dev: &Device) -> Result<()> { let t = Tensor::new(&[0.4056f32, -0.8689, -0.0773, -1.5630], dev)?.reshape((1, 1, 4))?; let w = Tensor::new(&[1f32, 0., 0.], dev)?.reshape((1, 1, 3))?; - let res = t.conv1d(&w, 0, 1)?; + let res = t.conv1d(&w, 0, 1, 1)?; assert_eq!(res.dims(), [1, 1, 2]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, [0.4056, -0.8689] ); - let res = t.conv1d(&w, /*padding*/ 1, 1)?; + let res = t.conv1d(&w, /*padding*/ 1, 1, 1)?; assert_eq!(res.dims(), [1, 1, 4]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, @@ -109,7 +109,7 @@ fn conv2d(dev: &Device) -> Result<()> { )?; let t = t.reshape((1, 4, 5, 5))?; let w = w.reshape((2, 4, 3, 3))?; - let res = t.conv2d(&w, 0, 1)?; + let res = t.conv2d(&w, 0, 1, 1)?; assert_eq!(res.dims(), [1, 2, 3, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, @@ -143,7 +143,7 @@ fn conv2d_small(dev: &Device) -> Result<()> { let w = Tensor::new(&[-0.9259f32, 1.3017], dev)?; let t = t.reshape((1, 2, 3, 3))?; let w = w.reshape((1, 2, 1, 1))?; - let res = t.conv2d(&w, 0, 1)?; + let res = t.conv2d(&w, 0, 1, 1)?; assert_eq!(res.dims(), [1, 1, 3, 3]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, @@ -162,7 +162,7 @@ fn conv2d_smaller(dev: &Device) -> Result<()> { let w = Tensor::new(&[1f32, 1., 1., 1., 1., 1., 1., 1., 1.], dev)?; let t = t.reshape((1, 1, 3, 3))?; let w = w.reshape((1, 1, 3, 3))?; - let res = t.conv2d(&w, 0, 1)?; + let res = t.conv2d(&w, 0, 1, 1)?; assert_eq!(res.dims(), [1, 1, 1, 1]); assert_eq!( test_utils::to_vec1_round(&res.flatten_all()?, 4)?, diff --git a/candle-datasets/Cargo.toml b/candle-datasets/Cargo.toml index 88b81311..d4a34b01 100644 --- a/candle-datasets/Cargo.toml +++ b/candle-datasets/Cargo.toml @@ -11,8 +11,8 @@ readme = "README.md" [dependencies] byteorder = { workspace = true } -candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.1.2" } +candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" } +candle-nn = { path = "../candle-nn", version = "0.1.3" } hf-hub = { workspace = true} intel-mkl-src = { workspace = true, optional = true } memmap2 = { workspace = true } diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 24ad47f2..bbd7c3b0 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -11,11 +11,11 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" } -candle-datasets = { path = "../candle-datasets", version = "0.1.2" } -candle-nn = { path = "../candle-nn", version = "0.1.2" } -candle-transformers = { path = "../candle-transformers", version = "0.1.2" } -candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.2", optional = true } +candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" } +candle-datasets = { path = "../candle-datasets", version = "0.1.3" } +candle-nn = { path = "../candle-nn", version = "0.1.3" } +candle-transformers = { path = "../candle-transformers", version = "0.1.3" } +candle-flash-attn = { path = "../candle-flash-attn", version = "0.1.3", optional = true } safetensors = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/candle-examples/examples/musicgen/encodec_model.rs b/candle-examples/examples/musicgen/encodec_model.rs index 9c966497..e7712bf3 100644 --- a/candle-examples/examples/musicgen/encodec_model.rs +++ b/candle-examples/examples/musicgen/encodec_model.rs @@ -274,14 +274,22 @@ impl EncodecConv1d { in_c, out_c, kernel_size, - Conv1dConfig { padding: 0, stride }, + Conv1dConfig { + padding: 0, + stride, + groups: 1, + }, vb.pp("conv"), )?, NormType::None => conv1d( in_c, out_c, kernel_size, - Conv1dConfig { padding: 0, stride }, + Conv1dConfig { + padding: 0, + stride, + groups: 1, + }, vb.pp("conv"), )?, }; diff --git a/candle-examples/examples/stable-diffusion/resnet.rs b/candle-examples/examples/stable-diffusion/resnet.rs index 94f436c8..172a9359 100644 --- a/candle-examples/examples/stable-diffusion/resnet.rs +++ b/candle-examples/examples/stable-diffusion/resnet.rs @@ -66,6 +66,7 @@ impl ResnetBlock2D { let conv_cfg = nn::Conv2dConfig { stride: 1, padding: 1, + groups: 1, }; let norm1 = nn::group_norm(config.groups, in_channels, config.eps, vs.pp("norm1"))?; let conv1 = conv2d(in_channels, out_channels, 3, conv_cfg, vs.pp("conv1"))?; @@ -79,6 +80,7 @@ impl ResnetBlock2D { let conv_cfg = nn::Conv2dConfig { stride: 1, padding: 0, + groups: 1, }; Some(conv2d( in_channels, diff --git a/candle-examples/examples/stable-diffusion/unet_2d.rs b/candle-examples/examples/stable-diffusion/unet_2d.rs index 6f568113..eb2dbf10 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d.rs @@ -112,8 +112,8 @@ impl UNet2DConditionModel { let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim; let time_embed_dim = b_channels * 4; let conv_cfg = nn::Conv2dConfig { - stride: 1, padding: 1, + ..Default::default() }; let conv_in = conv2d(in_channels, b_channels, 3, conv_cfg, vs.pp("conv_in"))?; diff --git a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs index b7adb2c0..65341e74 100644 --- a/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs +++ b/candle-examples/examples/stable-diffusion/unet_2d_blocks.rs @@ -24,7 +24,11 @@ impl Downsample2D { padding: usize, ) -> Result { let conv = if use_conv { - let config = nn::Conv2dConfig { stride: 2, padding }; + let config = nn::Conv2dConfig { + stride: 2, + padding, + ..Default::default() + }; let conv = conv2d(in_channels, out_channels, 3, config, vs.pp("conv"))?; Some(conv) } else { diff --git a/candle-examples/examples/stable-diffusion/vae.rs b/candle-examples/examples/stable-diffusion/vae.rs index abba39fa..aa8e13a0 100644 --- a/candle-examples/examples/stable-diffusion/vae.rs +++ b/candle-examples/examples/stable-diffusion/vae.rs @@ -51,8 +51,8 @@ impl Encoder { config: EncoderConfig, ) -> Result { let conv_cfg = nn::Conv2dConfig { - stride: 1, padding: 1, + ..Default::default() }; let conv_in = nn::conv2d( in_channels, @@ -182,8 +182,8 @@ impl Decoder { let n_block_out_channels = config.block_out_channels.len(); let last_block_out_channels = *config.block_out_channels.last().unwrap(); let conv_cfg = nn::Conv2dConfig { - stride: 1, padding: 1, + ..Default::default() }; let conv_in = nn::conv2d( in_channels, diff --git a/candle-examples/examples/whisper/model.rs b/candle-examples/examples/whisper/model.rs index 553bd93b..4ccc79f7 100644 --- a/candle-examples/examples/whisper/model.rs +++ b/candle-examples/examples/whisper/model.rs @@ -308,10 +308,12 @@ impl AudioEncoder { let cfg1 = Conv1dConfig { padding: 1, stride: 1, + groups: 1, }; let cfg2 = Conv1dConfig { padding: 1, stride: 2, + groups: 1, }; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; diff --git a/candle-examples/examples/yolo-v3/darknet.rs b/candle-examples/examples/yolo-v3/darknet.rs index d0392308..de8fcf09 100644 --- a/candle-examples/examples/yolo-v3/darknet.rs +++ b/candle-examples/examples/yolo-v3/darknet.rs @@ -128,7 +128,11 @@ fn conv(vb: VarBuilder, index: usize, p: usize, b: &Block) -> Result<(usize, Bl) } Some(_) | None => (None, true), }; - let conv_cfg = candle_nn::Conv2dConfig { stride, padding }; + let conv_cfg = candle_nn::Conv2dConfig { + stride, + padding, + groups: 1, + }; let conv = if bias { conv2d(p, filters, size, conv_cfg, vb.pp(&format!("conv_{index}")))? } else { diff --git a/candle-examples/examples/yolo-v8/main.rs b/candle-examples/examples/yolo-v8/main.rs index 616e04ed..3b9c1ce9 100644 --- a/candle-examples/examples/yolo-v8/main.rs +++ b/candle-examples/examples/yolo-v8/main.rs @@ -101,7 +101,11 @@ impl ConvBlock { padding: Option, ) -> Result { let padding = padding.unwrap_or(k / 2); - let cfg = Conv2dConfig { padding, stride }; + let cfg = Conv2dConfig { + padding, + stride, + groups: 1, + }; let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; Ok(Self { conv, bn }) diff --git a/candle-flash-attn/Cargo.toml b/candle-flash-attn/Cargo.toml index f88a88d5..b0efaf52 100644 --- a/candle-flash-attn/Cargo.toml +++ b/candle-flash-attn/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-flash-attn" -version = "0.1.2" +version = "0.1.3" edition = "2021" description = "Flash attention layer for the candle ML framework." @@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0" readme = "README.md" [dependencies] -candle = { path = "../candle-core", features = ["cuda"], version = "0.1.2", package = "candle-core" } +candle = { path = "../candle-core", features = ["cuda"], version = "0.1.3", package = "candle-core" } half = { version = "2.3.1", features = ["num-traits"] } [build-dependencies] @@ -21,4 +21,4 @@ rayon = "1.7.0" [dev-dependencies] anyhow = { version = "1", features = ["backtrace"] } -candle-nn = { path = "../candle-nn", version = "0.1.2", features = ["cuda"] } +candle-nn = { path = "../candle-nn", version = "0.1.3", features = ["cuda"] } diff --git a/candle-kernels/Cargo.toml b/candle-kernels/Cargo.toml index a3f55c3d..6144e2d5 100644 --- a/candle-kernels/Cargo.toml +++ b/candle-kernels/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "candle-kernels" -version = "0.1.2" +version = "0.1.3" edition = "2021" description = "CUDA kernels for Candle" diff --git a/candle-nn/Cargo.toml b/candle-nn/Cargo.toml index b3e9c0bf..7cd1d7a2 100644 --- a/candle-nn/Cargo.toml +++ b/candle-nn/Cargo.toml @@ -11,7 +11,7 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" } +candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" } thiserror = { workspace = true } intel-mkl-src = { workspace = true, optional = true } safetensors = { workspace = true } diff --git a/candle-nn/src/conv.rs b/candle-nn/src/conv.rs index df9818ab..204402c3 100644 --- a/candle-nn/src/conv.rs +++ b/candle-nn/src/conv.rs @@ -5,6 +5,7 @@ use candle::{Result, Tensor}; pub struct Conv1dConfig { pub padding: usize, pub stride: usize, + pub groups: usize, } impl Default for Conv1dConfig { @@ -12,6 +13,7 @@ impl Default for Conv1dConfig { Self { padding: 0, stride: 1, + groups: 1, } } } @@ -39,7 +41,12 @@ impl Conv1d { impl crate::Module for Conv1d { fn forward(&self, x: &Tensor) -> Result { - let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?; + let x = x.conv1d( + &self.weight, + self.config.padding, + self.config.stride, + self.config.groups, + )?; match &self.bias { None => Ok(x), Some(bias) => { @@ -55,6 +62,7 @@ impl crate::Module for Conv1d { pub struct Conv2dConfig { pub padding: usize, pub stride: usize, + pub groups: usize, } impl Default for Conv2dConfig { @@ -62,6 +70,7 @@ impl Default for Conv2dConfig { Self { padding: 0, stride: 1, + groups: 1, } } } @@ -90,7 +99,12 @@ impl Conv2d { impl crate::Module for Conv2d { fn forward(&self, x: &Tensor) -> Result { - let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?; + let x = x.conv2d( + &self.weight, + self.config.padding, + self.config.stride, + self.config.groups, + )?; match &self.bias { None => Ok(x), Some(bias) => { diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index 1a64cc17..45ab38c0 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -15,7 +15,7 @@ crate-type = ["cdylib"] doc = false [dependencies] -candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" } +candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" } half = { workspace = true } pyo3 = { version = "0.19.0", features = ["extension-module"] } diff --git a/candle-transformers/Cargo.toml b/candle-transformers/Cargo.toml index 92b1137a..5c4c8860 100644 --- a/candle-transformers/Cargo.toml +++ b/candle-transformers/Cargo.toml @@ -11,8 +11,8 @@ readme = "README.md" [dependencies] accelerate-src = { workspace = true, optional = true } -candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" } -candle-nn = { path = "../candle-nn", version = "0.1.2" } +candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" } +candle-nn = { path = "../candle-nn", version = "0.1.3" } intel-mkl-src = { workspace = true, optional = true } rand = { workspace = true } wav = { workspace = true } diff --git a/candle-wasm-examples/llama2-c/Cargo.toml b/candle-wasm-examples/llama2-c/Cargo.toml index a43578cd..370708bd 100644 --- a/candle-wasm-examples/llama2-c/Cargo.toml +++ b/candle-wasm-examples/llama2-c/Cargo.toml @@ -9,8 +9,8 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.1.2" } +candle = { path = "../../candle-core", version = "0.1.3", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.1.3" } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/whisper/Cargo.toml b/candle-wasm-examples/whisper/Cargo.toml index 5d777011..f404af55 100644 --- a/candle-wasm-examples/whisper/Cargo.toml +++ b/candle-wasm-examples/whisper/Cargo.toml @@ -9,8 +9,8 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.1.2" } +candle = { path = "../../candle-core", version = "0.1.3", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.1.3" } num-traits = { workspace = true } tokenizers = { workspace = true, features = ["unstable_wasm"] } diff --git a/candle-wasm-examples/whisper/src/model.rs b/candle-wasm-examples/whisper/src/model.rs index 3470c3d6..aea993f5 100644 --- a/candle-wasm-examples/whisper/src/model.rs +++ b/candle-wasm-examples/whisper/src/model.rs @@ -295,10 +295,12 @@ impl AudioEncoder { let cfg1 = Conv1dConfig { padding: 1, stride: 1, + groups: 1, }; let cfg2 = Conv1dConfig { padding: 1, stride: 2, + groups: 1, }; let conv1 = conv1d(cfg.num_mel_bins, n_state, 3, cfg1, vb.pp("conv1"))?; let conv2 = conv1d(n_state, n_state, 3, cfg2, vb.pp("conv2"))?; diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml index ef9498ee..b565c04b 100644 --- a/candle-wasm-examples/yolo/Cargo.toml +++ b/candle-wasm-examples/yolo/Cargo.toml @@ -9,8 +9,8 @@ categories.workspace = true license.workspace = true [dependencies] -candle = { path = "../../candle-core", version = "0.1.2", package = "candle-core" } -candle-nn = { path = "../../candle-nn", version = "0.1.2" } +candle = { path = "../../candle-core", version = "0.1.3", package = "candle-core" } +candle-nn = { path = "../../candle-nn", version = "0.1.3" } num-traits = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/candle-wasm-examples/yolo/src/model.rs b/candle-wasm-examples/yolo/src/model.rs index 50fd100c..7e40fcfc 100644 --- a/candle-wasm-examples/yolo/src/model.rs +++ b/candle-wasm-examples/yolo/src/model.rs @@ -97,7 +97,11 @@ impl ConvBlock { padding: Option, ) -> Result { let padding = padding.unwrap_or(k / 2); - let cfg = Conv2dConfig { padding, stride }; + let cfg = Conv2dConfig { + padding, + stride, + groups: 1, + }; let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?; let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?; Ok(Self { conv, bn })