mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
chore: switch to buffer
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
use metal::{
|
use metal::{
|
||||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLResourceOptions, MTLSize,
|
||||||
|
NSUInteger,
|
||||||
};
|
};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::ffi::c_void;
|
use std::ffi::c_void;
|
||||||
@ -1359,17 +1360,20 @@ pub fn call_gemm(
|
|||||||
// TODO byte_stride_d
|
// TODO byte_stride_d
|
||||||
let byte_stride_d = 0;
|
let byte_stride_d = 0;
|
||||||
|
|
||||||
let buffer: Vec<u64> = vec![
|
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
|
||||||
byte_stride_a as _,
|
for i in 0..b {
|
||||||
byte_stride_b as _,
|
buffer.push((i * byte_stride_a) as u64);
|
||||||
byte_stride_c as _,
|
buffer.push((i * byte_stride_b) as u64);
|
||||||
byte_stride_d as _,
|
buffer.push((i * byte_stride_c) as u64);
|
||||||
];
|
buffer.push((i * byte_stride_d) as u64);
|
||||||
encoder.set_bytes(
|
}
|
||||||
10,
|
|
||||||
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
let matrix_offsets = device.new_buffer_with_data(
|
||||||
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
buffer.as_ptr() as *const NSUInteger as *const c_void,
|
||||||
|
(buffer.len() * core::mem::size_of::<u64>()) as NSUInteger,
|
||||||
|
MTLResourceOptions::StorageModePrivate,
|
||||||
);
|
);
|
||||||
|
encoder.set_buffer(10, Some(&matrix_offsets), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
let grid_size = MTLSize {
|
let grid_size = MTLSize {
|
||||||
|
Binary file not shown.
Reference in New Issue
Block a user