diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index b8b951f0..b24db020 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -91,7 +91,7 @@ impl MetalDevice { metal::MTLCommandBufferStatus::Committed | metal::MTLCommandBufferStatus::Scheduled | metal::MTLCommandBufferStatus::Completed => { - panic!("Alredy committed"); + panic!("Already committed"); } _ => {} } @@ -166,9 +166,6 @@ impl MetalDevice { blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length()); blit.update_fence(&self.fence); blit.end_encoding(); - // drop(command_buffer); - // real.did_modify_range(metal::NSRange::new(0, real.length())); - // println!("Command {:?}", command.status()); // This is necessary, for mmaped safetensors // Because of the unsafe slice cast we're doing. @@ -245,11 +242,7 @@ impl BackendStorage for MetalStorage { DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))), DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))), DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))), - DType::F32 => { - let vec = read_to_vec(&buffer, length / size); - // println!("Got back {:?}", &vec[..1]); - Ok(CpuStorage::F32(vec)) - } + DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))), DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))), } } @@ -302,7 +295,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - // buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -401,7 +393,6 @@ impl BackendStorage for MetalStorage { ) .map_err(MetalError::from)?; } - buffer.did_modify_range(metal::NSRange::new(0, buffer.length())); Ok(Self::new(buffer, device.clone(), dtype)) } @@ -644,21 +635,13 @@ impl BackendStorage for MetalStorage { let kernel_name = match (B::KERNEL, dtype) { ("add", DType::F32) => contiguous::add::FLOAT, - // ("badd", DType::F32) => contiguous::add::FLOAT, ("sub", DType::F32) => contiguous::sub::FLOAT, - //("bsub", DType::F32) => contiguous::sub::FLOAT, ("mul", DType::F32) => contiguous::mul::FLOAT, - // ("bmul", DType::F32) => contiguous::mul::FLOAT, ("div", DType::F32) => contiguous::div::FLOAT, - // ("bdiv", DType::F32) => contiguous::div::FLOAT, ("add", DType::F16) => contiguous::add::HALF, - // ("badd", DType::F16) => contiguous::add::HALF, ("sub", DType::F16) => contiguous::sub::HALF, - // ("bsub", DType::F16) => contiguous::sub::HALF, ("mul", DType::F16) => contiguous::mul::HALF, - // ("bmul", DType::F16) => contiguous::mul::HALF, ("div", DType::F16) => contiguous::div::HALF, - // ("bdiv", DType::F16) => contiguous::div::HALF, (name, dtype) => crate::bail!("Match {name} - {dtype:?}"), }; candle_metal_kernels::call_binary_contiguous( @@ -877,8 +860,6 @@ impl BackendStorage for MetalStorage { lhs_l: &Layout, rhs_l: &Layout, ) -> Result { - // Create descriptors - let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul"); let name = match self.dtype { DType::F32 => "sgemm", @@ -889,8 +870,6 @@ impl BackendStorage for MetalStorage { }; let command_buffer = self.device.command_buffer(); - // println!("MATMUL {b} {m} {n} {k}"); - // println!("strides {:?} {:?}", lhs_l.stride(), rhs_l.stride()); command_buffer.set_label("matmul"); candle_metal_kernels::call_gemm( &self.device.device, @@ -907,14 +886,11 @@ impl BackendStorage for MetalStorage { &buffer, ) .map_err(MetalError::from)?; - // Create kernel - Ok(Self::new(buffer, self.device.clone(), self.dtype())) } fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> { let command_buffer = self.device.command_buffer(); - // println!("Copy strided"); if src_l.is_contiguous() && self.dtype == dst.dtype() { command_buffer.set_label("copy_contiguous"); let blit = command_buffer.new_blit_command_encoder(); @@ -975,7 +951,6 @@ impl BackendDevice for MetalDevice { type Storage = MetalStorage; fn new(ordinal: usize) -> Result { - // println!("CREATING DEVICE"); let device = metal::Device::all().swap_remove(ordinal); let n = 1; @@ -1024,6 +999,7 @@ impl BackendDevice for MetalDevice { let command_buffer = self.command_buffer(); command_buffer.set_label("zeros"); let blit = command_buffer.new_blit_command_encoder(); + blit.wait_for_fence(&self.fence); blit.fill_buffer( &buffer, metal::NSRange { @@ -1032,6 +1008,7 @@ impl BackendDevice for MetalDevice { }, 0, ); + blit.update_fence(&self.fence); blit.end_encoding(); Ok(MetalStorage::new(buffer, self.clone(), dtype)) }