Working with merging encoders and using fences.

This commit is contained in:
Nicolas Patry
2023-12-14 16:05:33 +01:00
parent 931432ed55
commit 361f2ad2af
5 changed files with 279 additions and 94 deletions

View File

@ -184,19 +184,21 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
type Libraries = HashMap<Source, Library>;
type Pipelines = HashMap<(&'static str, Option<ConstantValues>), ComputePipelineState>;
#[derive(Debug, Default)]
#[derive(Debug)]
pub struct Kernels {
libraries: RwLock<Libraries>,
pipelines: RwLock<Pipelines>,
fence: metal::Fence,
}
impl Kernels {
pub fn new() -> Self {
pub fn new(fence: metal::Fence) -> Self {
let libraries = RwLock::new(Libraries::new());
let pipelines = RwLock::new(Pipelines::new());
Self {
libraries,
pipelines,
fence,
}
}
@ -304,12 +306,14 @@ pub fn call_unary_contiguous(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -331,6 +335,7 @@ pub fn call_unary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@ -350,6 +355,7 @@ pub fn call_unary_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -368,6 +374,7 @@ pub fn call_binary_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Binary, kernel_name.0)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, left, right, output));
@ -375,6 +382,7 @@ pub fn call_binary_contiguous(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -399,6 +407,7 @@ pub fn call_binary_strided(
let num_dims: usize = shape.len();
let encoder = command_buffer.new_compute_command_encoder();
let width: usize = shape.iter().product();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@ -420,6 +429,7 @@ pub fn call_binary_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -438,12 +448,14 @@ pub fn call_cast_contiguous(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, (input, input_offset), output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -463,6 +475,7 @@ pub fn call_cast_strided(
let pipeline = kernels.load_pipeline(device, Source::Cast, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
@ -482,6 +495,7 @@ pub fn call_cast_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -501,6 +515,7 @@ pub fn call_reduce_contiguous(
let elements_to_sum = length / out_length;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -527,6 +542,7 @@ pub fn call_reduce_contiguous(
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -544,6 +560,7 @@ pub fn call_last_softmax(
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (length, elements_to_sum, input, output));
@ -569,6 +586,7 @@ pub fn call_last_softmax(
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -588,12 +606,14 @@ pub fn call_affine(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -616,6 +636,7 @@ pub fn call_affine_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -634,6 +655,7 @@ pub fn call_affine_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -652,12 +674,14 @@ pub fn call_powf(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -679,6 +703,7 @@ pub fn call_powf_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -696,6 +721,7 @@ pub fn call_powf_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -714,12 +740,14 @@ pub fn call_elu(
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, input, output));
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -741,6 +769,7 @@ pub fn call_elu_strided(
let size: usize = shape.iter().product();
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -758,6 +787,7 @@ pub fn call_elu_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -779,6 +809,7 @@ pub fn call_where_cond_strided(
let pipeline = kernels.load_pipeline(device, Source::Ternary, name)?;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
let size: usize = shape.iter().product();
@ -803,6 +834,7 @@ pub fn call_where_cond_strided(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -829,6 +861,7 @@ pub fn call_index_select(
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
@ -848,6 +881,7 @@ pub fn call_index_select(
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())
}
@ -1045,6 +1079,7 @@ pub fn call_gemm(
let block_bytes = block_elements * bytes;
let encoder = command_buffer.new_compute_command_encoder();
encoder.wait_for_fence(&kernels.fence);
encoder.set_compute_pipeline_state(&pipeline);
// println!("Threadgroup {block_bytes}");
encoder.set_threadgroup_memory_length(0, block_bytes.into());
@ -1087,6 +1122,7 @@ pub fn call_gemm(
};
// println!("grid size {grid_size:?} group size {group_size:?}");
encoder.dispatch_thread_groups(grid_size, group_size);
encoder.update_fence(&kernels.fence);
encoder.end_encoding();
Ok(())

View File

@ -0,0 +1,209 @@
import Metal
import MetalPerformanceShadersGraph
let type = MTLDataType.float;
let dataType = type;
var B = 2;
var M = 2;
var N = 2;
var K = 2;
var A_trans = false;
var B_trans = false;
var D_trans = false;
var alpha = Float(1.0);
var beta = Float(0.0);
var batched = B > 1;
var fused_activation = false;
var fused_bias = false;
let constants = MTLFunctionConstantValues()
constants.setConstantValue(&M, type: .uint, index: 0)
constants.setConstantValue(&N, type: .uint, index: 1)
constants.setConstantValue(&K, type: .uint, index: 2)
constants.setConstantValue(&A_trans, type: .bool, index: 10)
constants.setConstantValue(&B_trans, type: .bool, index: 11)
constants.setConstantValue(&D_trans, type: .bool, index: 13)
constants.setConstantValue(&alpha, type: .float, index: 20)
constants.setConstantValue(&beta, type: .float, index: 21)
constants.setConstantValue(&batched, type: .bool, index: 100)
constants.setConstantValue(&fused_activation, type: .bool, index: 101)
constants.setConstantValue(&fused_bias, type: .bool, index: 50001)
var M_simd = UInt16(16)
var N_simd = UInt16(16)
var K_simd = UInt16(32)
var M_splits = UInt16(2)
var N_splits = UInt16(2)
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
constants.setConstantValue(&K_simd, type: .ushort, index: 202)
constants.setConstantValue(&M_splits, type: .ushort, index: 210)
constants.setConstantValue(&N_splits, type: .ushort, index: 211)
let M_group = M_simd * M_splits
let N_group = N_simd * N_splits
// Satisfy Metal API validation.
#if DEBUG
do {
var garbage: SIMD4<UInt64> = .zero
constants.setConstantValue(&garbage, type: .bool, index: 102)
constants.setConstantValue(&garbage, type: .bool, index: 103)
constants.setConstantValue(&garbage, type: .bool, index: 113)
constants.setConstantValue(&garbage, type: .bool, index: 50000)
}
#endif
let device = MTLCopyAllDevices().first!
device.shouldMaximizeConcurrentCompilation = true
var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!;
libraryURL.append(component: "src")
libraryURL.append(component: "libMetalFlashAttention.metallib")
let library = try! device.makeLibrary(URL: libraryURL)
var name: String
switch dataType {
case .half: name = "hgemm"
case .float: name = "sgemm"
default: fatalError()
}
let function = try! library.makeFunction(
name: name, constantValues: constants)
let A_block_length = M_group * K_simd
let B_block_length = K_simd * N_group
var blockElements = A_block_length + B_block_length;
if (M % 8 != 0) && (N % 8 != 0) {
let C_block_length = M_group * N_group;
blockElements = max(C_block_length, blockElements)
}
if fused_bias {
if D_trans {
blockElements = max(blockElements, M_group)
} else {
blockElements = max(blockElements, N_group)
}
}
// let blockBytes = blockElements * UInt16(dataType.size)
let elementSize = 4
let blockBytes = blockElements * UInt16(elementSize)
func ceilDivide(target: Int, granularity: UInt16) -> Int {
(target + Int(granularity) - 1) / Int(granularity)
}
var gridSize = MTLSize(
width: ceilDivide(target: N, granularity: N_group),
height: ceilDivide(target: M, granularity: M_group),
depth: 1)
let groupSize = MTLSize(
width: Int(32 * M_splits * N_splits),
height: 1,
depth: 1)
let commandQueue = device.makeCommandQueue()!
let threadgroupMemoryLength = blockBytes;
let rowsA = M;
let columnsA = K;
let rowsB = K;
let columnsB = N;
let rowsC = M;
let columnsC = N;
var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA)
var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB)
var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC)
var arrayD = [Float](repeating: 0, count: B * rowsC * columnsC)
for i in 0..<arrayA.count {
arrayA[i] = Float(i)
}
for i in 0..<arrayB.count {
arrayB[i] = Float(i)
}
let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])!
let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])!
let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])!
let bufferD = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])!
let pipeline = try device.makeComputePipelineState(function: function)
func call(bufferA: MTLBuffer, bufferB: MTLBuffer, bufferC: MTLBuffer){
let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)!
encoder.setComputePipelineState(pipeline)
encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0)
encoder.setBuffer(bufferA, offset: 0, index: 0)
encoder.setBuffer(bufferB, offset: 0, index: 1)
encoder.setBuffer(bufferC, offset: 0, index: 2)
let gridZ: Int = B
if batched{
func byteStride(shape: [Int]) -> Int {
let rank = shape.count
var output = elementSize * shape[rank - 2] * shape[rank - 1]
if shape.dropLast(2).reduce(1, *) == 1 {
output = 0
}
return output
}
let byteStrideA = M*K*elementSize
let byteStrideB = N*K*elementSize
let byteStrideC = M*N*elementSize
let byteStrideD = 0
withUnsafeTemporaryAllocation(
of: SIMD4<UInt64>.self, capacity: gridZ
) { buffer in
for i in 0..<buffer.count {
buffer[i] = SIMD4(
UInt64(truncatingIfNeeded: i * byteStrideA),
UInt64(truncatingIfNeeded: i * byteStrideB),
UInt64(truncatingIfNeeded: i * byteStrideC),
UInt64(truncatingIfNeeded: i * byteStrideD))
}
let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride
assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4)
encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
}
}
gridSize.depth = gridZ
encoder.dispatchThreadgroups(
gridSize, threadsPerThreadgroup: groupSize
)
encoder.endEncoding()
}
var commandBuffer = commandQueue.makeCommandBuffer()!
call(bufferA:bufferA, bufferB:bufferB, bufferC:bufferC)
commandBuffer.commit()
commandBuffer = commandQueue.makeCommandBuffer()!
commandBuffer.encodeWaitForEvent(event, value: 2)
call(bufferA:bufferA, bufferB:bufferC, bufferC:bufferD)
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
var contents = bufferC.contents();
var count = B * rowsA * columnsB;
var typedPointer = contents.bindMemory(to: Float.self, capacity: count)
var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
print("First matmul is OK", Array(bufferedPointer))
contents = bufferD.contents();
count = B * rowsA * columnsB;
typedPointer = contents.bindMemory(to: Float.self, capacity: count)
bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
print("This should be filled", Array(bufferedPointer))