Add a cuda kernel for avg-pool2d. (#440)

* Add a cuda kernel for avg-pool2d.

* Avoid running out of bounds.

* Finish wiring the avg pool kernel + add some testing.

* Support for max-pool + testing.
This commit is contained in:
Laurent Mazare
2023-08-14 12:32:05 +01:00
committed by GitHub
parent 34f4b3187e
commit a094dc503d
3 changed files with 253 additions and 19 deletions

View File

@ -1,25 +1,22 @@
mod test_utils;
use candle_core::{Device, IndexOp, Tensor};
use candle_core::{Device, IndexOp, Result, Tensor};
// https://github.com/huggingface/candle/issues/364
#[test]
fn avg_pool2d() -> anyhow::Result<()> {
fn avg_pool2d(dev: &Device) -> Result<()> {
let data: Vec<f32> = vec![
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];
let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?;
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[0.5f32, 1.], [1., 1.]]);
Ok(())
}
#[test]
fn max_pool2d() -> anyhow::Result<()> {
fn max_pool2d(dev: &Device) -> Result<()> {
let data: Vec<f32> = vec![
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
];
let t = Tensor::from_vec(data, (1, 1, 4, 4), &Device::Cpu)?;
let t = Tensor::from_vec(data, (1, 1, 4, 4), dev)?;
let pool = t.max_pool2d((2, 2), (2, 2))?.squeeze(0)?.squeeze(0)?;
assert_eq!(pool.to_vec2::<f32>()?, [[2f32, 3.], [5., 1.]]);
@ -35,8 +32,7 @@ print(t.flatten())
res = torch.nn.functional.avg_pool2d(t, 2)
print(res)
*/
#[test]
fn avg_pool2d_pytorch() -> anyhow::Result<()> {
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
let t = Tensor::new(
&[
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
@ -44,7 +40,7 @@ fn avg_pool2d_pytorch() -> anyhow::Result<()> {
1.2279, -0.9287, -1.7030, 0.1370, 0.6047, 0.3770, -0.6266, 0.3529, 2.2013, -0.6836,
0.2477, 1.3127,
],
&Device::Cpu,
dev,
)?
.reshape((1, 2, 4, 4))?;
let pool = t.avg_pool2d((2, 2), (2, 2))?.squeeze(0)?;
@ -61,7 +57,7 @@ fn avg_pool2d_pytorch() -> anyhow::Result<()> {
}
#[test]
fn upsample_nearest2d() -> anyhow::Result<()> {
fn upsample_nearest2d() -> Result<()> {
let t = Tensor::arange(0f32, 6f32, &Device::Cpu)?.reshape((1, 1, 2, 3))?;
let upsampled = t.upsample_nearest2d(4, 6)?.i(0)?.i(0)?;
assert_eq!(
@ -79,3 +75,11 @@ fn upsample_nearest2d() -> anyhow::Result<()> {
);
Ok(())
}
test_device!(avg_pool2d, avg_pool2d_cpu, avg_pool2d_gpu);
test_device!(
avg_pool2d_pytorch,
avg_pool2d_pytorch_cpu,
avg_pool2d_pytorch_gpu
);
test_device!(max_pool2d, max_pool2d_cpu, max_pool2d_gpu);