diff --git a/candle-core/src/metal_backend.rs b/candle-core/src/metal_backend.rs index 6da2e2a9..d62bb159 100644 --- a/candle-core/src/metal_backend.rs +++ b/candle-core/src/metal_backend.rs @@ -4,7 +4,6 @@ use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT}; use crate::{CpuStorage, DType, Layout, Result, Shape}; use candle_metal_kernels; use candle_metal_kernels::Kernels; -use core::mem; use half::f16; use metal; use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}; @@ -94,7 +93,7 @@ impl MetalDevice { let option = metal::MTLResourceOptions::StorageModeManaged; self.device.new_buffer_with_data( data.as_ptr() as *const core::ffi::c_void, - (data.len() * mem::size_of::()) as NSUInteger, + core::mem::size_of_val(data) as NSUInteger, option, ) } @@ -123,13 +122,11 @@ impl BackendStorage for MetalStorage { } fn to_cpu_storage(&self) -> Result { - // let start = std::time::Instant::now(); self.device.wait_until_completed(); - // println!("Wait took {:?}", start.elapsed()); match self.dtype { DType::U8 => Ok(CpuStorage::U8( - self.buffer.read_to_vec(self.buffer.length() as usize / 1), + self.buffer.read_to_vec(self.buffer.length() as usize), )), DType::U32 => Ok(CpuStorage::U32( self.buffer.read_to_vec(self.buffer.length() as usize / 4), @@ -200,11 +197,11 @@ impl BackendStorage for MetalStorage { ) .unwrap(); } - return Ok(Self { + Ok(Self { buffer, device: device.clone(), dtype, - }); + }) } fn powf(&self, _: &Layout, _: f64) -> Result { @@ -499,10 +496,10 @@ impl BackendStorage for MetalStorage { kernel_name, lhs_l.dims(), &self.buffer, - &lhs_l.stride(), + lhs_l.stride(), lhs_l.start_offset() * self.dtype.size_in_bytes(), &rhs.buffer, - &rhs_l.stride(), + rhs_l.stride(), rhs_l.start_offset() * rhs.dtype.size_in_bytes(), &mut buffer, ) @@ -535,7 +532,7 @@ impl BackendStorage for MetalStorage { &command_buffer, &device.kernels, "where_u8_f32", - &dims, + dims, &self.buffer, ( layout.stride(), @@ -853,7 +850,7 @@ impl BackendStorage for MetalStorage { kernel_name, src_l.dims(), &self.buffer, - &src_l.stride(), + src_l.stride(), src_l.start_offset() * self.dtype.size_in_bytes(), &mut dst.buffer, dst_offset * dst.dtype.size_in_bytes(),