Better version ?

This commit is contained in:
Nicolas Patry
2023-12-13 12:09:20 +01:00
parent 87dc559817
commit a9d0657432
2 changed files with 52 additions and 25 deletions

View File

@ -96,6 +96,7 @@ impl MetalDevice {
.map(|i| {
// println!("Creating command buffer {i}");
let command_buffer = self.command_queue.new_command_buffer().to_owned();
command_buffer.set_label(&format!("num {i}"));
command_buffer.enqueue();
command_buffer
})
@ -157,7 +158,7 @@ impl MetalDevice {
for sub in &mut *subbuffers {
if Arc::strong_count(sub) == 1 {
// println!("Reusing tensor {size} {name}");
// return sub.clone();
return sub.clone();
}
}
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
@ -177,7 +178,7 @@ impl MetalDevice {
}
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
self._new_buffer(size, MTLResourceOptions::StorageModeShared, "managed")
}
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
@ -185,19 +186,22 @@ impl MetalDevice {
let tmp = self.device.new_buffer_with_data(
data.as_ptr() as *const core::ffi::c_void,
size,
metal::MTLResourceOptions::StorageModeManaged,
metal::MTLResourceOptions::StorageModeShared,
);
let real = self._new_buffer(
size,
metal::MTLResourceOptions::StorageModePrivate,
"with_data",
);
let command = self.command_buffer();
let blit = command.new_blit_command_encoder();
let command_buffer = self.command_buffer();
command_buffer.set_label("with_data");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("with_data_blit");
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
blit.end_encoding();
command.commit();
real.did_modify_range(metal::NSRange::new(0, real.length()));
command_buffer.commit();
drop(command_buffer);
// real.did_modify_range(metal::NSRange::new(0, real.length()));
// println!("Command {:?}", command.status());
// self.commit();
@ -220,15 +224,29 @@ impl MetalDevice {
dtype: DType,
) -> Result<(Matrix, Arc<Buffer>)> {
let elem_count = (b * m * n) as usize;
let out_buffer = self.new_buffer(elem_count, dtype, "matrix");
let buffer = self.new_buffer(elem_count, dtype, "matrix");
let command_buffer = self.command_buffer();
command_buffer.set_label("zeros_matmul");
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
&buffer,
metal::NSRange {
location: 0,
length: buffer.length(),
},
0,
);
blit.end_encoding();
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
let result_descriptor =
MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
let result_matrix = Matrix::init_with_buffer_descriptor(&buffer, 0, &result_descriptor)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
Ok((result_matrix, out_buffer))
Ok((result_matrix, buffer))
}
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
@ -298,11 +316,13 @@ impl BackendStorage for MetalStorage {
let buffer = self.device.new_buffer_managed(self.buffer.length());
{
let command_buffer = self.device.command_buffer();
command_buffer.set_label("to_cpu");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("blit_to_cpu");
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
blit.end_encoding();
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
}
self.device.wait_until_completed();
@ -550,8 +570,9 @@ impl BackendStorage for MetalStorage {
let shape = layout.shape();
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, "todtype");
device.wait_until_completed();
let command_buffer = device.command_buffer();
if layout.is_contiguous() {
if layout.is_contiguous() && layout.start_offset() == 0 {
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
(DType::U32, DType::U8) => "cast_u32_u8",
@ -593,8 +614,10 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
command_buffer.set_label("to_dtype");
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
device.wait_until_completed();
Ok(Self::new(buffer, device.clone(), dtype))
}
@ -606,6 +629,7 @@ impl BackendStorage for MetalStorage {
let el_count = shape.elem_count();
let buffer = device.new_buffer(el_count, dtype, B::KERNEL);
let command_buffer = device.command_buffer();
command_buffer.set_label(B::KERNEL);
if layout.is_contiguous() && layout.start_offset() == 0 {
use candle_metal_kernels::unary::contiguous;
@ -695,7 +719,6 @@ impl BackendStorage for MetalStorage {
)
.map_err(MetalError::from)?;
}
command_buffer.set_label("unary");
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype))
@ -962,7 +985,6 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
// Create descriptors
let (type_id, size) = match self.dtype {
DType::F32 => (
metal::mps::MPS_FLOATBIT_ENCODING | 32,
@ -1028,9 +1050,11 @@ impl BackendStorage for MetalStorage {
.new_matrix((b, m, n), size, type_id, self.dtype)?;
let command_buffer = self.device.command_buffer();
command_buffer.set_label("matmul");
let alpha = 1.0f64;
let beta = 0.0f64;
// let beta = f64::MIN;
let beta = 1.0;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
@ -1045,6 +1069,8 @@ impl BackendStorage for MetalStorage {
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
matrix_multiplication.set_batch_size(b);
matrix_multiplication.set_batch_start(0);
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
@ -1053,7 +1079,6 @@ impl BackendStorage for MetalStorage {
&right_matrix,
&result_matrix,
);
command_buffer.set_label("matmul");
command_buffer.commit();
out_buffer.did_modify_range(metal::NSRange::new(0, out_buffer.length()));
// println!("========= MATMUL {:?}", Arc::strong_count(&out_buffer));
@ -1062,9 +1087,11 @@ impl BackendStorage for MetalStorage {
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let command_buffer = self.device.command_buffer();
// println!("Copy strided");
if src_l.is_contiguous() && self.dtype == dst.dtype() {
command_buffer.set_label("copy_contiguous");
let blit = command_buffer.new_blit_command_encoder();
blit.set_label("copy_contiguous");
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
@ -1100,8 +1127,6 @@ impl BackendStorage for MetalStorage {
command_buffer.set_label("copy_strided");
}
command_buffer.commit();
dst.buffer
.did_modify_range(metal::NSRange::new(0, dst.buffer.length()));
Ok(())
}
}
@ -1157,13 +1182,14 @@ impl BackendDevice for MetalDevice {
// println!("CREATING DEVICE");
let device = metal::Device::all().swap_remove(ordinal);
let n = 50;
let n = 64;
let command_queue = device.new_command_queue();
let command_buffers = (0..n)
.map(|_| {
.map(|i| {
let command_buffer = command_queue.new_command_buffer().to_owned();
command_buffer.enqueue();
command_buffer.set_label(&format!("num {i}"));
command_buffer
})
.collect();
@ -1198,6 +1224,7 @@ impl BackendDevice for MetalDevice {
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros");
let command_buffer = self.command_buffer();
command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder();
blit.fill_buffer(
&buffer,
@ -1208,7 +1235,6 @@ impl BackendDevice for MetalDevice {
0,
);
blit.end_encoding();
command_buffer.commit();
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(MetalStorage::new(buffer, self.clone(), dtype))

View File

@ -144,6 +144,7 @@ impl RotaryEmbedding {
let freqs = t.matmul(&inv_freq)?;
let sin = freqs.sin()?;
let cos = freqs.cos()?;
// todo!("{}", sin);
Ok(Self { sin, cos })
}
@ -272,10 +273,10 @@ impl MHA {
}
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
let view = xs.to_string();
if view.contains("NaN") {
panic!("NaN");
}
// let view = xs.to_string();
// if view.contains("NaN") {
// panic!("NaN");
// }
let _enter = self.span.enter();
let (b_size, seq_len, _n_embd) = xs.dims3()?;
let qkv = self