Adding cast + binary kernels.

This commit is contained in:
Nicolas Patry
2023-11-07 23:45:53 +01:00
parent 0c24a885a6
commit 480a3e22e6
7 changed files with 601 additions and 84 deletions

View File

@ -113,11 +113,11 @@ impl BackendStorage for MetalStorage {
debug!("{shape:?} {el:?} {:?}", layout.stride());
let output_buffer = device.new_buffer(el, self.dtype);
// return Ok(Self {
// buffer: output_buffer,
// device: device.clone(),
// dtype,
// });
return Ok(Self {
buffer: output_buffer,
device: device.clone(),
dtype,
});
let function = self
.device
.kernels
@ -185,9 +185,9 @@ impl BackendStorage for MetalStorage {
start.elapsed()
);
let capture = metal::CaptureManager::shared();
capture.stop_capture();
panic!("Done");
// let capture = metal::CaptureManager::shared();
// capture.stop_capture();
// panic!("Done");
Ok(Self {
buffer: output_buffer,
@ -283,7 +283,58 @@ impl BackendStorage for MetalStorage {
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
todo!("Implement {:?} {layout:?} - {dtype:?}", self.dtype)
let device = self.device();
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
// TODO remove
// return Ok(Self {
// buffer,
// device: device.clone(),
// dtype,
// });
let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
(left, right) => todo!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
todo!(
"TODO Implement the kernel calling cast {:?}-{:?}",
self.dtype,
dtype
);
}
let start = std::time::Instant::now();
command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"cast {:?} - {:?} - {:?} - {:?}",
dtype,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
@ -294,11 +345,11 @@ impl BackendStorage for MetalStorage {
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
// TODO remove
return Ok(Self {
buffer,
device: device.clone(),
dtype,
});
// return Ok(Self {
// buffer,
// device: device.clone(),
// dtype,
// });
let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
@ -328,7 +379,7 @@ impl BackendStorage for MetalStorage {
let start = std::time::Instant::now();
command_buffer.commit();
command_buffer.wait_until_completed();
// command_buffer.wait_until_scheduled();
debug!(
"Unary {:?} - {:?} - {:?} - {:?}",
B::KERNEL,
@ -344,10 +395,87 @@ impl BackendStorage for MetalStorage {
})
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
debug!("TODO Binary {:?}", B::NAME);
Ok(self.clone())
// todo!()
fn binary_impl<B: BinaryOpT>(
&self,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = lhs_l.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
use candle_metal_kernels::binary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("add", DType::F32) => contiguous::add::FLOAT,
("badd", DType::F32) => contiguous::add::FLOAT,
("sub", DType::F32) => contiguous::sub::FLOAT,
("bsub", DType::F32) => contiguous::sub::FLOAT,
("mul", DType::F32) => contiguous::mul::FLOAT,
("bmul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
("bdiv", DType::F32) => contiguous::div::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&rhs.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
use candle_metal_kernels::binary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("badd", DType::F32) => strided::add::FLOAT,
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
lhs_l.dims(),
&self.buffer,
&lhs_l.stride(),
lhs_l.start_offset(),
&rhs.buffer,
&rhs_l.stride(),
rhs_l.start_offset(),
&mut buffer,
)
.map_err(MetalError::from)?;
}
let start = std::time::Instant::now();
command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"Binary {:?} - {:?} - {:?} - {:?}",
B::KERNEL,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
@ -546,25 +674,25 @@ impl MetalStorage {
}
debug!("GEMM");
let command_buffer = self.device.command_queue.new_command_buffer();
encode_gemm::<Float32, Float32, Float32>(
&self.device,
&command_buffer,
transpose_left,
transpose_right,
&self.buffer,
&rhs.buffer,
&mut out_buffer,
m as NSUInteger,
n as NSUInteger,
k as NSUInteger,
alpha,
beta,
)
.map_err(MetalError::from)?;
// let command_buffer = self.device.command_queue.new_command_buffer();
// encode_gemm::<Float32, Float32, Float32>(
// &self.device,
// &command_buffer,
// transpose_left,
// transpose_right,
// &self.buffer,
// &rhs.buffer,
// &mut out_buffer,
// m as NSUInteger,
// n as NSUInteger,
// k as NSUInteger,
// alpha,
// beta,
// )
// .map_err(MetalError::from)?;
command_buffer.commit();
command_buffer.wait_until_scheduled();
// command_buffer.commit();
// command_buffer.wait_until_scheduled();
// println!("lhs {:?} {m} {k}", self.buffer.length());
// println!("rhs {:?} {k} {n}", rhs.buffer.length());
@ -588,18 +716,18 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
let capture = metal::CaptureManager::shared();
let descriptor = metal::CaptureDescriptor::new();
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
println!("{:?}", std::env::current_dir()?);
descriptor.set_capture_device(&device);
let mut dir = std::env::current_dir()?;
dir.push("out.gputrace");
descriptor.set_output_url(dir);
// let capture = metal::CaptureManager::shared();
// let descriptor = metal::CaptureDescriptor::new();
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
// println!("{:?}", std::env::current_dir()?);
// descriptor.set_capture_device(&device);
// let mut dir = std::env::current_dir()?;
// dir.push("out.gputrace");
// descriptor.set_output_url(dir);
capture
.start_capture(&descriptor)
.map_err(MetalError::from)?;
// capture
// .start_capture(&descriptor)
// .map_err(MetalError::from)?;
let command_queue = device.new_command_queue();
// let command_buffer = _command_queue.new_owned_command_buffer();
let kernels = Arc::new(Kernels::new());