From d5902840e0f5e1e18d32965496bd1bc4911f67e2 Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Sun, 21 Jan 2024 17:32:21 +0100 Subject: [PATCH] Improve reduce perf and add contiguous impl --- candle-core/benches/bench_main.rs | 2 +- candle-core/benches/benchmarks/mod.rs | 1 + candle-core/benches/benchmarks/reduce.rs | 36 +++ candle-core/src/metal_backend.rs | 64 +++- candle-metal-kernels/src/lib.rs | 7 +- candle-metal-kernels/src/reduce.metal | 378 ++++++++++++++++++----- candle-metal-kernels/src/tests.rs | 17 +- 7 files changed, 409 insertions(+), 96 deletions(-) create mode 100644 candle-core/benches/benchmarks/reduce.rs diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 4425f2fb..07668c81 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -1,4 +1,4 @@ mod benchmarks; use criterion::criterion_main; -criterion_main!(benchmarks::matmul::benches); +criterion_main!(benchmarks::reduce::benches); \ No newline at end of file diff --git a/candle-core/benches/benchmarks/mod.rs b/candle-core/benches/benchmarks/mod.rs index 1344770d..8ee75ea9 100644 --- a/candle-core/benches/benchmarks/mod.rs +++ b/candle-core/benches/benchmarks/mod.rs @@ -1,4 +1,5 @@ pub(crate) mod matmul; +pub(crate) mod reduce; use candle_core::{Device, Result}; diff --git a/candle-core/benches/benchmarks/reduce.rs b/candle-core/benches/benchmarks/reduce.rs new file mode 100644 index 00000000..f44a1730 --- /dev/null +++ b/candle-core/benches/benchmarks/reduce.rs @@ -0,0 +1,36 @@ +use candle_core::{DType, Tensor}; +use criterion::{black_box, criterion_group, Criterion, Throughput}; +use std::time::Instant; +use crate::benchmarks::{bench_name, device, BenchDevice}; + +fn run(a: &Tensor) { + a.sum(2).unwrap(); +} + +fn criterion_benchmark(c: &mut Criterion) { + let b = 1; + let m = 2048; + let k = 2048; + + let device = device().unwrap(); + + let a = Tensor::rand(-1000.0f32, 1000.0f32, (b, m, k), &device).unwrap(); + + let flops = b * m * k * DType::F32.size_in_bytes(); + + let mut group = c.benchmark_group(bench_name("reduce")); + group.throughput(Throughput::Bytes(flops as u64)); + group.bench_function("iter", move |b| { + b.iter_custom(|iters| { + let start = Instant::now(); + for _i in 0..iters { + run(black_box(&a)); + } + device.sync().unwrap(); + start.elapsed() + }) + }); + group.finish(); +} + +criterion_group!(benches, criterion_benchmark); diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index c1c4aa4b..6d5232ce 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -491,6 +491,7 @@ impl BackendStorage for MetalStorage { fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result { let device = self.device.clone(); + let src_stride = layout.stride(); let src_dims = layout.shape().dims(); // Source dims and strides with the sum dims at the end. @@ -504,13 +505,72 @@ impl BackendStorage for MetalStorage { stride.push(src_stride[dim_idx]); } } + + if layout.is_contiguous() { + let (name, check_empty, return_index) = match (op, self.dtype) { + (ReduceOp::Sum, DType::F32) => ("fast_sum_f32", false, false), + (ReduceOp::Min, DType::F32) => ("fast_min_f32", true, false), + (ReduceOp::Max, DType::F32) => ("fast_max_f32", true, false), + //(ReduceOp::ArgMin, DType::F32) => ("fast_argmin_f32", true, true), + //(ReduceOp::ArgMax, DType::F32) => ("fast_argmax_f32", true, true), + (ReduceOp::Sum, DType::U32) => ("fast_sum_u32", false, false), + (ReduceOp::Min, DType::U32) => ("fast_min_u32", true, false), + (ReduceOp::Max, DType::U32) => ("fast_max_u32", true, false), + //(ReduceOp::ArgMin, DType::U32) => ("fast_argmin_u32", true, true), + //(ReduceOp::ArgMax, DType::U32) => ("fast_argmax_u32", true, true), + (ReduceOp::Sum, DType::F16) => ("fast_sum_f16", false, false), + (ReduceOp::Min, DType::F16) => ("fast_min_f16", true, false), + (ReduceOp::Max, DType::F16) => ("fast_max_f16", true, false), + //(ReduceOp::ArgMin, DType::F16) => ("fast_argmin_f16", true, true), + //(ReduceOp::ArgMax, DType::F16) => ("fast_argmax_f16", true, true), + (ReduceOp::Sum, DType::BF16) => ("fast_sum_bf16", false, false), + (ReduceOp::Min, DType::BF16) => ("fast_min_bf16", true, false), + (ReduceOp::Max, DType::BF16) => ("fast_max_bf16", true, false), + //(ReduceOp::ArgMin, DType::BF16) => ("fast_argmin_bf16", true, true), + //(ReduceOp::ArgMax, DType::BF16) => ("fast_argmax_bf16", true, true), + (ReduceOp::Sum, DType::I64) => ("fast_sum_i64", false, false), + (ReduceOp::Min, DType::I64) => ("fast_min_i64", true, false), + (ReduceOp::Max, DType::I64) => ("fast_max_i64", true, false), + //(ReduceOp::ArgMin, DType::I64) => ("fast_argmin_i64", true, true), + //(ReduceOp::ArgMax, DType::I64) => ("fast_argmax_i64", true, true), + (ReduceOp::Sum, DType::U8) => ("fast_sum_u8", false, false), + (ReduceOp::Min, DType::U8) => ("fast_min_u8", true, false), + (ReduceOp::Max, DType::U8) => ("fast_max_u8", true, false), + //(ReduceOp::ArgMin, DType::U8) => ("fast_argmin_u8", true, true), + //(ReduceOp::ArgMax, DType::U8) => ("fast_argmax_u8", true, true), + //(k, dtype) => crate::bail!("Metal reduce op {k:?} {dtype:?} not implemented"), + _ => ("fall back to strided impl", false, false) + }; + + if name != "fall back to strided impl" { + if check_empty && layout.shape().elem_count() == 0 { + Err(crate::Error::EmptyTensor { op: "reduce" }.bt())? + } + + + let buffer = device.new_buffer(1, self.dtype, "reduce")?; + let command_buffer = self.device.command_buffer()?; + candle_metal_kernels::call_reduce_contiguous( + &device.device, + &command_buffer, + &device.kernels, + name, + layout.shape().elem_count(), + dst_el, + &self.buffer, + layout.start_offset() * self.dtype.size_in_bytes(), + &buffer, + ) + .map_err(MetalError::from)?; + return Ok(Self::new(buffer, device, self.dtype)); + } + } + for &dim_idx in sum_dims.iter() { dims.push(src_dims[dim_idx]); stride.push(src_stride[dim_idx]); } - // The reduction loop requires the shared array to be properly initialized and for - // this we want the number of threads to be a power of two. let (name, check_empty, return_index) = match (op, self.dtype) { (ReduceOp::Sum, DType::F32) => ("fast_sum_f32_strided", false, false), (ReduceOp::Min, DType::F32) => ("fast_min_f32_strided", true, false), diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 5d34f61a..2815e74c 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -568,7 +568,6 @@ pub fn call_reduce_contiguous( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -597,7 +596,6 @@ pub fn call_reduce_contiguous( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } @@ -619,7 +617,6 @@ pub fn call_reduce_strided( let elements_to_sum = length / out_length; let encoder = command_buffer.new_compute_command_encoder(); - encoder.wait_for_fence(&kernels.fence); encoder.set_compute_pipeline_state(&pipeline); set_params!( @@ -630,7 +627,8 @@ pub fn call_reduce_strided( strides, elements_to_sum, (input, input_offset), - output + output, + out_length ) ); @@ -655,7 +653,6 @@ pub fn call_reduce_strided( encoder.use_resource(input, metal::MTLResourceUsage::Read); encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - encoder.update_fence(&kernels.fence); encoder.end_encoding(); Ok(()) } diff --git a/candle-metal-kernels/src/reduce.metal b/candle-metal-kernels/src/reduce.metal index 83a56f0a..d297a527 100644 --- a/candle-metal-kernels/src/reduce.metal +++ b/candle-metal-kernels/src/reduce.metal @@ -1,16 +1,15 @@ #include +#include using namespace metal; -#define MAX(x, y) ((x) > (y) ? (x) : (y)) -#define MIN(x, y) ((x) < (y) ? (x) : (y)) - METAL_FUNC uint get_strided_index( uint idx, - constant size_t &num_dims, - constant size_t *dims, - constant size_t *strides + constant const size_t &num_dims, + constant const size_t *dims, + constant const size_t *strides ) { uint strided_i = 0; + #pragma clang loop unroll(full) for (uint d = 0; d < num_dims; d++) { uint dim_idx = num_dims - 1 - d; strided_i += (idx % dims[dim_idx]) * strides[dim_idx]; @@ -19,8 +18,270 @@ METAL_FUNC uint get_strided_index( return strided_i; } -constant int THREADGROUP_SIZE = 2048; +#define impl_reduction_op(name, op, init_val) \ +template \ +struct name { \ + \ + static constexpr constant T init = init_val; \ + \ + METAL_FUNC T operator()(thread const T &a, thread const T &b) const { \ + return op; \ + } \ + \ + METAL_FUNC T operator()(threadgroup const T &a, threadgroup const T &b) const { \ + return op; \ + } \ + \ + METAL_FUNC T operator()(device const T &a, device const T &b) const { \ + return op; \ + } \ + \ + METAL_FUNC T operator()(T a, T b) { \ + return op; \ + } \ +} \ +impl_reduction_op(Sum, a + b, 0); +impl_reduction_op(Mul, a * b, 1); +impl_reduction_op(Min, a < b ? a : b, numeric_limits::max()); +impl_reduction_op(Max, a > b ? a : b, numeric_limits::min()); +#undef impl_reduction_op + +static constant constexpr int THREADGROUP_SIZE = 2048; + +template +METAL_FUNC void load_from_global( + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant size_t &el_to_sum_per_block, + device const T *src, + threadgroup T shared[BLOCKSIZE], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]], + uint block_dim [[ threads_per_threadgroup ]] +) { + ReductionOp op; + + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + while (idx < stop_idx) { + size_t strided_i = get_strided_index(idx, num_dims, dims, strides); + shared[tid] = op(shared[tid], src[strided_i]); + idx += block_dim; + } + threadgroup_barrier(mem_flags::mem_none); +} + +template +METAL_FUNC void load_from_global( + constant size_t &num_dims, + constant size_t *dims, + constant size_t &el_to_sum_per_block, + device const T *src, + threadgroup T shared[BLOCKSIZE], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]], + uint block_dim [[ threads_per_threadgroup ]] +) { + ReductionOp op; + + size_t start_idx = dst_id * el_to_sum_per_block; + size_t stop_idx = start_idx + el_to_sum_per_block; + size_t idx = start_idx + tid; + while (idx < stop_idx) { + shared[tid] = op(shared[tid], src[idx]); + idx += block_dim; + } + threadgroup_barrier(mem_flags::mem_none); +} + +template +METAL_FUNC void threadgroup_reduce( + threadgroup T shared[BLOCKSIZE], + uint tid [[thread_index_in_threadgroup]], + uint block_dim [[ threads_per_threadgroup ]] +) { + ReductionOp op; + if (BLOCKSIZE >= 64) { + if (block_dim >= 64) { + shared[tid] = op(shared[tid], shared[tid + 32]); + } + } + if (BLOCKSIZE >= 32) { + if (block_dim >= 32) { + shared[tid] = op(shared[tid], shared[tid + 16]); + } + } + if (BLOCKSIZE >= 16) { + if (block_dim >= 16) { + shared[tid] = op(shared[tid], shared[tid + 8]); + } + } + if (BLOCKSIZE >= 8) { + if (block_dim >= 8) { + shared[tid] = op(shared[tid], shared[tid + 4]); + } + } + if (BLOCKSIZE >= 4) { + if (block_dim >= 4) { + shared[tid] = op(shared[tid], shared[tid + 2]); + } + } + if (BLOCKSIZE >= 2) { + if (block_dim >= 2) { + shared[tid] = op(shared[tid], shared[tid + 1]); + } + } +} + +// Inspired by "Optimizing Parallel Reduction in CUDA" by Mark Harris +template< + typename T, + typename ReductionOp, + uint BLOCKSIZE, + bool STRIDED +> +METAL_FUNC void block_reduce( + constant size_t &num_dims, + constant size_t *dims, + constant size_t *strides, + constant size_t &el_to_sum_per_block, + device const T *src, + device T *dst, + constant uint &num_elements, + threadgroup T shared[BLOCKSIZE], + uint id [[ thread_position_in_grid ]], + uint tid [[ thread_index_in_threadgroup ]], + uint dst_id [[ threadgroup_position_in_grid ]], + uint block_dim [[ threads_per_threadgroup ]] +) { + ReductionOp op; + + shared[tid] = ReductionOp::init; + + if (STRIDED) { + load_from_global( + num_dims, + dims, + strides, + el_to_sum_per_block, + src, + shared, + tid, + dst_id, + block_dim + ); + } else { + load_from_global( + num_dims, + dims, + el_to_sum_per_block, + src, + shared, + tid, + dst_id, + block_dim + ); + } + + if (BLOCKSIZE >= 1024) { + if (tid < 512 && block_dim >= 1024) { + shared[tid] = op(shared[tid], shared[tid + 512]); + threadgroup_barrier(mem_flags::mem_none); + } + } + if (BLOCKSIZE >= 512) { + if (tid < 256 && block_dim >= 512) { + shared[tid] = op(shared[tid], shared[tid + 256]); + threadgroup_barrier(mem_flags::mem_none); + } + } + if (BLOCKSIZE >= 256) { + if (tid < 128 && block_dim >= 256) { + shared[tid] = op(shared[tid], shared[tid + 128]); + threadgroup_barrier(mem_flags::mem_none); + } + } + if (BLOCKSIZE >= 128) { + if (tid < 64 && block_dim >= 128) { + shared[tid] = op(shared[tid], shared[tid + 64]); + threadgroup_barrier(mem_flags::mem_none); + } + } + if (tid < 32) { + threadgroup_reduce(shared, tid, block_dim); + threadgroup_barrier(mem_flags::mem_none); + } + + if (tid == 0) { + dst[dst_id] = shared[tid]; + } +} + +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) + +static constant constexpr int BLOCKSIZE = 2048; + +#define REDUCE(OP, NAME, T) \ +kernel void NAME( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + constant uint &num_elements, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup T shared[BLOCKSIZE]; \ + block_reduce, BLOCKSIZE, false>( \ + num_dims, \ + dims, \ + strides, \ + el_to_sum_per_block, \ + src, \ + dst, \ + num_elements, \ + shared, \ + id, \ + tid, \ + dst_id, \ + block_dim); \ +} \ +kernel void NAME##_strided( \ + constant size_t &num_dims, \ + constant size_t *dims, \ + constant size_t *strides, \ + constant size_t &el_to_sum_per_block, \ + device const T *src, \ + device T *dst, \ + constant uint &num_elements, \ + uint id [[ thread_position_in_grid ]], \ + uint tid [[ thread_index_in_threadgroup ]], \ + uint dst_id [[ threadgroup_position_in_grid ]], \ + uint block_dim [[ threads_per_threadgroup ]] \ +) { \ + threadgroup T shared[BLOCKSIZE]; \ + block_reduce, BLOCKSIZE, false>( \ + num_dims, \ + dims, \ + strides, \ + el_to_sum_per_block, \ + src, \ + dst, \ + num_elements, \ + shared, \ + id, \ + tid, \ + dst_id, \ + block_dim); \ +} \ #define ARGMIN(NAME, T, MAXVALUE) \ kernel void NAME( \ @@ -140,59 +401,6 @@ kernel void NAME( \ } \ } \ -#define REDUCE(FN, NAME, T, START) \ -kernel void NAME( \ - constant size_t &num_dims, \ - constant size_t *dims, \ - constant size_t *strides, \ - constant size_t &el_to_sum_per_block, \ - device const T *src, \ - device T *dst, \ - uint id [[ thread_position_in_grid ]], \ - uint tid [[ thread_index_in_threadgroup ]], \ - uint dst_id [[ threadgroup_position_in_grid ]], \ - uint block_dim [[ threads_per_threadgroup ]] \ -) { \ - \ - threadgroup T shared_memory[THREADGROUP_SIZE]; \ - \ - shared_memory[tid] = START; \ - /* \ - // Elements summed in this block range from dst_id * el_to_sum_per_block \ - // to (dst_id + 1) * el_to_sum_per_block. \ - */ \ - size_t start_idx = dst_id * el_to_sum_per_block; \ - size_t stop_idx = start_idx + el_to_sum_per_block; \ - size_t idx = start_idx + tid; \ - while (idx < stop_idx) { \ - /* \ - // TODO: Fast version for the contiguous case. \ - */ \ - size_t strided_i = get_strided_index(idx, num_dims, dims, strides); \ - T x = shared_memory[tid]; \ - T y = src[strided_i]; \ - shared_memory[tid] = FN; \ - idx += block_dim; \ - } \ - \ - threadgroup_barrier(mem_flags::mem_none); \ - \ - /* \ - // reduction in shared memory \ - */ \ - for (uint s = block_dim / 2; s > 0; s >>= 1) { \ - if (tid < s) { \ - T x = shared_memory[tid]; \ - T y = shared_memory[tid + s]; \ - shared_memory[tid] = FN; \ - } \ - threadgroup_barrier(mem_flags::mem_none); \ - } \ - \ - dst[dst_id] = shared_memory[0]; \ -} \ - - #define SOFTMAX(NAME, T) \ kernel void NAME( \ constant size_t &src_numel, \ @@ -258,23 +466,24 @@ kernel void NAME( dst[idx] *= inv_acc; \ idx += block_dim; \ } \ -} \ +} + +REDUCE(Sum, fast_sum_f32, float) +REDUCE(Sum, fast_sum_u32, uint) +REDUCE(Sum, fast_sum_f16, half) +REDUCE(Sum, fast_sum_u8, uint8_t) +REDUCE(Mul, fast_mul_f32, float) +REDUCE(Mul, fast_mul_u32, uint) +REDUCE(Mul, fast_mul_f16, half) +REDUCE(Max, fast_max_f32, float) +REDUCE(Max, fast_max_u32, uint) +REDUCE(Max, fast_max_f16, half) +REDUCE(Max, fast_max_u8, uint8_t) +REDUCE(Min, fast_min_f32, float) +REDUCE(Min, fast_min_u32, uint) +REDUCE(Min, fast_min_f16, half) +REDUCE(Min, fast_min_u8, uint8_t) -REDUCE(x + y, fast_sum_f32_strided, float, 0) -REDUCE(x + y, fast_sum_u32_strided, uint, 0) -REDUCE(x + y, fast_sum_f16_strided, half, 0) -REDUCE(x + y, fast_sum_u8_strided, uint8_t, 0) -REDUCE(x * y, fast_mul_f32_strided, float, 1) -REDUCE(x * y, fast_mul_u32_strided, uint, 1) -REDUCE(x * y, fast_mul_f16_strided, half, 1) -REDUCE(MAX(x, y), fast_max_f32_strided, float, -HUGE_VALF) -REDUCE(MAX(x, y), fast_max_u32_strided, uint, 0) -REDUCE(MAX(x, y), fast_max_f16_strided, half, -HUGE_VALH) -REDUCE(MAX(x, y), fast_max_u8_strided, uint8_t, 0) -REDUCE(MIN(x, y), fast_min_f32_strided, float, HUGE_VALF) -REDUCE(MIN(x, y), fast_min_u32_strided, uint, 0xFFFFFFFF) -REDUCE(MIN(x, y), fast_min_f16_strided, half, HUGE_VALH) -REDUCE(MIN(x, y), fast_min_u8_strided, uint8_t, 0xFF) ARGMIN(fast_argmin_f32_strided, float, HUGE_VALF) ARGMIN(fast_argmin_f16_strided, half, HUGE_VALH) ARGMIN(fast_argmin_u32_strided, uint, 0xFFFFFFFF) @@ -288,18 +497,21 @@ SOFTMAX(softmax_f32, float) SOFTMAX(softmax_f16, half) #if __METAL_VERSION__ >= 220 -REDUCE(x + y, fast_sum_i64_strided, int64_t, 0) -REDUCE(MIN(x, y), fast_min_i64_strided, int64_t, INT_MAX) -REDUCE(MAX(x, y), fast_max_i64_strided, int64_t, INT_MIN) +REDUCE(Sum, fast_sum_i64, int64_t) +REDUCE(Mul, fast_mul_i64, int64_t) +REDUCE(Min, fast_min_i64, int64_t) +REDUCE(Max, fast_max_i64, int64_t) + ARGMIN(fast_argmin_i64_strided, int64_t, INT_MAX) ARGMAX(fast_argmax_i64_strided, int64_t, INT_MIN) #endif #if __METAL_VERSION__ >= 310 -REDUCE(x + y, fast_sum_bf16, bfloat, 0) -REDUCE(x * y, fast_mul_bf16, bfloat, 1) -REDUCE(MAX(x, y), fast_max_bf16, bfloat, -HUGE_VALBF) -REDUCE(MIN(x, y), fast_min_bf16, bfloat, HUGE_VALBF) +REDUCE(Sum, fast_sum_bf16, bfloat) +REDUCE(Mul, fast_mul_bf16, bfloat) +REDUCE(Max, fast_max_bf16, bfloat) +REDUCE(Min, fast_min_bf16, bfloat) + ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF) ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF) SOFTMAX(softmax_bf16, bfloat) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index c955abca..5d41977f 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -509,7 +509,7 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec()) as u64, options); let dims = vec![v.len()]; let strides = vec![1]; - call_reduce_strided( + let result = call_reduce_strided( &device, command_buffer, &kernels, @@ -520,10 +520,17 @@ fn run_reduce(v: &[T], out_length: usize, name: &'static str) -> Vec { + command_buffer.commit(); + command_buffer.wait_until_completed(); + } + Err(e) => { + println!("Error: {}", e); + panic!("damn!"); + }, + } read_to_vec(&output, out_length) }