mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Better version ?
This commit is contained in:
@ -96,6 +96,7 @@ impl MetalDevice {
|
||||
.map(|i| {
|
||||
// println!("Creating command buffer {i}");
|
||||
let command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
command_buffer.set_label(&format!("num {i}"));
|
||||
command_buffer.enqueue();
|
||||
command_buffer
|
||||
})
|
||||
@ -157,7 +158,7 @@ impl MetalDevice {
|
||||
for sub in &mut *subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
// println!("Reusing tensor {size} {name}");
|
||||
// return sub.clone();
|
||||
return sub.clone();
|
||||
}
|
||||
}
|
||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||
@ -177,7 +178,7 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||
self._new_buffer(size, MTLResourceOptions::StorageModeShared, "managed")
|
||||
}
|
||||
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
||||
@ -185,19 +186,22 @@ impl MetalDevice {
|
||||
let tmp = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModeManaged,
|
||||
metal::MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
let real = self._new_buffer(
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModePrivate,
|
||||
"with_data",
|
||||
);
|
||||
let command = self.command_buffer();
|
||||
let blit = command.new_blit_command_encoder();
|
||||
let command_buffer = self.command_buffer();
|
||||
command_buffer.set_label("with_data");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("with_data_blit");
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
blit.end_encoding();
|
||||
command.commit();
|
||||
real.did_modify_range(metal::NSRange::new(0, real.length()));
|
||||
command_buffer.commit();
|
||||
drop(command_buffer);
|
||||
// real.did_modify_range(metal::NSRange::new(0, real.length()));
|
||||
// println!("Command {:?}", command.status());
|
||||
|
||||
// self.commit();
|
||||
@ -220,15 +224,29 @@ impl MetalDevice {
|
||||
dtype: DType,
|
||||
) -> Result<(Matrix, Arc<Buffer>)> {
|
||||
let elem_count = (b * m * n) as usize;
|
||||
let out_buffer = self.new_buffer(elem_count, dtype, "matrix");
|
||||
let buffer = self.new_buffer(elem_count, dtype, "matrix");
|
||||
let command_buffer = self.command_buffer();
|
||||
command_buffer.set_label("zeros_matmul");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.fill_buffer(
|
||||
&buffer,
|
||||
metal::NSRange {
|
||||
location: 0,
|
||||
length: buffer.length(),
|
||||
},
|
||||
0,
|
||||
);
|
||||
blit.end_encoding();
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
|
||||
let result_descriptor =
|
||||
MatrixDescriptor::init_multiple(m, n, b, n * size, m * n * size, type_id);
|
||||
let result_matrix = Matrix::init_with_buffer_descriptor(&out_buffer, 0, &result_descriptor)
|
||||
let result_matrix = Matrix::init_with_buffer_descriptor(&buffer, 0, &result_descriptor)
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
Ok((result_matrix, out_buffer))
|
||||
Ok((result_matrix, buffer))
|
||||
}
|
||||
|
||||
pub fn capture<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||
@ -298,11 +316,13 @@ impl BackendStorage for MetalStorage {
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
||||
{
|
||||
let command_buffer = self.device.command_buffer();
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.end_encoding();
|
||||
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
}
|
||||
self.device.wait_until_completed();
|
||||
|
||||
@ -550,8 +570,9 @@ impl BackendStorage for MetalStorage {
|
||||
let shape = layout.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, "todtype");
|
||||
device.wait_until_completed();
|
||||
let command_buffer = device.command_buffer();
|
||||
if layout.is_contiguous() {
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
@ -593,8 +614,10 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.set_label("to_dtype");
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
device.wait_until_completed();
|
||||
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
@ -606,6 +629,7 @@ impl BackendStorage for MetalStorage {
|
||||
let el_count = shape.elem_count();
|
||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL);
|
||||
let command_buffer = device.command_buffer();
|
||||
command_buffer.set_label(B::KERNEL);
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
|
||||
@ -695,7 +719,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.set_label("unary");
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
@ -962,7 +985,6 @@ impl BackendStorage for MetalStorage {
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
// Create descriptors
|
||||
|
||||
let (type_id, size) = match self.dtype {
|
||||
DType::F32 => (
|
||||
metal::mps::MPS_FLOATBIT_ENCODING | 32,
|
||||
@ -1028,9 +1050,11 @@ impl BackendStorage for MetalStorage {
|
||||
.new_matrix((b, m, n), size, type_id, self.dtype)?;
|
||||
|
||||
let command_buffer = self.device.command_buffer();
|
||||
command_buffer.set_label("matmul");
|
||||
|
||||
let alpha = 1.0f64;
|
||||
let beta = 0.0f64;
|
||||
// let beta = f64::MIN;
|
||||
let beta = 1.0;
|
||||
// Create kernel
|
||||
let matrix_multiplication = MatrixMultiplication::init(
|
||||
&self.device,
|
||||
@ -1045,6 +1069,8 @@ impl BackendStorage for MetalStorage {
|
||||
.ok_or_else(|| {
|
||||
MetalError::from("Failed to create matrix multiplication kernel".to_string())
|
||||
})?;
|
||||
matrix_multiplication.set_batch_size(b);
|
||||
matrix_multiplication.set_batch_start(0);
|
||||
|
||||
// Encode kernel to command buffer
|
||||
matrix_multiplication.encode_to_command_buffer(
|
||||
@ -1053,7 +1079,6 @@ impl BackendStorage for MetalStorage {
|
||||
&right_matrix,
|
||||
&result_matrix,
|
||||
);
|
||||
command_buffer.set_label("matmul");
|
||||
command_buffer.commit();
|
||||
out_buffer.did_modify_range(metal::NSRange::new(0, out_buffer.length()));
|
||||
// println!("========= MATMUL {:?}", Arc::strong_count(&out_buffer));
|
||||
@ -1062,9 +1087,11 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
// println!("Copy strided");
|
||||
if src_l.is_contiguous() && self.dtype == dst.dtype() {
|
||||
command_buffer.set_label("copy_contiguous");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("copy_contiguous");
|
||||
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
@ -1100,8 +1127,6 @@ impl BackendStorage for MetalStorage {
|
||||
command_buffer.set_label("copy_strided");
|
||||
}
|
||||
command_buffer.commit();
|
||||
dst.buffer
|
||||
.did_modify_range(metal::NSRange::new(0, dst.buffer.length()));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -1157,13 +1182,14 @@ impl BackendDevice for MetalDevice {
|
||||
// println!("CREATING DEVICE");
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
let n = 50;
|
||||
let n = 64;
|
||||
let command_queue = device.new_command_queue();
|
||||
|
||||
let command_buffers = (0..n)
|
||||
.map(|_| {
|
||||
.map(|i| {
|
||||
let command_buffer = command_queue.new_command_buffer().to_owned();
|
||||
command_buffer.enqueue();
|
||||
command_buffer.set_label(&format!("num {i}"));
|
||||
command_buffer
|
||||
})
|
||||
.collect();
|
||||
@ -1198,6 +1224,7 @@ impl BackendDevice for MetalDevice {
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros");
|
||||
let command_buffer = self.command_buffer();
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.fill_buffer(
|
||||
&buffer,
|
||||
@ -1208,7 +1235,6 @@ impl BackendDevice for MetalDevice {
|
||||
0,
|
||||
);
|
||||
blit.end_encoding();
|
||||
|
||||
command_buffer.commit();
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||
|
@ -144,6 +144,7 @@ impl RotaryEmbedding {
|
||||
let freqs = t.matmul(&inv_freq)?;
|
||||
let sin = freqs.sin()?;
|
||||
let cos = freqs.cos()?;
|
||||
// todo!("{}", sin);
|
||||
Ok(Self { sin, cos })
|
||||
}
|
||||
|
||||
@ -272,10 +273,10 @@ impl MHA {
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let view = xs.to_string();
|
||||
if view.contains("NaN") {
|
||||
panic!("NaN");
|
||||
}
|
||||
// let view = xs.to_string();
|
||||
// if view.contains("NaN") {
|
||||
// panic!("NaN");
|
||||
// }
|
||||
let _enter = self.span.enter();
|
||||
let (b_size, seq_len, _n_embd) = xs.dims3()?;
|
||||
let qkv = self
|
||||
|
Reference in New Issue
Block a user