Add some group parameter to convolutions. (#566)

* Add some group parameter to convolutions.

* Avoid some unnecessary groups checks.

* Move the tensor convolution bits.

* Properh handling of groups.

* Bump the crate version.

* And add a changelog.
This commit is contained in:
Laurent Mazare
2023-08-23 12:58:55 +01:00
committed by GitHub
parent 4ee1cf038a
commit aba1e90797
30 changed files with 216 additions and 113 deletions

View File

@ -11,7 +11,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
thiserror = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
safetensors = { workspace = true }

View File

@ -5,6 +5,7 @@ use candle::{Result, Tensor};
pub struct Conv1dConfig {
pub padding: usize,
pub stride: usize,
pub groups: usize,
}
impl Default for Conv1dConfig {
@ -12,6 +13,7 @@ impl Default for Conv1dConfig {
Self {
padding: 0,
stride: 1,
groups: 1,
}
}
}
@ -39,7 +41,12 @@ impl Conv1d {
impl crate::Module for Conv1d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
let x = x.conv1d(
&self.weight,
self.config.padding,
self.config.stride,
self.config.groups,
)?;
match &self.bias {
None => Ok(x),
Some(bias) => {
@ -55,6 +62,7 @@ impl crate::Module for Conv1d {
pub struct Conv2dConfig {
pub padding: usize,
pub stride: usize,
pub groups: usize,
}
impl Default for Conv2dConfig {
@ -62,6 +70,7 @@ impl Default for Conv2dConfig {
Self {
padding: 0,
stride: 1,
groups: 1,
}
}
}
@ -90,7 +99,12 @@ impl Conv2d {
impl crate::Module for Conv2d {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
let x = x.conv2d(
&self.weight,
self.config.padding,
self.config.stride,
self.config.groups,
)?;
match &self.bias {
None => Ok(x),
Some(bias) => {