mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Add some group parameter to convolutions. (#566)
* Add some group parameter to convolutions. * Avoid some unnecessary groups checks. * Move the tensor convolution bits. * Properh handling of groups. * Bump the crate version. * And add a changelog.
This commit is contained in:
@ -11,7 +11,7 @@ readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { path = "../candle-core", version = "0.1.2", package = "candle-core" }
|
||||
candle = { path = "../candle-core", version = "0.1.3", package = "candle-core" }
|
||||
thiserror = { workspace = true }
|
||||
intel-mkl-src = { workspace = true, optional = true }
|
||||
safetensors = { workspace = true }
|
||||
|
@ -5,6 +5,7 @@ use candle::{Result, Tensor};
|
||||
pub struct Conv1dConfig {
|
||||
pub padding: usize,
|
||||
pub stride: usize,
|
||||
pub groups: usize,
|
||||
}
|
||||
|
||||
impl Default for Conv1dConfig {
|
||||
@ -12,6 +13,7 @@ impl Default for Conv1dConfig {
|
||||
Self {
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -39,7 +41,12 @@ impl Conv1d {
|
||||
|
||||
impl crate::Module for Conv1d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv1d(&self.weight, self.config.padding, self.config.stride)?;
|
||||
let x = x.conv1d(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.groups,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => {
|
||||
@ -55,6 +62,7 @@ impl crate::Module for Conv1d {
|
||||
pub struct Conv2dConfig {
|
||||
pub padding: usize,
|
||||
pub stride: usize,
|
||||
pub groups: usize,
|
||||
}
|
||||
|
||||
impl Default for Conv2dConfig {
|
||||
@ -62,6 +70,7 @@ impl Default for Conv2dConfig {
|
||||
Self {
|
||||
padding: 0,
|
||||
stride: 1,
|
||||
groups: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -90,7 +99,12 @@ impl Conv2d {
|
||||
|
||||
impl crate::Module for Conv2d {
|
||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||
let x = x.conv2d(&self.weight, self.config.padding, self.config.stride)?;
|
||||
let x = x.conv2d(
|
||||
&self.weight,
|
||||
self.config.padding,
|
||||
self.config.stride,
|
||||
self.config.groups,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => {
|
||||
|
Reference in New Issue
Block a user