mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Clippy pass.
This commit is contained in:
@ -59,6 +59,8 @@ impl From<String> for MetalError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type AllocatedBuffers = Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct MetalDevice {
|
pub struct MetalDevice {
|
||||||
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
/// Raw metal device: <https://developer.apple.com/documentation/metal/mtldevice?language=objc>
|
||||||
@ -103,7 +105,7 @@ pub struct MetalDevice {
|
|||||||
///
|
///
|
||||||
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
||||||
/// (strong_count = 1).
|
/// (strong_count = 1).
|
||||||
buffers: Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>,
|
buffers: AllocatedBuffers,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::fmt::Debug for MetalDevice {
|
impl std::fmt::Debug for MetalDevice {
|
||||||
@ -258,7 +260,7 @@ impl MetalDevice {
|
|||||||
let newbuffers = subbuffers
|
let newbuffers = subbuffers
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|s| Arc::strong_count(s) > 1)
|
.filter(|s| Arc::strong_count(s) > 1)
|
||||||
.map(|s| Arc::clone(s))
|
.map(Arc::clone)
|
||||||
.collect();
|
.collect();
|
||||||
*subbuffers = newbuffers;
|
*subbuffers = newbuffers;
|
||||||
}
|
}
|
||||||
@ -270,7 +272,7 @@ impl MetalDevice {
|
|||||||
let capture = metal::CaptureManager::shared();
|
let capture = metal::CaptureManager::shared();
|
||||||
let descriptor = metal::CaptureDescriptor::new();
|
let descriptor = metal::CaptureDescriptor::new();
|
||||||
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||||
descriptor.set_capture_device(&self);
|
descriptor.set_capture_device(self);
|
||||||
descriptor.set_output_url(path);
|
descriptor.set_output_url(path);
|
||||||
|
|
||||||
capture
|
capture
|
||||||
@ -1021,10 +1023,10 @@ impl BackendStorage for MetalStorage {
|
|||||||
&self.device.kernels,
|
&self.device.kernels,
|
||||||
name,
|
name,
|
||||||
(b, m, n, k),
|
(b, m, n, k),
|
||||||
&lhs_l.stride(),
|
lhs_l.stride(),
|
||||||
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
lhs_l.start_offset() * self.dtype.size_in_bytes(),
|
||||||
&self.buffer,
|
&self.buffer,
|
||||||
&rhs_l.stride(),
|
rhs_l.stride(),
|
||||||
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
rhs_l.start_offset() * rhs.dtype.size_in_bytes(),
|
||||||
&rhs.buffer,
|
&rhs.buffer,
|
||||||
&buffer,
|
&buffer,
|
||||||
@ -1274,11 +1276,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
||||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||||
}?;
|
}?;
|
||||||
Ok(Self::Storage::new(
|
Ok(Self::Storage::new(buffer, self.clone(), storage.dtype()))
|
||||||
buffer.into(),
|
|
||||||
self.clone(),
|
|
||||||
storage.dtype(),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand_uniform(
|
fn rand_uniform(
|
||||||
|
@ -574,7 +574,6 @@ fn run_reduce<T: Clone>(v: &[T], out_length: usize, name: &'static str) -> Vec<T
|
|||||||
|
|
||||||
let options = MTLResourceOptions::StorageModeManaged;
|
let options = MTLResourceOptions::StorageModeManaged;
|
||||||
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
let output = device.new_buffer((out_length * core::mem::size_of::<T>()) as u64, options);
|
||||||
let num_dims = 1;
|
|
||||||
let dims = vec![v.len()];
|
let dims = vec![v.len()];
|
||||||
let strides = vec![1];
|
let strides = vec![1];
|
||||||
call_reduce_strided(
|
call_reduce_strided(
|
||||||
|
@ -226,17 +226,17 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
|
|
||||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
|
let output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
|
||||||
candle_metal_kernels::call_last_softmax(
|
candle_metal_kernels::call_last_softmax(
|
||||||
device.metal_device(),
|
device.metal_device(),
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
&kernels,
|
kernels,
|
||||||
name,
|
name,
|
||||||
elem_count,
|
elem_count,
|
||||||
last_dim,
|
last_dim,
|
||||||
storage.buffer(),
|
storage.buffer(),
|
||||||
layout.start_offset() * storage.dtype().size_in_bytes(),
|
layout.start_offset() * storage.dtype().size_in_bytes(),
|
||||||
&mut output,
|
&output,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||||
|
Reference in New Issue
Block a user