This commit is contained in:
Nicolas Patry
2023-12-15 01:41:14 +01:00
parent ece4c69a68
commit 40c3e1bd5a

View File

@ -91,7 +91,7 @@ impl MetalDevice {
metal::MTLCommandBufferStatus::Committed metal::MTLCommandBufferStatus::Committed
| metal::MTLCommandBufferStatus::Scheduled | metal::MTLCommandBufferStatus::Scheduled
| metal::MTLCommandBufferStatus::Completed => { | metal::MTLCommandBufferStatus::Completed => {
panic!("Alredy committed"); panic!("Already committed");
} }
_ => {} _ => {}
} }
@ -166,9 +166,6 @@ impl MetalDevice {
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
blit.update_fence(&self.fence); blit.update_fence(&self.fence);
blit.end_encoding(); blit.end_encoding();
// drop(command_buffer);
// real.did_modify_range(metal::NSRange::new(0, real.length()));
// println!("Command {:?}", command.status());
// This is necessary, for mmaped safetensors // This is necessary, for mmaped safetensors
// Because of the unsafe slice cast we're doing. // Because of the unsafe slice cast we're doing.
@ -245,11 +242,7 @@ impl BackendStorage for MetalStorage {
DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))), DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))), DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))), DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
DType::F32 => { DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))),
let vec = read_to_vec(&buffer, length / size);
// println!("Got back {:?}", &vec[..1]);
Ok(CpuStorage::F32(vec))
}
DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))), DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
} }
} }
@ -302,7 +295,6 @@ impl BackendStorage for MetalStorage {
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype)) Ok(Self::new(buffer, device.clone(), dtype))
} }
@ -401,7 +393,6 @@ impl BackendStorage for MetalStorage {
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
} }
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
Ok(Self::new(buffer, device.clone(), dtype)) Ok(Self::new(buffer, device.clone(), dtype))
} }
@ -644,21 +635,13 @@ impl BackendStorage for MetalStorage {
let kernel_name = match (B::KERNEL, dtype) { let kernel_name = match (B::KERNEL, dtype) {
("add", DType::F32) => contiguous::add::FLOAT, ("add", DType::F32) => contiguous::add::FLOAT,
// ("badd", DType::F32) => contiguous::add::FLOAT,
("sub", DType::F32) => contiguous::sub::FLOAT, ("sub", DType::F32) => contiguous::sub::FLOAT,
//("bsub", DType::F32) => contiguous::sub::FLOAT,
("mul", DType::F32) => contiguous::mul::FLOAT, ("mul", DType::F32) => contiguous::mul::FLOAT,
// ("bmul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT,
// ("bdiv", DType::F32) => contiguous::div::FLOAT,
("add", DType::F16) => contiguous::add::HALF, ("add", DType::F16) => contiguous::add::HALF,
// ("badd", DType::F16) => contiguous::add::HALF,
("sub", DType::F16) => contiguous::sub::HALF, ("sub", DType::F16) => contiguous::sub::HALF,
// ("bsub", DType::F16) => contiguous::sub::HALF,
("mul", DType::F16) => contiguous::mul::HALF, ("mul", DType::F16) => contiguous::mul::HALF,
// ("bmul", DType::F16) => contiguous::mul::HALF,
("div", DType::F16) => contiguous::div::HALF, ("div", DType::F16) => contiguous::div::HALF,
// ("bdiv", DType::F16) => contiguous::div::HALF,
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"), (name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
}; };
candle_metal_kernels::call_binary_contiguous( candle_metal_kernels::call_binary_contiguous(
@ -877,8 +860,6 @@ impl BackendStorage for MetalStorage {
lhs_l: &Layout, lhs_l: &Layout,
rhs_l: &Layout, rhs_l: &Layout,
) -> Result<Self> { ) -> Result<Self> {
// Create descriptors
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul"); let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul");
let name = match self.dtype { let name = match self.dtype {
DType::F32 => "sgemm", DType::F32 => "sgemm",
@ -889,8 +870,6 @@ impl BackendStorage for MetalStorage {
}; };
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer();
// println!("MATMUL {b} {m} {n} {k}");
// println!("strides {:?} {:?}", lhs_l.stride(), rhs_l.stride());
command_buffer.set_label("matmul"); command_buffer.set_label("matmul");
candle_metal_kernels::call_gemm( candle_metal_kernels::call_gemm(
&self.device.device, &self.device.device,
@ -907,14 +886,11 @@ impl BackendStorage for MetalStorage {
&buffer, &buffer,
) )
.map_err(MetalError::from)?; .map_err(MetalError::from)?;
// Create kernel
Ok(Self::new(buffer, self.device.clone(), self.dtype())) Ok(Self::new(buffer, self.device.clone(), self.dtype()))
} }
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let command_buffer = self.device.command_buffer(); let command_buffer = self.device.command_buffer();
// println!("Copy strided");
if src_l.is_contiguous() && self.dtype == dst.dtype() { if src_l.is_contiguous() && self.dtype == dst.dtype() {
command_buffer.set_label("copy_contiguous"); command_buffer.set_label("copy_contiguous");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
@ -975,7 +951,6 @@ impl BackendDevice for MetalDevice {
type Storage = MetalStorage; type Storage = MetalStorage;
fn new(ordinal: usize) -> Result<Self> { fn new(ordinal: usize) -> Result<Self> {
// println!("CREATING DEVICE");
let device = metal::Device::all().swap_remove(ordinal); let device = metal::Device::all().swap_remove(ordinal);
let n = 1; let n = 1;
@ -1024,6 +999,7 @@ impl BackendDevice for MetalDevice {
let command_buffer = self.command_buffer(); let command_buffer = self.command_buffer();
command_buffer.set_label("zeros"); command_buffer.set_label("zeros");
let blit = command_buffer.new_blit_command_encoder(); let blit = command_buffer.new_blit_command_encoder();
blit.wait_for_fence(&self.fence);
blit.fill_buffer( blit.fill_buffer(
&buffer, &buffer,
metal::NSRange { metal::NSRange {
@ -1032,6 +1008,7 @@ impl BackendDevice for MetalDevice {
}, },
0, 0,
); );
blit.update_fence(&self.fence);
blit.end_encoding(); blit.end_encoding();
Ok(MetalStorage::new(buffer, self.clone(), dtype)) Ok(MetalStorage::new(buffer, self.clone(), dtype))
} }