Add avg_pool2d metal implementation for the metal backend (#1869)

* implement metal avg pool 2d

* fixX

* add suggested precision workaround for the accumulator
This commit is contained in:
Thomas Santerre
2024-03-18 13:50:14 -04:00
committed by GitHub
parent 58605252e8
commit 04a61a9c72
5 changed files with 236 additions and 20 deletions

View File

@ -2,9 +2,6 @@ use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
// https://github.com/huggingface/candle/issues/364
fn avg_pool2d(dev: &Device) -> Result<()> {
if dev.is_metal() {
return Ok(());
}
let data: Vec<f32> = vec![
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
];