diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 12f56d50..f8de6c71 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -796,101 +796,37 @@ impl BackendStorage for MetalStorage { ) -> Result { // Create descriptors - let (type_id, size) = match self.dtype { - DType::F32 => ( - metal::mps::MPS_FLOATBIT_ENCODING | 32, - core::mem::size_of::() as NSUInteger, - ), - DType::F16 => ( - metal::mps::MPS_FLOATBIT_ENCODING | 16, - core::mem::size_of::() as NSUInteger, - ), - dtype => todo!("Dtype for matmul {dtype:?} is not supported"), + let buffer = self.device.new_buffer(b * m * n, self.dtype); + let name = match self.dtype { + DType::F32 => "sgemm", + DType::F16 => "hgemm", + dtype => { + return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into()) + } }; - let lhs_stride = lhs_l.stride(); - let rhs_stride = rhs_l.stride(); - let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; - let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; - let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; - let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; - // The a tensor has dims batching, k, n (rhs) - let transpose_left = if lhs_m1 == 1 && lhs_m2 == k { - false - } else if lhs_m1 == m && lhs_m2 == 1 { - true - } else { - Err(MetalError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? - }; - let transpose_right = if rhs_m1 == 1 && rhs_m2 == n { - false - } else if rhs_m1 == k && rhs_m2 == 1 { - true - } else { - Err(MetalError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), - mnk: (m, n, k), - })? - }; - let b = b as NSUInteger; - let m = m as NSUInteger; - let n = n as NSUInteger; - let k = k as NSUInteger; - - let left_matrix = self.matrix( - (b, m, k), - transpose_left, - size, - lhs_l.start_offset() as NSUInteger * size, - type_id, - )?; - let right_matrix = rhs.matrix( - (b, k, n), - transpose_right, - size, - rhs_l.start_offset() as NSUInteger * size, - type_id, - )?; - let (result_matrix, out_buffer) = - self.device - .new_matrix((b, m, n), size, type_id, self.dtype)?; - let command_buffer = self.device.command_buffer(); - - let alpha = 1.0f64; - let beta = 0.0f64; - // Create kernel - let matrix_multiplication = MatrixMultiplication::init( - &self.device, - transpose_left, - transpose_right, - m, - n, - k, - alpha, - beta, - ) - .ok_or_else(|| { - MetalError::from("Failed to create matrix multiplication kernel".to_string()) - })?; - - // Encode kernel to command buffer - matrix_multiplication.encode_to_command_buffer( - &command_buffer, - &left_matrix, - &right_matrix, - &result_matrix, - ); command_buffer.set_label("matmul"); + candle_metal_kernels::call_gemm( + &self.device.device, + &command_buffer, + &self.device.kernels, + name, + (b, m, n, k), + &lhs_l.stride(), + lhs_l.start_offset(), + &self.buffer, + &rhs_l.stride(), + rhs_l.start_offset(), + &rhs.buffer, + &buffer, + ) + .map_err(MetalError::from)?; + // Create kernel drop(command_buffer); self.device.commit(); - Ok(Self::new(out_buffer, self.device.clone(), self.dtype())) + Ok(Self::new(buffer, self.device.clone(), self.dtype())) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index d3312f0e..6d01145d 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -183,7 +183,7 @@ impl From> for MetalKernelError { #[derive(Debug, PartialEq)] pub enum Value { - U32(u32), + USize(usize), Bool(bool), F32(f32), U16(u16), @@ -193,7 +193,7 @@ impl std::hash::Hash for Value { fn hash(&self, state: &mut H) { match self { Value::F32(v) => v.to_bits().hash(state), - Value::U32(v) => v.hash(state), + Value::USize(v) => v.hash(state), Value::U16(v) => v.hash(state), Value::Bool(v) => v.hash(state), } @@ -203,7 +203,7 @@ impl std::hash::Hash for Value { impl Value { fn data_type(&self) -> MTLDataType { match self { - Value::U32(_) => MTLDataType::UInt, + Value::USize(_) => MTLDataType::UInt, Value::F32(_) => MTLDataType::Float, Value::U16(_) => MTLDataType::UShort, Value::Bool(_) => MTLDataType::Bool, @@ -227,9 +227,9 @@ impl ConstantValues { for (index, value) in &self.0 { let ty = value.data_type(); match value { - Value::U32(v) => { + Value::USize(v) => { f.set_constant_value_at_index( - v as *const u32 as *const c_void, + v as *const usize as *const c_void, ty, *index as u64, ); @@ -824,11 +824,39 @@ pub fn call_gemm( rhs_buffer: &Buffer, output: &Buffer, ) -> Result<(), MetalKernelError> { - let a_trans = false; - let b_trans = false; + assert!(rhs_stride.len() >= 2); + assert!(lhs_stride.len() >= 2); + let rhs_m1 = rhs_stride[rhs_stride.len() - 1]; + let rhs_m2 = rhs_stride[rhs_stride.len() - 2]; + let lhs_m1 = lhs_stride[lhs_stride.len() - 1]; + let lhs_m2 = lhs_stride[lhs_stride.len() - 2]; + let a_trans = if lhs_m1 == 1 && lhs_m2 == k { + false + } else if lhs_m1 == m && lhs_m2 == 1 { + true + } else { + todo!(); + // Err(MetalError::MatMulNonContiguous { + // lhs_stride: lhs_stride.to_vec(), + // rhs_stride: rhs_stride.to_vec(), + // mnk: (m, n, k), + // })? + }; + let b_trans = if rhs_m1 == 1 && rhs_m2 == n { + false + } else if rhs_m1 == k && rhs_m2 == 1 { + true + } else { + todo!(); + // Err(MetalError::MatMulNonContiguous { + // lhs_stride: lhs_stride.to_vec(), + // rhs_stride: rhs_stride.to_vec(), + // mnk: (m, n, k), + // })? + }; let d_trans = false; - let alpha = 1.0; - let beta = 0.0; + let alpha = 1.0f32; + let beta = 0.0f32; let batched = b > 1; let fused_activation = false; let fused_bias = false; @@ -838,9 +866,9 @@ pub fn call_gemm( let m_splits = 2; let n_splits = 2; let constants = Some(ConstantValues::new(vec![ - (0, Value::U32(m as u32)), - (1, Value::U32(n as u32)), - (2, Value::U32(k as u32)), + (0, Value::USize(m)), + (1, Value::USize(n)), + (2, Value::USize(k)), (10, Value::Bool(a_trans)), (11, Value::Bool(b_trans)), (13, Value::Bool(d_trans)), @@ -861,7 +889,7 @@ pub fn call_gemm( (211, Value::U16(n_splits)), (50_001, Value::Bool(fused_bias)), ])); - println!("Constants {constants:?}"); + // println!("Constants {constants:?}"); let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?; let m_group = m_simd * m_splits; let n_group = n_simd * n_splits; @@ -895,35 +923,34 @@ pub fn call_gemm( let encoder = command_buffer.new_compute_command_encoder(); encoder.set_compute_pipeline_state(&pipeline); - println!("Threadgroup {block_bytes}"); - encoder.set_threadgroup_memory_length(block_bytes.into(), 0); + // println!("Threadgroup {block_bytes}"); + encoder.set_threadgroup_memory_length(0, block_bytes.into()); encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger); encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger); encoder.set_buffer(2, Some(output), 0); // TODO Tensor D let grid_z = b; - let byte_stride_a: usize = *lhs_stride.get(lhs_stride.len() - 3).unwrap_or(&0) * bytes as usize; - let byte_stride_b = *rhs_stride.get(rhs_stride.len() - 3).unwrap_or(&0) * bytes as usize; - let byte_stride_c = m * n * bytes as usize; - // TODO byte_stride_d - let byte_stride_d = 0; + if batched { + let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize; + let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize; + let byte_stride_c = m * n * bytes as usize; + // TODO byte_stride_d + let byte_stride_d = 0; - let mut buffer: Vec = Vec::with_capacity(b * 4); - for i in 0..b { - buffer.push((i * byte_stride_a) as u64); - buffer.push((i * byte_stride_b) as u64); - buffer.push((i * byte_stride_c) as u64); - buffer.push((i * byte_stride_d) as u64); + let mut buffer: Vec = Vec::with_capacity(b * 4); + for i in 0..b { + buffer.push((i * byte_stride_a) as u64); + buffer.push((i * byte_stride_b) as u64); + buffer.push((i * byte_stride_c) as u64); + buffer.push((i * byte_stride_d) as u64); + } + encoder.set_bytes( + 10, + buffer.len() as NSUInteger * core::mem::size_of::(), + buffer.as_ptr() as *const NSUInteger as *const c_void, + ); } - println!("A {:?}", lhs_buffer.read_to_vec::(12)); - println!("B {:?}", rhs_buffer.read_to_vec::(24)); - println!("buffer {:?}", buffer); - encoder.set_bytes( - 10, - buffer.len() as NSUInteger, - buffer.as_ptr() as *const NSUInteger as *const c_void, - ); let grid_size = MTLSize { width: divide(n, n_group.into()), @@ -935,7 +962,7 @@ pub fn call_gemm( height: 1, depth: 1, }; - println!("grid size {grid_size:?} group size {group_size:?}"); + // println!("grid size {grid_size:?} group size {group_size:?}"); encoder.dispatch_thread_groups(grid_size, group_size); encoder.end_encoding(); diff --git a/candle-metal-kernels/src/libMetalFlashAttention.metallib b/candle-metal-kernels/src/libMetalFlashAttention.metallib index 8c8ce692..f5116ca6 100644 Binary files a/candle-metal-kernels/src/libMetalFlashAttention.metallib and b/candle-metal-kernels/src/libMetalFlashAttention.metallib differ diff --git a/candle-metal-kernels/src/test.swift b/candle-metal-kernels/src/test.swift new file mode 100644 index 00000000..65749501 --- /dev/null +++ b/candle-metal-kernels/src/test.swift @@ -0,0 +1,211 @@ + +import Metal +import MetalPerformanceShadersGraph + + + +let type = MTLDataType.float; +let dataType = type; +var B = 2; +var M = 2; +var N = 4; +var K = 3; +var A_trans = false; +var B_trans = false; +var D_trans = false; +var alpha = Float(1.0); +var beta = Float(0.0); +var batched = B > 1; +var fused_activation = false; +var fused_bias = false; +let constants = MTLFunctionConstantValues() +constants.setConstantValue(&M, type: .uint, index: 0) +constants.setConstantValue(&N, type: .uint, index: 1) +constants.setConstantValue(&K, type: .uint, index: 2) +constants.setConstantValue(&A_trans, type: .bool, index: 10) +constants.setConstantValue(&B_trans, type: .bool, index: 11) +constants.setConstantValue(&D_trans, type: .bool, index: 13) +constants.setConstantValue(&alpha, type: .float, index: 20) +constants.setConstantValue(&beta, type: .float, index: 21) +constants.setConstantValue(&batched, type: .bool, index: 100) +constants.setConstantValue(&fused_activation, type: .bool, index: 101) +constants.setConstantValue(&fused_bias, type: .bool, index: 50001) + + +var M_simd = UInt16(16) +var N_simd = UInt16(16) +var K_simd = UInt16(32) +var M_splits = UInt16(2) +var N_splits = UInt16(2) +constants.setConstantValue(&M_simd, type: .ushort, index: 200) +constants.setConstantValue(&N_simd, type: .ushort, index: 201) +constants.setConstantValue(&K_simd, type: .ushort, index: 202) +constants.setConstantValue(&M_splits, type: .ushort, index: 210) +constants.setConstantValue(&N_splits, type: .ushort, index: 211) + +let M_group = M_simd * M_splits +let N_group = N_simd * N_splits + +// Satisfy Metal API validation. +#if DEBUG +do { + var garbage: SIMD4 = .zero + constants.setConstantValue(&garbage, type: .bool, index: 102) + constants.setConstantValue(&garbage, type: .bool, index: 103) + constants.setConstantValue(&garbage, type: .bool, index: 113) + constants.setConstantValue(&garbage, type: .bool, index: 50000) +} +#endif +print(constants) + +let device = MTLCopyAllDevices().first! +device.shouldMaximizeConcurrentCompilation = true + +var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!; +libraryURL.append(component: "src") +libraryURL.append(component: "libMetalFlashAttention.metallib") +let library = try! device.makeLibrary(URL: libraryURL) + +var name: String + switch dataType { + case .half: name = "hgemm" + case .float: name = "sgemm" + default: fatalError() + } +let function = try! library.makeFunction( + name: name, constantValues: constants) + +let A_block_length = M_group * K_simd +let B_block_length = K_simd * N_group + +var blockElements = A_block_length + B_block_length; +if (M % 8 != 0) && (N % 8 != 0) { + let C_block_length = M_group * N_group; + blockElements = max(C_block_length, blockElements) +} +if fused_bias { + if D_trans { + blockElements = max(blockElements, M_group) + } else { + blockElements = max(blockElements, N_group) + } +} +// let blockBytes = blockElements * UInt16(dataType.size) +let elementSize = 4 +let blockBytes = blockElements * UInt16(elementSize) + +func ceilDivide(target: Int, granularity: UInt16) -> Int { + (target + Int(granularity) - 1) / Int(granularity) +} +var gridSize = MTLSize( + width: ceilDivide(target: N, granularity: N_group), + height: ceilDivide(target: M, granularity: M_group), + depth: 1) +let groupSize = MTLSize( + width: Int(32 * M_splits * N_splits), + height: 1, + depth: 1) + +let commandQueue = device.makeCommandQueue()! +let commandBuffer = commandQueue.makeCommandBuffer()! +let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)! +let pipeline = try device.makeComputePipelineState(function: function) + +let threadgroupMemoryLength = blockBytes; +print(threadgroupMemoryLength) +encoder.setComputePipelineState(pipeline) +encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0) + + +let rowsA = M; +let columnsA = K; +let rowsB = K; +let columnsB = N; +let rowsC = M; +let columnsC = N; +var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA) + +var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB) + +var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC) +for i in 0...stride, options: []) + +let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout.stride, options: []) + +let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout.stride, options: []) + +print(arrayA) +print(arrayB) + + +encoder.setBuffer(bufferA, offset: 0, index: 0) +encoder.setBuffer(bufferB, offset: 0, index: 1) +encoder.setBuffer(bufferC, offset: 0, index: 2) +var gridZ: Int = B +if batched{ + func byteStride(shape: [Int]) -> Int { + let rank = shape.count + var output = elementSize * shape[rank - 2] * shape[rank - 1] + if shape.dropLast(2).reduce(1, *) == 1 { + output = 0 + } + return output + } + let byteStrideA = M*K*elementSize + let byteStrideB = N*K*elementSize + let byteStrideC = M*N*elementSize + + let byteStrideD = 0 + // if let shapeD = tensors.d?.shape { + // let rank = shapeD.count + // byteStrideD = elementSize * shapeD[rank - 1] + // if shapeD.dropLast(1).reduce(1, *) == 1 { + // byteStrideD = 0 + // } + // } + withUnsafeTemporaryAllocation( + of: SIMD4.self, capacity: gridZ + ) { buffer in + for i in 0..>.stride + assert(MemoryLayout>.stride == 8 * 4) + encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10) + print("BATCHED") + print(buffer) + } +} +gridSize.depth = gridZ + + +print(gridSize, groupSize) +encoder.dispatchThreadgroups( + gridSize, threadsPerThreadgroup: groupSize +) +encoder.endEncoding() +commandBuffer.commit() + +commandBuffer.waitUntilCompleted() + var contents = bufferC!.contents(); + + var count = B * rowsA * columnsB; + + var typedPointer = contents.bindMemory(to: Float.self, capacity: count) + + var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count) + + print(Array(bufferedPointer)) diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 5805206b..136935e2 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -774,6 +774,16 @@ fn run_gemm( #[test] fn gemm() { + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); + let results = run_gemm((b, m, n, k), &lhs, lhs_stride, &rhs, rhs_stride); + assert_eq!( + approx(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); let (b, m, n, k) = (2, 2, 4, 3); let lhs_stride = vec![m * k, k, 1]; let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect();