chore: final

This commit is contained in:
FL33TW00D
2024-01-22 15:15:19 +00:00
parent 73d79e6092
commit b6afb46601
2 changed files with 10 additions and 15 deletions

View File

@ -1,7 +1,6 @@
use metal::{ use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState, Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize, Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
NSUInteger,
}; };
use std::collections::HashMap; use std::collections::HashMap;
use std::ffi::c_void; use std::ffi::c_void;
@ -1360,21 +1359,17 @@ pub fn call_gemm(
// TODO byte_stride_d // TODO byte_stride_d
let byte_stride_d = 0; let byte_stride_d = 0;
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4); let buffer: Vec<u64> = vec![
for i in 0..b { byte_stride_a as _,
buffer.push((i * byte_stride_a) as u64); byte_stride_b as _,
buffer.push((i * byte_stride_b) as u64); byte_stride_c as _,
buffer.push((i * byte_stride_c) as u64); byte_stride_d as _,
buffer.push((i * byte_stride_d) as u64); ];
} encoder.set_bytes(
10,
let matrix_offsets = device.new_buffer_with_data(
buffer.as_ptr() as *const c_void,
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger, (buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
MTLResourceOptions::StorageModeManaged, buffer.as_ptr() as *const NSUInteger as *const c_void,
); );
encoder.set_buffer(10, Some(&matrix_offsets), 0);
encoder.use_resource(&matrix_offsets, metal::MTLResourceUsage::Read);
} }
let grid_size = MTLSize { let grid_size = MTLSize {