mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Add support for conv_transpose1d for metal backend (#1874)
* first attempt * progress * integrate into metal backend * finish and get test passing * add other dtype support * update transpose1d dtypes supported
This commit is contained in:
@ -1717,3 +1717,219 @@ fn avg_pool2d_u32() {
|
||||
let expected = vec![2, 3, 4, 6, 7, 8, 10, 11, 12];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
fn run_conv_transpose1d<T: Clone>(
|
||||
input: &[T],
|
||||
input_shape: &[usize],
|
||||
input_stride: &[usize],
|
||||
kernel: &[T],
|
||||
kernel_shape: &[usize],
|
||||
kernel_stride: &[usize],
|
||||
dilation: usize,
|
||||
stride: usize,
|
||||
padding: usize,
|
||||
out_padding: usize,
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
|
||||
let c_out = kernel_shape[1];
|
||||
let k_size = kernel_shape[2];
|
||||
let b_size = input_shape[0];
|
||||
let l_in = input_shape[2];
|
||||
let l_out = (l_in - 1) * stride - 2 * padding + dilation * (k_size - 1) + out_padding + 1;
|
||||
let dst_el = c_out * l_out * b_size;
|
||||
|
||||
let input = new_buffer(&device, input);
|
||||
let kernel = new_buffer(&device, kernel);
|
||||
let output = new_buffer(&device, &vec![0.0f32; dst_el]);
|
||||
let kernels = Kernels::new();
|
||||
|
||||
call_conv_transpose1d(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
dilation,
|
||||
stride,
|
||||
padding,
|
||||
out_padding,
|
||||
c_out,
|
||||
l_out,
|
||||
b_size,
|
||||
input_shape,
|
||||
input_stride,
|
||||
kernel_shape,
|
||||
kernel_stride,
|
||||
&input,
|
||||
0,
|
||||
&kernel,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
|
||||
read_to_vec(&output, dst_el)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_f32() {
|
||||
let input = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel = vec![1.0f32, 2.0, 3.0, 4.0];
|
||||
let kernel_shape = &[1, 1, 4];
|
||||
let kernel_stride = &[4, 4, 1];
|
||||
|
||||
let results = run_conv_transpose1d(
|
||||
&input,
|
||||
input_shape,
|
||||
input_stride,
|
||||
&kernel,
|
||||
kernel_shape,
|
||||
kernel_stride,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
"conv_transpose1d_f32",
|
||||
);
|
||||
|
||||
let expected = vec![1., 4., 10., 20., 25., 24., 16.];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_f16() {
|
||||
let input: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel: Vec<f16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect();
|
||||
let kernel_shape = &[1, 1, 4];
|
||||
let kernel_stride = &[4, 4, 1];
|
||||
|
||||
let results = run_conv_transpose1d(
|
||||
&input,
|
||||
input_shape,
|
||||
input_stride,
|
||||
&kernel,
|
||||
kernel_shape,
|
||||
kernel_stride,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
"conv_transpose1d_f16",
|
||||
);
|
||||
|
||||
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_bf16() {
|
||||
let input: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect();
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel: Vec<bf16> = vec![1.0, 2.0, 3.0, 4.0]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect();
|
||||
let kernel_shape = &[1, 1, 4];
|
||||
let kernel_stride = &[4, 4, 1];
|
||||
|
||||
let results = run_conv_transpose1d(
|
||||
&input,
|
||||
input_shape,
|
||||
input_stride,
|
||||
&kernel,
|
||||
kernel_shape,
|
||||
kernel_stride,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
"conv_transpose1d_bf16",
|
||||
);
|
||||
|
||||
let expected = vec![1., 4., 10., 20., 25., 24., 16.]
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_u8() {
|
||||
let input: Vec<u8> = vec![1, 2, 3, 4];
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel: Vec<u8> = vec![1, 2, 3, 4];
|
||||
let kernel_shape = &[1, 1, 4];
|
||||
let kernel_stride = &[4, 4, 1];
|
||||
|
||||
let results = run_conv_transpose1d(
|
||||
&input,
|
||||
input_shape,
|
||||
input_stride,
|
||||
&kernel,
|
||||
kernel_shape,
|
||||
kernel_stride,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
"conv_transpose1d_u8",
|
||||
);
|
||||
|
||||
let expected = vec![1, 4, 10, 20, 25, 24, 16];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv_transpose1d_u32() {
|
||||
let input: Vec<u32> = vec![1, 2, 3, 4];
|
||||
let input_shape = &[1, 1, 4];
|
||||
let input_stride = &[4, 4, 1];
|
||||
|
||||
let kernel: Vec<u32> = vec![1, 2, 3, 4];
|
||||
let kernel_shape = &[1, 1, 4];
|
||||
let kernel_stride = &[4, 4, 1];
|
||||
|
||||
let results = run_conv_transpose1d(
|
||||
&input,
|
||||
input_shape,
|
||||
input_stride,
|
||||
&kernel,
|
||||
kernel_shape,
|
||||
kernel_stride,
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
0,
|
||||
"conv_transpose1d_u32",
|
||||
);
|
||||
|
||||
let expected = vec![1, 4, 10, 20, 25, 24, 16];
|
||||
assert_eq!(results, expected);
|
||||
}
|
||||
|
Reference in New Issue
Block a user