Clippy pass.

This commit is contained in:
Nicolas Patry
2023-12-18 15:22:43 +01:00
parent 064ba17bd7
commit 03641293ee
3 changed files with 11 additions and 14 deletions

View File

@ -59,6 +59,8 @@ impl From<String> for MetalError {
} }
} }
type AllocatedBuffers = Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>;
#[derive(Clone)] #[derive(Clone)]
pub struct MetalDevice { pub struct MetalDevice {
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc> /// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
@ -103,7 +105,7 @@ pub struct MetalDevice {
/// ///
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers /// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
/// (strong_count = 1). /// (strong_count = 1).
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>, buffers: AllocatedBuffers,
} }
impl std::fmt::Debug for MetalDevice { impl std::fmt::Debug for MetalDevice {
@ -258,7 +260,7 @@ impl MetalDevice {
let newbuffers = subbuffers let newbuffers = subbuffers
.iter() .iter()
.filter(|s| Arc::strong_count(s) > 1) .filter(|s| Arc::strong_count(s) > 1)
.map(|s| Arc::clone(s)) .map(Arc::clone)
.collect(); .collect();
*subbuffers = newbuffers; *subbuffers = newbuffers;
} }
@ -270,7 +272,7 @@ impl MetalDevice {
let capture = metal::CaptureManager::shared(); let capture = metal::CaptureManager::shared();
let descriptor = metal::CaptureDescriptor::new(); let descriptor = metal::CaptureDescriptor::new();
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument); descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
descriptor.set_capture_device(&self); descriptor.set_capture_device(self);
descriptor.set_output_url(path); descriptor.set_output_url(path);
capture capture
@ -1021,10 +1023,10 @@ impl BackendStorage for MetalStorage {
&self.device.kernels, &self.device.kernels,
name, name,
(b, m, n, k), (b, m, n, k),
&lhs_l.stride(), lhs_l.stride(),
lhs_l.start_offset() * self.dtype.size_in_bytes(), lhs_l.start_offset() * self.dtype.size_in_bytes(),
&self.buffer, &self.buffer,
&rhs_l.stride(), rhs_l.stride(),
rhs_l.start_offset() * rhs.dtype.size_in_bytes(), rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
&rhs.buffer, &rhs.buffer,
&buffer, &buffer,
@ -1274,11 +1276,7 @@ impl BackendDevice for MetalDevice {
CpuStorage::F32(storage) => self.new_buffer_with_data(storage), CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
CpuStorage::F64(storage) => self.new_buffer_with_data(storage), CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
}?; }?;
Ok(Self::Storage::new( Ok(Self::Storage::new(buffer, self.clone(), storage.dtype()))
buffer.into(),
self.clone(),
storage.dtype(),
))
} }
fn rand_uniform( fn rand_uniform(

View File

@ -574,7 +574,6 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
let options = MTLResourceOptions::StorageModeManaged; let options = MTLResourceOptions::StorageModeManaged;
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options); let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
let num_dims = 1;
let dims = vec![v.len()]; let dims = vec![v.len()];
let strides = vec![1]; let strides = vec![1];
call_reduce_strided( call_reduce_strided(

View File

@ -226,17 +226,17 @@ impl candle::CustomOp1 for SoftmaxLastDim {
let last_dim = layout.dims()[layout.shape().rank() - 1]; let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count(); let elem_count = layout.shape().elem_count();
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax")?; let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
candle_metal_kernels::call_last_softmax( candle_metal_kernels::call_last_softmax(
device.metal_device(), device.metal_device(),
&command_buffer, &command_buffer,
&kernels, kernels,
name, name,
elem_count, elem_count,
last_dim, last_dim,
storage.buffer(), storage.buffer(),
layout.start_offset() * storage.dtype().size_in_bytes(), layout.start_offset() * storage.dtype().size_in_bytes(),
&mut output, &output,
) )
.unwrap(); .unwrap();
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype()); let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());