mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add avg_pool2d metal implementation for the metal backend (#1869)
* implement metal avg pool 2d * fixX * add suggested precision workaround for the accumulator
This commit is contained in:
@ -1044,8 +1044,46 @@ impl BackendStorage for MetalStorage {
|
||||
crate::bail!("Metal conv_tranpose2d not implemented")
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
crate::bail!("Metal avg_pool2d not implemented")
|
||||
fn avg_pool2d(
|
||||
&self,
|
||||
inp_l: &Layout,
|
||||
(w_k, h_k): (usize, usize),
|
||||
(w_stride, h_stride): (usize, usize),
|
||||
) -> Result<Self> {
|
||||
let shape = inp_l.shape();
|
||||
let (b_size, channels, width, height) = shape.dims4()?;
|
||||
let strides = inp_l.stride();
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "avg_pool2d_f32",
|
||||
DType::F16 => "avg_pool2d_f16",
|
||||
DType::BF16 => "avg_pool2d_bf16",
|
||||
DType::U8 => "avg_pool2d_u8",
|
||||
DType::U32 => "avg_pool2d_u32",
|
||||
dtype => crate::bail!("Metal avg_pool2d {dtype:?} not implemented"),
|
||||
};
|
||||
let out_w = (width - w_k) / w_stride + 1;
|
||||
let out_h = (height - h_k) / h_stride + 1;
|
||||
let dst_el = out_w * out_h * b_size * channels;
|
||||
let buffer = self.device.new_buffer(dst_el, self.dtype, "avg_pool2d")?;
|
||||
let command_buffers = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_pool2d(
|
||||
&self.device.device,
|
||||
&command_buffers,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
inp_l.dims(),
|
||||
strides,
|
||||
out_w,
|
||||
out_h,
|
||||
w_k,
|
||||
h_k,
|
||||
w_stride,
|
||||
h_stride,
|
||||
&self.buffer,
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
|
||||
}
|
||||
|
||||
fn max_pool2d(
|
||||
@ -1063,14 +1101,14 @@ impl BackendStorage for MetalStorage {
|
||||
DType::BF16 => "max_pool2d_bf16",
|
||||
DType::U8 => "max_pool2d_u8",
|
||||
DType::U32 => "max_pool2d_u32",
|
||||
dtype => crate::bail!("Metal upsample_nearest2d {dtype:?} not implemented"),
|
||||
dtype => crate::bail!("Metal max_pool2d {dtype:?} not implemented"),
|
||||
};
|
||||
let out_w = (width - w_k) / w_stride + 1;
|
||||
let out_h = (height - h_k) / h_stride + 1;
|
||||
let dst_el = out_w * out_h * b_size * channels;
|
||||
let buffer = self.device.new_buffer(dst_el, self.dtype, "max_pool2d")?;
|
||||
let command_buffers = self.device.command_buffer()?;
|
||||
candle_metal_kernels::call_max_pool2d(
|
||||
candle_metal_kernels::call_pool2d(
|
||||
&self.device.device,
|
||||
&command_buffers,
|
||||
&self.device.kernels,
|
||||
|
@ -2,9 +2,6 @@ use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
|
||||
|
||||
// https://github.com/huggingface/candle/issues/364
|
||||
fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
let data: Vec<f32> = vec![
|
||||
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
|
@ -206,6 +206,67 @@ kernel void FN_NAME( \
|
||||
upsample_nearest2d<TYPENAME>(w_out, h_out, w_scale, h_scale, dims, strides, src, dst, tid); \
|
||||
} \
|
||||
|
||||
template <typename T, typename A>
|
||||
METAL_FUNC void avg_pool2d(
|
||||
constant size_t &w_k,
|
||||
constant size_t &h_k,
|
||||
constant size_t &w_stride,
|
||||
constant size_t &h_stride,
|
||||
constant size_t *src_dims,
|
||||
constant size_t *src_strides,
|
||||
device const T *src,
|
||||
device T *dst,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
const size_t c = src_dims[1];
|
||||
const size_t w_in = src_dims[2];
|
||||
const size_t h_in = src_dims[3];
|
||||
|
||||
const size_t w_out = (w_in - w_k) / w_stride + 1;
|
||||
const size_t h_out = (h_in - h_k) / h_stride + 1;
|
||||
if (tid >= src_dims[0] * c * w_out * h_out) {
|
||||
return;
|
||||
}
|
||||
|
||||
const size_t b_idx = tid / (w_out * h_out * c);
|
||||
const size_t c_idx = (tid / (w_out * h_out)) % c;
|
||||
const size_t dst_w = (tid / h_out) % w_out;
|
||||
const size_t dst_h = tid % h_out;
|
||||
|
||||
const size_t src_idx0 = b_idx * src_strides[0];
|
||||
A d = 0;
|
||||
for (size_t w_offset = 0; w_offset < w_k; ++w_offset) {
|
||||
size_t src_w = w_stride * dst_w + w_offset;
|
||||
if (src_w >= w_in){
|
||||
continue;
|
||||
}
|
||||
for (size_t h_offset = 0; h_offset < h_k; ++h_offset) {
|
||||
size_t src_h = h_stride * dst_h + h_offset;
|
||||
if (src_h >= h_in) {
|
||||
continue;
|
||||
}
|
||||
const size_t src_idx = src_idx0 + c_idx * src_strides[1] + src_w * src_strides[2] + src_h * src_strides[3];
|
||||
d += static_cast<A>(src[src_idx]);
|
||||
}
|
||||
}
|
||||
dst[tid] = static_cast<T>(d / (w_k * h_k));
|
||||
}
|
||||
|
||||
#define AVGPOOL2D_OP(TYPENAME, TYPEACC, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &w_k, \
|
||||
constant size_t &h_k, \
|
||||
constant size_t &w_s, \
|
||||
constant size_t &h_s, \
|
||||
constant size_t *src_dims, \
|
||||
constant size_t *src_s, \
|
||||
device const TYPENAME *src, \
|
||||
device TYPENAME *dst, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
avg_pool2d<TYPENAME, TYPEACC>(w_k, h_k, w_s, h_s, src_dims, src_s, src, dst, tid); \
|
||||
} \
|
||||
|
||||
template <typename T>
|
||||
METAL_FUNC void max_pool2d(
|
||||
constant size_t &w_k,
|
||||
@ -293,3 +354,11 @@ MAXPOOL2D_OP(uint8_t, max_pool2d_u8)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
MAXPOOL2D_OP(bfloat, max_pool2d_bf16)
|
||||
#endif
|
||||
|
||||
AVGPOOL2D_OP(float, float, avg_pool2d_f32)
|
||||
AVGPOOL2D_OP(half, float, avg_pool2d_f16)
|
||||
AVGPOOL2D_OP(uint32_t, uint32_t, avg_pool2d_u32)
|
||||
AVGPOOL2D_OP(uint8_t, uint8_t, avg_pool2d_u8)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
AVGPOOL2D_OP(bfloat, float, avg_pool2d_bf16)
|
||||
#endif
|
@ -1827,7 +1827,7 @@ fn divide(m: usize, b: usize) -> NSUInteger {
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_max_pool2d(
|
||||
pub fn call_pool2d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
|
@ -1369,7 +1369,7 @@ fn index_add() {
|
||||
}
|
||||
}
|
||||
|
||||
fn run_max_pool2d<T: Clone>(
|
||||
fn run_pool2d<T: Clone>(
|
||||
v: &[T],
|
||||
(w_k, h_k): (usize, usize),
|
||||
(w_stride, h_stride): (usize, usize),
|
||||
@ -1386,7 +1386,7 @@ fn run_max_pool2d<T: Clone>(
|
||||
let input = new_buffer(&device, v);
|
||||
let output = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
let kernels = Kernels::new();
|
||||
call_max_pool2d(
|
||||
call_pool2d(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
@ -1417,7 +1417,7 @@ fn max_pool2d_f32() {
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 1;
|
||||
let results = run_max_pool2d(
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
@ -1434,7 +1434,7 @@ fn max_pool2d_f32() {
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 2;
|
||||
let results = run_max_pool2d(
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
@ -1454,7 +1454,7 @@ fn max_pool2d_f16() {
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 1;
|
||||
let results = run_max_pool2d(
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
@ -1474,7 +1474,7 @@ fn max_pool2d_f16() {
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 2;
|
||||
let results = run_max_pool2d(
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
@ -1497,7 +1497,7 @@ fn max_pool2d_bf16() {
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 1;
|
||||
let results = run_max_pool2d(
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
@ -1517,7 +1517,7 @@ fn max_pool2d_bf16() {
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 2;
|
||||
let results = run_max_pool2d(
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
@ -1540,7 +1540,7 @@ fn max_pool2d_u8() {
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 1;
|
||||
let results = run_max_pool2d(
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
@ -1557,7 +1557,7 @@ fn max_pool2d_u8() {
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 2;
|
||||
let results = run_max_pool2d(
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
@ -1577,7 +1577,7 @@ fn max_pool2d_u32() {
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 1;
|
||||
let results = run_max_pool2d(
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
@ -1594,7 +1594,7 @@ fn max_pool2d_u32() {
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 2;
|
||||
let results = run_max_pool2d(
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
@ -1605,3 +1605,115 @@ fn max_pool2d_u32() {
|
||||
let expected = vec![5, 7, 13, 15];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn avg_pool2d_f32() {
|
||||
// kernel 2 stride 1
|
||||
let v: Vec<f32> = (0..16).map(|v| v as f32).collect();
|
||||
let shape = vec![1, 1, 4, 4];
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 1;
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
&shape,
|
||||
&strides,
|
||||
"avg_pool2d_f32",
|
||||
);
|
||||
let expected = vec![
|
||||
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
|
||||
];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn avg_pool2d_f16() {
|
||||
// kernel 2 stride 1
|
||||
let v: Vec<f16> = (0..16).map(|v| f16::from_f32(v as f32)).collect();
|
||||
let shape = vec![1, 1, 4, 4];
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 1;
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
&shape,
|
||||
&strides,
|
||||
"avg_pool2d_f16",
|
||||
);
|
||||
let expected = vec![
|
||||
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
|
||||
]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn avg_pool2d_bf16() {
|
||||
// kernel 2 stride 1
|
||||
let v: Vec<bf16> = (0..16).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
let shape = vec![1, 1, 4, 4];
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 1;
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
&shape,
|
||||
&strides,
|
||||
"avg_pool2d_bf16",
|
||||
);
|
||||
let expected = vec![
|
||||
2.5000, 3.5000, 4.5000, 6.5000, 7.5000, 8.5000, 10.5000, 11.5000, 12.5000,
|
||||
]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn avg_pool2d_u8() {
|
||||
// kernel 2 stride 1
|
||||
let v: Vec<u8> = (0..16).map(|v| v as u8).collect();
|
||||
let shape = vec![1, 1, 4, 4];
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 1;
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
&shape,
|
||||
&strides,
|
||||
"avg_pool2d_u8",
|
||||
);
|
||||
let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn avg_pool2d_u32() {
|
||||
// kernel 2 stride 1
|
||||
let v: Vec<u32> = (0..16).map(|v| v as u32).collect();
|
||||
let shape = vec![1, 1, 4, 4];
|
||||
let strides = vec![16, 16, 4, 1];
|
||||
let kernel = 2;
|
||||
let stride = 1;
|
||||
let results = run_pool2d(
|
||||
&v,
|
||||
(kernel, kernel),
|
||||
(stride, stride),
|
||||
&shape,
|
||||
&strides,
|
||||
"avg_pool2d_u32",
|
||||
);
|
||||
let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
Reference in New Issue
Block a user