diff --git a/candle-core/benches/bench_main.rs b/candle-core/benches/bench_main.rs index 2e1816fd..30e17f30 100644 --- a/candle-core/benches/bench_main.rs +++ b/candle-core/benches/bench_main.rs @@ -2,11 +2,11 @@ mod benchmarks; use criterion::criterion_main; criterion_main!( - benchmarks::affine::benches, + //benchmarks::affine::benches, benchmarks::matmul::benches, - benchmarks::random::benches, - benchmarks::where_cond::benches, - benchmarks::conv_transpose2d::benches, - benchmarks::qmatmul::benches, - benchmarks::unary::benches + //benchmarks::random::benches, + //benchmarks::where_cond::benches, + //benchmarks::conv_transpose2d::benches, + //benchmarks::qmatmul::benches, + //benchmarks::unary::benches ); diff --git a/candle-metal-kernels/src/kernels/gemm.metal b/candle-metal-kernels/src/kernels/gemm.metal index 06bf6037..0cab25c2 100644 --- a/candle-metal-kernels/src/kernels/gemm.metal +++ b/candle-metal-kernels/src/kernels/gemm.metal @@ -44,22 +44,13 @@ constant uint K [[function_constant(2)]]; constant bool A_trans [[function_constant(10)]]; constant bool B_trans [[function_constant(11)]]; -// Define the memory layout of the matrix block. -constant ushort M_group [[function_constant(200)]]; -constant ushort N_group [[function_constant(201)]]; -constant ushort K_group [[function_constant(202)]]; - constant bool prefer_async_copy [[function_constant(206)]]; constant bool ideal_grouping [[function_constant(207)]]; +constant bool batched [[function_constant(100)]]; + constant ushort A_leading_dim = A_trans ? M : K; constant ushort B_leading_dim = B_trans ? K : N; -constant ushort A_leading_block_dim = A_trans ? M_group : K_group; -constant ushort B_leading_block_dim = B_trans ? K_group : N_group; - -// Thresholds that mark the matrix edge. -constant uint M_edge = M - (M % M_group); -constant uint N_edge = N - (N % N_group); // The layout of threads within a SIMD matrix. // @@ -123,28 +114,28 @@ METAL_FUNC void multiply_accumulate( thread simdgroup_matrix_storage *C_sram, ushort k ) { -#pragma clang loop unroll(full) - for (ushort m = 0; m < M_register; m += 8) { - ushort2 origin(0, m); - auto A = get_sram(A_sram, 8, origin); - A->load(A_src, A_leading_dim, ushort2(k, m), A_trans); - } -#pragma clang loop unroll(full) - for (ushort n = 0; n < N_register; n += 8) { - ushort2 origin(n, 0); - auto B = get_sram(B_sram, N_register, origin); - B->load(B_src, B_leading_dim, ushort2(n, k), B_trans); - } -#pragma clang loop unroll(full) - for (ushort m = 0; m < M_register; m += 8) { -#pragma clang loop unroll(full) - for (ushort n = 0; n < N_register; n += 8) { - auto A = get_sram(A_sram, 8, ushort2(0, m)); - auto B = get_sram(B_sram, N_register, ushort2(n, 0)); - auto C = get_sram(C_sram, N_register, ushort2(n, m)); - C->multiply(*A, *B); + #pragma clang loop unroll(full) + for (ushort m = 0; m < M_register; m += 8) { + ushort2 origin(0, m); + auto A = get_sram(A_sram, 8, origin); + A->load(A_src, A_leading_dim, ushort2(k, m), A_trans); + } + #pragma clang loop unroll(full) + for (ushort n = 0; n < N_register; n += 8) { + ushort2 origin(n, 0); + auto B = get_sram(B_sram, N_register, origin); + B->load(B_src, B_leading_dim, ushort2(n, k), B_trans); + } + #pragma clang loop unroll(full) + for (ushort m = 0; m < M_register; m += 8) { + #pragma clang loop unroll(full) + for (ushort n = 0; n < N_register; n += 8) { + auto A = get_sram(A_sram, 8, ushort2(0, m)); + auto B = get_sram(B_sram, N_register, ushort2(n, 0)); + auto C = get_sram(C_sram, N_register, ushort2(n, m)); + C->multiply(*A, *B); + } } - } } // One multiply-accumulate loop iteration, or 8 dot products. @@ -162,28 +153,28 @@ METAL_FUNC void multiply_accumulate( thread simdgroup_matrix_storage *C_sram, ushort k ) { -#pragma clang loop unroll(full) - for (ushort m = 0; m < M_register; m += 8) { - ushort2 origin(0, m); - auto A = get_sram(A_sram, 8, origin); - A->load(A_src, A_leading_dim, ushort2(k, m), A_trans); - } -#pragma clang loop unroll(full) - for (ushort n = 0; n < N_register; n += 8) { - ushort2 origin(n, 0); - auto B = get_sram(B_sram, N_register, origin); - B->load(B_src, B_leading_dim, ushort2(n, k), B_trans); - } -#pragma clang loop unroll(full) - for (ushort m = 0; m < M_register; m += 8) { -#pragma clang loop unroll(full) - for (ushort n = 0; n < N_register; n += 8) { - auto A = get_sram(A_sram, 8, ushort2(0, m)); - auto B = get_sram(B_sram, N_register, ushort2(n, 0)); - auto C = get_sram(C_sram, N_register, ushort2(n, m)); - C->multiply(*A, *B); + #pragma clang loop unroll(full) + for (ushort m = 0; m < M_register; m += 8) { + ushort2 origin(0, m); + auto A = get_sram(A_sram, 8, origin); + A->load(A_src, A_leading_dim, ushort2(k, m), A_trans); + } + #pragma clang loop unroll(full) + for (ushort n = 0; n < N_register; n += 8) { + ushort2 origin(n, 0); + auto B = get_sram(B_sram, N_register, origin); + B->load(B_src, B_leading_dim, ushort2(n, k), B_trans); + } + #pragma clang loop unroll(full) + for (ushort m = 0; m < M_register; m += 8) { + #pragma clang loop unroll(full) + for (ushort n = 0; n < N_register; n += 8) { + auto A = get_sram(A_sram, 8, ushort2(0, m)); + auto B = get_sram(B_sram, N_register, ushort2(n, 0)); + auto C = get_sram(C_sram, N_register, ushort2(n, m)); + C->multiply(*A, *B); + } } - } } // Metal function arguments. @@ -191,19 +182,19 @@ METAL_FUNC void multiply_accumulate( // A: the left-hand side matrix // - dimensions: M x K // K x M (transposed) -// - memory precision: memA -// - register precision: regA +// - memory precision: T +// - register precision: T // // B: the right-hand side matrix // - dimensions: K x N // N x K (transposed) -// - memory precision: memB -// - register precision: regB +// - memory precision: U +// - register precision: U // // C: the output matrix, alternatively the dot product accumulator // - dimensions: M x N -// - memory precision: memC -// - register precision: regC +// - memory precision: V +// - register precision: V // // threadgroup_block: the chunk of threadgroup memory allocated at runtime // - ideally 10 KB or less @@ -211,28 +202,35 @@ METAL_FUNC void multiply_accumulate( template < typename T, typename U = T, - ushort M_block_dim, - ushort N_block_dim, - ushort K_block_dim, - ushort M_split, - ushort N_split + typename V = U, + ushort M_group, + ushort N_group, + ushort K_group, + ushort M_splits, + ushort N_splits, + ushort M_register = M_group / M_splits, + ushort N_register = N_group / N_splits > void gemm_impl( device T *A [[buffer(0)]], device U *B [[buffer(1)]], - device U *C [[buffer(2)]], + device V *C [[buffer(2)]], threadgroup uchar *threadgroup_block [[threadgroup(0)]], + constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]], uint3 gid [[threadgroup_position_in_grid]], ushort sidx [[simdgroup_index_in_threadgroup]], ushort lane_id [[thread_index_in_simdgroup]] ) { - constexpr ushort M_register = M_block_dim / M_split; - constexpr ushort N_register = N_block_dim / N_split; - constexpr ushort threadgroup_size = 32 * M_split * N_split; + const ushort A_leading_block_dim = A_trans ? M_group : K_group; + const ushort B_leading_block_dim = B_trans ? K_group : N_group; - const ushort iteration_start = prefer_async_copy ? 0 : (K - (K % K_group)); + // Thresholds that mark the matrix edge. + const uint M_edge = M - (M % M_group); + const uint N_edge = N - (N % N_group); + + const ushort async_iter_start = prefer_async_copy ? 0 : (K - (K % K_group)); // Find the number of elements in the final block. If the matrix // dimensions are perfectly divisibly by block dimensions, we don't want @@ -249,9 +247,16 @@ void gemm_impl( const ushort M_shift = (M < M_group) ? 0 : M_register - M_remainder; const ushort N_shift = (N < N_group) ? 0 : N_register - N_remainder; + if (batched) { + ulong3 offsets = matrix_offsets[0].xyz * gid.z; + A = (device T*)((device uchar*)A + offsets[0]); + B = (device U*)((device uchar*)B + offsets[1]); + C = (device V*)((device uchar*)C + offsets[2]); + } + auto A_block = (threadgroup T*)(threadgroup_block); - auto B_block = (threadgroup U*)(threadgroup_block + (M*K)); - ushort2 sid(sidx % N_split, sidx / N_split); + auto B_block = (threadgroup U*)(threadgroup_block + (M * K)); + ushort2 sid(sidx % N_splits, sidx / N_splits); ushort2 morton_offset = morton_order(lane_id); // Return early if the SIMD is out of bounds. @@ -266,8 +271,8 @@ void gemm_impl( N_offset + sid.x * N_register >= N) { return; } - ushort2 offset_in_group(sid.x * M_register + morton_offset.x, - sid.y * N_register + morton_offset.y); + ushort2 offset_in_group(sid.x * N_register + morton_offset.x, + sid.y * M_register + morton_offset.y); // Shift the matrix block within bounds, if possible. if ((M_shift != 0) && (gid.y * M_group >= M_edge)) { @@ -277,91 +282,98 @@ void gemm_impl( N_offset -= N_shift; } - simdgroup_matrix_storage C_sram[(M_register / 8) * (N_register / 8)]; + simdgroup_matrix_storage C_sram[(M_register / 8) * (N_register / 8)]; // Initialize the accumulator. #pragma clang loop unroll(full) for (ushort m = 0; m < M_register; m += 8) { #pragma clang loop unroll(full) for (ushort n = 0; n < N_register; n += 8) { - ushort2 origin(n, m); + ushort2 origin(m, n); auto C = get_sram(C_sram, N_register, origin); - *C = simdgroup_matrix_storage(0); + *C = simdgroup_matrix_storage(0); } } - // Perform the iterations where async copy is avoided. - for (uint k = 0; k < iteration_start; k += 8) { + #pragma clang loop unroll(full) + for (uint k = 0; k < async_iter_start; k += 8) { uint2 A_offset(k, M_offset); uint2 B_offset(N_offset, k); A_offset += uint2(morton_offset.x, offset_in_group.y); B_offset += uint2(offset_in_group.x, morton_offset.y); - auto A_src = simdgroup_matrix_storage::apply_offset( - A, A_leading_dim, A_offset, A_trans); - auto B_src = simdgroup_matrix_storage::apply_offset( - B, N, B_offset, B_trans); + auto A_src = simdgroup_matrix_storage::apply_offset(A, A_leading_dim, A_offset, A_trans); + auto B_src = simdgroup_matrix_storage::apply_offset(B, B_leading_dim, B_offset, B_trans); simdgroup_matrix_storage A_sram[M_register / 8]; simdgroup_matrix_storage B_sram[N_register / 8]; - multiply_accumulate( - A_src, B_src, A_sram, B_sram, C_sram, 0); + multiply_accumulate(A_src, B_src, A_sram, B_sram, C_sram, 0); } - - // Perform the iterations where async copy is used. - for (uint k = iteration_start; k < K; k += K_group) { - // Launch an async copy from device to threadgroup memory. - if (sidx == 0) { + if (!prefer_async_copy) { + #pragma clang loop unroll(full) + for (uint k = 0; k < K; k += K_group) { uint2 A_offset(k, M_offset); uint2 B_offset(N_offset, k); - auto A_src = simdgroup_matrix_storage::apply_offset( - A, A_leading_dim, A_offset, A_trans); - auto B_src = simdgroup_matrix_storage::apply_offset( - B, N, B_offset, B_trans); + A_offset += uint2(morton_offset.x, offset_in_group.y); + B_offset += uint2(offset_in_group.x, morton_offset.y); - ushort M_tile_dimension = min(uint(M_group), M - M_offset); - ushort N_tile_dimension = min(uint(N_group), N - N_offset); - ushort K_tile_dimension = min(uint(K_group), K - k); - ushort K_tile_padded = min(uint(K_group), (K + K_remainder_padded - K_remainder) - k); + auto A_src = simdgroup_matrix_storage::apply_offset(A, A_leading_dim, A_offset, A_trans); + auto B_src = simdgroup_matrix_storage::apply_offset(B, B_leading_dim, B_offset, B_trans); - ushort2 A_tile_src(K_tile_dimension, M_tile_dimension); - ushort2 B_tile_src(N_tile_dimension, K_tile_dimension); - ushort2 A_tile_dst(K_tile_padded, M_tile_dimension); - ushort2 B_tile_dst(N_tile_dimension, K_tile_padded); - - simdgroup_event events[2]; - events[0].async_copy(A_block, A_leading_block_dim, A_tile_dst, - A_src, A_leading_dim, A_tile_src, A_trans); - events[1].async_copy(B_block, B_leading_block_dim, B_tile_dst, - B_src, B_leading_dim, B_tile_src, B_trans); - simdgroup_event::wait(2, events); + simdgroup_matrix_storage A_sram[M_register / 8]; + simdgroup_matrix_storage B_sram[N_register / 8]; + multiply_accumulate(A_src, B_src, A_sram, B_sram, C_sram, 0); } - threadgroup_barrier(mem_flags::mem_threadgroup); - - ushort2 A_block_offset(morton_offset.x, offset_in_group.y); - ushort2 B_block_offset(offset_in_group.x, morton_offset.y); - auto A_block_src = simdgroup_matrix_storage::apply_offset( - A_block, A_leading_block_dim, A_block_offset, A_trans); - auto B_block_src = simdgroup_matrix_storage::apply_offset( - B_block, B_leading_block_dim, B_block_offset, B_trans); - - simdgroup_matrix_storage A_sram[(M_register / 8) * (K_block_dim / 8)]; - simdgroup_matrix_storage B_sram[(K_block_dim / 8) * (N_register / 8)]; + } else { + // Perform the iterations where async copy is used. #pragma clang loop unroll(full) - for (ushort k = 0; k < K_remainder_padded; k += 8) { - multiply_accumulate( - A_block_src, B_block_src, A_sram, B_sram, C_sram, k); - } + for (uint k = async_iter_start; k < K; k += K_group) { + // Launch an async copy from device to threadgroup memory. + if (sidx == 0) { + uint2 A_offset(k, M_offset); + uint2 B_offset(N_offset, k); + auto A_src = simdgroup_matrix_storage::apply_offset(A, A_leading_dim, A_offset, A_trans); + auto B_src = simdgroup_matrix_storage::apply_offset(B, B_leading_dim, B_offset, B_trans); - // Will there be any iterations after this one? - if (k + K_group < K) { - // If so, we haven't reached the edge of either input matrix yet. - #pragma clang loop unroll(full) - for (ushort k = K_remainder_padded; k < K_group; k += 8) { - multiply_accumulate( - A_block_src, B_block_src, A_sram, B_sram, C_sram, k); + ushort M_tile_dimension = min(uint(M_group), M - M_offset); + ushort N_tile_dimension = min(uint(N_group), N - N_offset); + ushort K_tile_dimension = min(uint(K_group), K - k); + ushort K_tile_padded = min(uint(K_group), (K + K_remainder_padded - K_remainder) - k); + + ushort2 A_tile_src(K_tile_dimension, M_tile_dimension); + ushort2 B_tile_src(N_tile_dimension, K_tile_dimension); + ushort2 A_tile_dst(K_tile_padded, M_tile_dimension); + ushort2 B_tile_dst(N_tile_dimension, K_tile_padded); + + simdgroup_event events[2]; + events[0].async_copy(A_block, A_leading_block_dim, A_tile_dst, A_src, A_leading_dim, A_tile_src, A_trans); + events[1].async_copy(B_block, B_leading_block_dim, B_tile_dst, B_src, B_leading_dim, B_tile_src, B_trans); + simdgroup_event::wait(2, events); } threadgroup_barrier(mem_flags::mem_threadgroup); + + ushort2 A_block_offset(morton_offset.x, offset_in_group.y); + ushort2 B_block_offset(offset_in_group.x, morton_offset.y); + auto A_block_src = simdgroup_matrix_storage::apply_offset(A_block, A_leading_block_dim, A_block_offset, A_trans); + auto B_block_src = simdgroup_matrix_storage::apply_offset(B_block, B_leading_block_dim, B_block_offset, B_trans); + + simdgroup_matrix_storage A_sram[(M_register / 8) * (K_group / 8)]; + simdgroup_matrix_storage B_sram[(K_group / 8) * (N_register / 8)]; + + #pragma clang loop unroll(full) + for (ushort k = 0; k < K_remainder_padded; k += 8) { + multiply_accumulate(A_block_src, B_block_src, A_sram, B_sram, C_sram, k); + } + + // Will there be any iterations after this one? + if (k + K_group < K) { + // If so, we haven't reached the edge of either input matrix yet. + #pragma clang loop unroll(full) + for (ushort k = K_remainder_padded; k < K_group; k += 8) { + multiply_accumulate(A_block_src, B_block_src, A_sram, B_sram, C_sram, k); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } } } @@ -384,9 +396,8 @@ void gemm_impl( } } else { // Slow path for when memory must be handled more carefully. - auto C_block = (threadgroup U*)(threadgroup_block); - auto C_block_dst = simdgroup_matrix_storage::apply_offset( - C_block, N_group, offset_in_group); + auto C_block = (threadgroup V*)(threadgroup_block); + auto C_block_dst = simdgroup_matrix_storage::apply_offset(C_block, N_group, offset_in_group); threadgroup_barrier(mem_flags::mem_threadgroup); // Write the accumulator to threadgroup memory. @@ -405,9 +416,8 @@ void gemm_impl( if (sidx == 0) { uint2 C_offset(gid.x * N_group, gid.y * M_group); ushort2 C_tile(min(uint(N_group), N - C_offset.x), - min(uint(M_group), M - C_offset.y)); - auto C_dst = simdgroup_matrix_storage::apply_offset( - C, N, C_offset); + min(uint(M_group), M - C_offset.y)); + auto C_dst = simdgroup_matrix_storage::apply_offset(C, N, C_offset); // If we shift successfully, the garbage zone moves from the bottom right // to the top left. @@ -419,8 +429,7 @@ void gemm_impl( if ((N_shift != 0) && (C_offset.x >= N_edge)) { C_block_shift.x = N_shift; } - C_block = simdgroup_matrix_storage::apply_offset( - C_block, N_group, C_block_shift); + C_block = simdgroup_matrix_storage::apply_offset(C_block, N_group, C_block_shift); } simdgroup_event event; @@ -435,34 +444,19 @@ kernel void hgemm( device half *C [[buffer(2)]], threadgroup uchar *threadgroup_block [[threadgroup(0)]], + constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]], uint3 gid [[threadgroup_position_in_grid]], ushort sidx [[simdgroup_index_in_threadgroup]], ushort lane_id [[thread_index_in_simdgroup]] ) { if (ideal_grouping) { - gemm_impl< - half, - half, - 32, - 32, - 32, - 1, - 1 - >( - A, B, C, threadgroup_block, gid, sidx, lane_id + gemm_impl( + A, B, C, threadgroup_block, matrix_offsets, gid, sidx, lane_id ); } else { - gemm_impl< - half, - half, - 48, - 48, - 32, - 1, - 1 - >( - A, B, C, threadgroup_block, gid, sidx, lane_id + gemm_impl( + A, B, C, threadgroup_block, matrix_offsets, gid, sidx, lane_id ); } } @@ -473,40 +467,17 @@ kernel void sgemm( device float *C [[buffer(2)]], threadgroup uchar *threadgroup_block [[threadgroup(0)]], + constant ulong4 *matrix_offsets [[buffer(10), function_constant(batched)]], uint3 gid [[threadgroup_position_in_grid]], ushort sidx [[simdgroup_index_in_threadgroup]], ushort lane_id [[thread_index_in_simdgroup]] ) { + gemm_impl( + A, B, C, threadgroup_block, matrix_offsets, gid, sidx, lane_id + ); + /* if (prefer_async_copy) { - // TODO: figure out correct splits - if (ideal_grouping) { - gemm_impl< - float, - float, - 32, - 32, - 32, - 2, - 2 - >( - A, B, C, threadgroup_block, gid, sidx, lane_id - ); - } else { - gemm_impl< - float, - float, - 48, - 48, - 24, - 2, - 2 - >( - A, B, C, threadgroup_block, gid, sidx, lane_id - ); - } - } else { - // TODO: figure out correct splits constexpr ushort M_split = 1; constexpr ushort N_split = 1; if (ideal_grouping) { @@ -534,5 +505,34 @@ kernel void sgemm( A, B, C, threadgroup_block, gid, sidx, lane_id ); } + } else { + constexpr ushort M_split = 2; + constexpr ushort N_split = 2; + if (ideal_grouping) { + gemm_impl< + float, + float, + 32, + 32, + 8, + M_split, + N_split + >( + A, B, C, threadgroup_block, gid, sidx, lane_id + ); + } else { + gemm_impl< + float, + float, + 32, + 32, + 100, + M_split, + N_split + >( + A, B, C, threadgroup_block, gid, sidx, lane_id + ); + } } + */ } diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 1f01a82f..379f6227 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1476,19 +1476,27 @@ pub fn call_gemm( ) -> Result<(), MetalKernelError> { let prefer_async_copy = !device.supports_family(MTLGPUFamily::Apple9); - let mut ideal_grouping = false; let mut actual_groups: usize = 1; actual_groups *= divide(m, 48) as usize; actual_groups *= divide(n, 48) as usize; actual_groups *= b; let core_count = get_device_core_count(device); - println!("Core count: {}", core_count); let ideal_grouping = if name == "sgemm" { actual_groups <= core_count * 6 } else { actual_groups <= core_count * 9 }; + + let mut blockdim = (32, 32, 32); + if !ideal_grouping { + if name == "sgemm" { + blockdim = (48, 48, 24); + } else { + blockdim = (48, 48, 32); + } + } + assert!(rhs_stride.len() >= 2); assert!(lhs_stride.len() >= 2); let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; @@ -1525,52 +1533,45 @@ pub fn call_gemm( let alpha = 1.0f32; let beta = 0.0f32; let batched = b > 1; + println!("batched: {batched}"); let fused_activation = false; let fused_bias = false; - let (m_simd, n_simd, k_simd, m_splits, n_splits) = if m == 1 { - let m_simd = 8; - let n_simd = 8; - let k_simd = 64; - let m_splits = 1; - let n_splits = 1; - (m_simd, n_simd, k_simd, m_splits, n_splits) - } else { - let m_simd = 40; - let n_simd = 40; - let k_simd = 32; - let m_splits = 1; - let n_splits = 1; - (m_simd, n_simd, k_simd, m_splits, n_splits) - }; + let constants = Some(ConstantValues::new(vec![ (0, Value::USize(m)), (1, Value::USize(n)), (2, Value::USize(k)), (10, Value::Bool(a_trans)), (11, Value::Bool(b_trans)), - (13, Value::Bool(d_trans)), - (20, Value::F32(alpha)), - (21, Value::F32(beta)), + //(13, Value::Bool(d_trans)), + //(20, Value::F32(alpha)), + //(21, Value::F32(beta)), (100, Value::Bool(batched)), - (101, Value::Bool(fused_activation)), + //(101, Value::Bool(fused_activation)), // Garbage (102, Value::Bool(false)), (103, Value::Bool(false)), (113, Value::Bool(false)), (50_000, Value::Bool(false)), // End garbage - (200, Value::U16(32)), - (201, Value::U16(32)), - (202, Value::U16(32)), + //(200, Value::U16(blockdim.0)), + //(201, Value::U16(blockdim.1)), + //(202, Value::U16(blockdim.2)), (206, Value::Bool(prefer_async_copy)), (207, Value::Bool(ideal_grouping)), - (210, Value::U16(m_splits)), - (211, Value::U16(n_splits)), - (50_001, Value::Bool(fused_bias)), + //(210, Value::U16(m_splits)), + //(211, Value::U16(n_splits)), + //(50_001, Value::Bool(fused_bias)), ])); let pipeline = kernels.load_pipeline_with_constants(device, Source::Candle, name, constants)?; - let m_group = m_simd * m_splits; - let n_group = n_simd * n_splits; + + let m_group: u16 = 32; + let n_group: u16 = 32; + let m_splits: u16 = 2; + let n_splits: u16 = 2; + let k_simd: u16 = 32; + let m_simd = m_group / m_splits; + let n_simd = n_group / n_splits; let a_block_length = m_group * k_simd; let b_block_length = k_simd * n_group; @@ -1580,6 +1581,7 @@ pub fn call_gemm( let c_block_length = m_group * n_group; block_elements = std::cmp::max(c_block_length, block_elements) } + /* if fused_bias { if d_trans { block_elements = std::cmp::max(block_elements, m_group); @@ -1587,6 +1589,7 @@ pub fn call_gemm( block_elements = std::cmp::max(block_elements, n_group); } } + */ let bytes = match name { "sgemm" => 4, "hgemm" => 2, @@ -1600,7 +1603,7 @@ pub fn call_gemm( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - encoder.set_threadgroup_memory_length(0, block_bytes.into()); + encoder.set_threadgroup_memory_length(0, block_bytes as NSUInteger); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); encoder.set_buffer(2, Some(output), 0); @@ -1614,7 +1617,7 @@ pub fn call_gemm( // TODO byte_stride_d let byte_stride_d = 0; - let buffer: Vec = vec![ + let buffer: [u64; 4] = [ byte_stride_a as _, byte_stride_b as _, byte_stride_c as _, diff --git a/candle-metal-kernels/src/libraries/candle.metallib b/candle-metal-kernels/src/libraries/candle.metallib index 1a9df376..7ff3fba9 100644 Binary files a/candle-metal-kernels/src/libraries/candle.metallib and b/candle-metal-kernels/src/libraries/candle.metallib differ diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 8c38e74a..f255e8e2 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1100,6 +1100,11 @@ fn gemm() { let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); let rhs_stride = vec![n * k, n, 1]; let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + println!("lhs: {lhs:?}"); + println!("lhs_stride: {lhs_stride:?}"); + println!("rhs: {rhs:?}"); + println!("rhs_stride: {rhs_stride:?}"); + let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); assert_eq!( approx(results, 4), @@ -1111,6 +1116,11 @@ fn gemm() { let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); let rhs_stride = vec![n * k, n, 1]; let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + println!("lhs: {lhs:?}"); + println!("lhs_stride: {lhs_stride:?}"); + println!("rhs: {rhs:?}"); + println!("rhs_stride: {rhs_stride:?}"); + let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); assert_eq!( approx(results, 4),