mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 11:37:11 +00:00
Add some group parameter to convolutions. (#566)
* Add some group parameter to convolutions. * Avoid some unnecessary groups checks. * Move the tensor convolution bits. * Properh handling of groups. * Bump the crate version. * And add a changelog.
This commit is contained in:
@ -124,7 +124,7 @@ macro_rules! broadcast_binary_op {
|
||||
}
|
||||
|
||||
/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides.
|
||||
fn from_storage<S: Into<Shape>>(
|
||||
pub(crate) fn from_storage<S: Into<Shape>>(
|
||||
storage: Storage,
|
||||
shape: S,
|
||||
op: BackpropOp,
|
||||
@ -787,72 +787,6 @@ impl Tensor {
|
||||
self.cmp(rhs, CmpOp::Le)
|
||||
}
|
||||
|
||||
/// Applies a 1D convolution over the input tensor.
|
||||
pub fn conv1d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (c_out, c_in_k, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
Err(Error::Conv1dInvalidArgs {
|
||||
inp_shape: self.shape().clone(),
|
||||
k_shape: kernel.shape().clone(),
|
||||
padding,
|
||||
stride,
|
||||
msg: "the number of in-channels on the input doesn't match the kernel size",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let params = crate::conv::ParamsConv1D {
|
||||
b_size,
|
||||
l_in,
|
||||
c_out,
|
||||
c_in,
|
||||
k_size,
|
||||
padding,
|
||||
stride,
|
||||
};
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv1d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv1D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
pub fn conv2d(&self, kernel: &Self, padding: usize, stride: usize) -> Result<Self> {
|
||||
let (b_size, c_in, i_h, i_w) = self.dims4()?;
|
||||
let (c_out, c_in_k, k_h, k_w) = kernel.dims4()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = crate::conv::ParamsConv2D {
|
||||
b_size,
|
||||
i_h,
|
||||
i_w,
|
||||
k_h,
|
||||
k_w,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
stride,
|
||||
};
|
||||
let storage =
|
||||
self.storage()
|
||||
.conv2d(self.layout(), &kernel.storage(), kernel.layout(), ¶ms)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::Conv2D {
|
||||
arg,
|
||||
kernel,
|
||||
padding,
|
||||
stride,
|
||||
});
|
||||
let out_dims = params.out_dims();
|
||||
Ok(from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
pub fn upsample_nearest2d(&self, target_h: usize, target_w: usize) -> Result<Self> {
|
||||
let (n, c, _h, _w) = self.dims4()?;
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest2D);
|
||||
@ -1920,7 +1854,7 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
|
||||
self.storage.read().unwrap()
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user