mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Remove unwrap()
.
This commit is contained in:
@ -8,7 +8,26 @@ use metal;
|
|||||||
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock, TryLockError};
|
||||||
|
|
||||||
|
/// Simple way to catch lock error without
|
||||||
|
/// depending on T
|
||||||
|
#[derive(thiserror::Error, Debug)]
|
||||||
|
pub enum LockError {
|
||||||
|
#[error("{0}")]
|
||||||
|
Poisoned(String),
|
||||||
|
#[error("Would block")]
|
||||||
|
WouldBlock,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> From<TryLockError<T>> for MetalError {
|
||||||
|
fn from(value: TryLockError<T>) -> Self {
|
||||||
|
match value {
|
||||||
|
TryLockError::Poisoned(p) => MetalError::LockError(LockError::Poisoned(p.to_string())),
|
||||||
|
TryLockError::WouldBlock => MetalError::LockError(LockError::WouldBlock),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -24,6 +43,8 @@ pub enum MetalError {
|
|||||||
rhs_stride: Vec<usize>,
|
rhs_stride: Vec<usize>,
|
||||||
mnk: (usize, usize, usize),
|
mnk: (usize, usize, usize),
|
||||||
},
|
},
|
||||||
|
#[error("{0:?}")]
|
||||||
|
LockError(LockError),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<String> for MetalError {
|
impl From<String> for MetalError {
|
||||||
@ -106,10 +127,13 @@ impl MetalDevice {
|
|||||||
&self.command_queue
|
&self.command_queue
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn command_buffer(&self) -> CommandBuffer {
|
pub fn command_buffer(&self) -> Result<CommandBuffer> {
|
||||||
let mut command_buffer_lock = self.command_buffer.try_write().unwrap();
|
let mut command_buffer_lock = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||||
let mut command_buffer = command_buffer_lock.to_owned();
|
let mut command_buffer = command_buffer_lock.to_owned();
|
||||||
let mut index = self.command_buffer_index.try_write().unwrap();
|
let mut index = self
|
||||||
|
.command_buffer_index
|
||||||
|
.try_write()
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
if *index > self.compute_per_buffer {
|
if *index > self.compute_per_buffer {
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
@ -117,11 +141,11 @@ impl MetalDevice {
|
|||||||
*index = 0;
|
*index = 0;
|
||||||
}
|
}
|
||||||
*index += 1;
|
*index += 1;
|
||||||
command_buffer
|
Ok(command_buffer)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn wait_until_completed(&self) {
|
pub fn wait_until_completed(&self) -> Result<()> {
|
||||||
let mut command_buffer = self.command_buffer.try_write().unwrap();
|
let mut command_buffer = self.command_buffer.try_write().map_err(MetalError::from)?;
|
||||||
match command_buffer.status() {
|
match command_buffer.status() {
|
||||||
metal::MTLCommandBufferStatus::Committed
|
metal::MTLCommandBufferStatus::Committed
|
||||||
| metal::MTLCommandBufferStatus::Scheduled
|
| metal::MTLCommandBufferStatus::Scheduled
|
||||||
@ -133,6 +157,7 @@ impl MetalDevice {
|
|||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn kernels(&self) -> &Kernels {
|
pub fn kernels(&self) -> &Kernels {
|
||||||
@ -148,7 +173,12 @@ impl MetalDevice {
|
|||||||
/// This means the buffer data cannot be read on the CPU directly.
|
/// This means the buffer data cannot be read on the CPU directly.
|
||||||
///
|
///
|
||||||
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
/// [`name`] is only used to keep track of the resource origin in case of bugs
|
||||||
pub fn new_buffer(&self, element_count: usize, dtype: DType, name: &str) -> Arc<Buffer> {
|
pub fn new_buffer(
|
||||||
|
&self,
|
||||||
|
element_count: usize,
|
||||||
|
dtype: DType,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<Arc<Buffer>> {
|
||||||
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
let size = (element_count * dtype.size_in_bytes()) as NSUInteger;
|
||||||
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
self.allocate_buffer(size, MTLResourceOptions::StorageModePrivate, name)
|
||||||
}
|
}
|
||||||
@ -158,7 +188,7 @@ impl MetalDevice {
|
|||||||
/// This means the buffer can be read on the CPU but will require manual
|
/// This means the buffer can be read on the CPU but will require manual
|
||||||
/// synchronization when the CPU memory is modified
|
/// synchronization when the CPU memory is modified
|
||||||
/// Used as a bridge to gather data back from the GPU
|
/// Used as a bridge to gather data back from the GPU
|
||||||
pub fn new_buffer_managed(&self, size: NSUInteger) -> Arc<Buffer> {
|
pub fn new_buffer_managed(&self, size: NSUInteger) -> Result<Arc<Buffer>> {
|
||||||
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
self.allocate_buffer(size, MTLResourceOptions::StorageModeManaged, "managed")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -168,7 +198,7 @@ impl MetalDevice {
|
|||||||
/// This method will block the computation because of the
|
/// This method will block the computation because of the
|
||||||
/// lack of lifetime management through the GPU.
|
/// lack of lifetime management through the GPU.
|
||||||
/// Internal comment for technical details.
|
/// Internal comment for technical details.
|
||||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Arc<Buffer> {
|
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||||
let tmp = self.device.new_buffer_with_data(
|
let tmp = self.device.new_buffer_with_data(
|
||||||
data.as_ptr() as *const core::ffi::c_void,
|
data.as_ptr() as *const core::ffi::c_void,
|
||||||
@ -179,8 +209,8 @@ impl MetalDevice {
|
|||||||
size,
|
size,
|
||||||
metal::MTLResourceOptions::StorageModePrivate,
|
metal::MTLResourceOptions::StorageModePrivate,
|
||||||
"with_data",
|
"with_data",
|
||||||
);
|
)?;
|
||||||
let command_buffer = self.command_buffer();
|
let command_buffer = self.command_buffer()?;
|
||||||
command_buffer.set_label("with_data");
|
command_buffer.set_label("with_data");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
blit.wait_for_fence(&self.fence);
|
blit.wait_for_fence(&self.fence);
|
||||||
@ -196,8 +226,8 @@ impl MetalDevice {
|
|||||||
// Putting this wait forces the GPU buffer to be filled
|
// Putting this wait forces the GPU buffer to be filled
|
||||||
// with the actual data allowing the CPU storage todo
|
// with the actual data allowing the CPU storage todo
|
||||||
// deallocate properly.
|
// deallocate properly.
|
||||||
self.wait_until_completed();
|
self.wait_until_completed()?;
|
||||||
real
|
Ok(real)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The critical allocator algorithm
|
/// The critical allocator algorithm
|
||||||
@ -206,13 +236,13 @@ impl MetalDevice {
|
|||||||
size: NSUInteger,
|
size: NSUInteger,
|
||||||
option: MTLResourceOptions,
|
option: MTLResourceOptions,
|
||||||
_name: &str,
|
_name: &str,
|
||||||
) -> Arc<Buffer> {
|
) -> Result<Arc<Buffer>> {
|
||||||
let mut buffers = self.buffers.try_write().unwrap();
|
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||||
|
|
||||||
for sub in &mut *subbuffers {
|
for sub in &mut *subbuffers {
|
||||||
if Arc::strong_count(sub) == 1 {
|
if Arc::strong_count(sub) == 1 {
|
||||||
return sub.clone();
|
return Ok(sub.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||||
@ -226,8 +256,7 @@ impl MetalDevice {
|
|||||||
.collect();
|
.collect();
|
||||||
*subbuffers = newbuffers;
|
*subbuffers = newbuffers;
|
||||||
}
|
}
|
||||||
|
Ok(new_buffer)
|
||||||
new_buffer
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a metal GPU capture trace on [`path`].
|
/// Create a metal GPU capture trace on [`path`].
|
||||||
@ -279,9 +308,9 @@ impl BackendStorage for MetalStorage {
|
|||||||
self.dtype
|
self.dtype
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let buffer = self.device.new_buffer_managed(self.buffer.length());
|
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||||
{
|
{
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer()?;
|
||||||
command_buffer.set_label("to_cpu");
|
command_buffer.set_label("to_cpu");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
blit.set_label("blit_to_cpu");
|
blit.set_label("blit_to_cpu");
|
||||||
@ -290,7 +319,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
blit.update_fence(&self.device.fence);
|
blit.update_fence(&self.device.fence);
|
||||||
blit.end_encoding();
|
blit.end_encoding();
|
||||||
}
|
}
|
||||||
self.device.wait_until_completed();
|
self.device.wait_until_completed()?;
|
||||||
|
|
||||||
match self.dtype {
|
match self.dtype {
|
||||||
DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
|
DType::U8 => Ok(CpuStorage::U8(read_to_vec(&buffer, length / size))),
|
||||||
@ -310,8 +339,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
|
|
||||||
let buffer = device.new_buffer(el, self.dtype, "affine");
|
let buffer = device.new_buffer(el, self.dtype, "affine")?;
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer()?;
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "affine_f32",
|
DType::F32 => "affine_f32",
|
||||||
@ -361,8 +390,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
|
|
||||||
let buffer = device.new_buffer(el, self.dtype, "powf");
|
let buffer = device.new_buffer(el, self.dtype, "powf")?;
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer()?;
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "powf_f32",
|
DType::F32 => "powf_f32",
|
||||||
@ -410,8 +439,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
|
|
||||||
let buffer = device.new_buffer(el, self.dtype, "elu");
|
let buffer = device.new_buffer(el, self.dtype, "elu")?;
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer()?;
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "elu_f32",
|
DType::F32 => "elu_f32",
|
||||||
@ -497,8 +526,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
if dtype == DType::U32 {
|
if dtype == DType::U32 {
|
||||||
crate::bail!("Implement return index reduce op");
|
crate::bail!("Implement return index reduce op");
|
||||||
}
|
}
|
||||||
let buffer = device.new_buffer(dst_el, dtype, "reduce");
|
let buffer = device.new_buffer(dst_el, dtype, "reduce")?;
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer()?;
|
||||||
candle_metal_kernels::call_reduce_contiguous(
|
candle_metal_kernels::call_reduce_contiguous(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
@ -523,8 +552,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let device = self.device();
|
let device = self.device();
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let buffer = device.new_buffer(el_count, dtype, "todtype");
|
let buffer = device.new_buffer(el_count, dtype, "todtype")?;
|
||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_buffer()?;
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
let kernel_name = match (self.dtype, dtype) {
|
let kernel_name = match (self.dtype, dtype) {
|
||||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||||
@ -576,8 +605,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL);
|
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_buffer()?;
|
||||||
command_buffer.set_label(B::KERNEL);
|
command_buffer.set_label(B::KERNEL);
|
||||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||||
use candle_metal_kernels::unary::contiguous;
|
use candle_metal_kernels::unary::contiguous;
|
||||||
@ -681,8 +710,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let shape = lhs_l.shape();
|
let shape = lhs_l.shape();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let buffer = device.new_buffer(el_count, dtype, B::KERNEL);
|
let buffer = device.new_buffer(el_count, dtype, B::KERNEL)?;
|
||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_buffer()?;
|
||||||
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
if (lhs_l.is_contiguous() && lhs_l.start_offset() == 0)
|
||||||
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
&& (rhs_l.is_contiguous() && rhs_l.start_offset() == 0)
|
||||||
&& &B::KERNEL[..1] != "b"
|
&& &B::KERNEL[..1] != "b"
|
||||||
@ -758,8 +787,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
let dtype = t.dtype;
|
let dtype = t.dtype;
|
||||||
let buffer = self.device.new_buffer(el, dtype, "where");
|
let buffer = self.device.new_buffer(el, dtype, "where")?;
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer()?;
|
||||||
if t.dtype() != f.dtype() {
|
if t.dtype() != f.dtype() {
|
||||||
crate::bail!("Invalid ternary different dtypes for values");
|
crate::bail!("Invalid ternary different dtypes for values");
|
||||||
}
|
}
|
||||||
@ -875,13 +904,13 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dst_el = ids_el * left_size * right_size;
|
let dst_el = ids_el * left_size * right_size;
|
||||||
let dtype = self.dtype;
|
let dtype = self.dtype;
|
||||||
let device = self.device();
|
let device = self.device();
|
||||||
let buffer = device.new_buffer(dst_el, dtype, "index_select");
|
let buffer = device.new_buffer(dst_el, dtype, "index_select")?;
|
||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
(DType::U32, DType::F32) => "is_u32_f32",
|
(DType::U32, DType::F32) => "is_u32_f32",
|
||||||
(DType::U32, DType::F16) => "is_u32_f16",
|
(DType::U32, DType::F16) => "is_u32_f16",
|
||||||
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
|
(left, right) => crate::bail!("index select metal {left:?} {right:?}"),
|
||||||
};
|
};
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer()?;
|
||||||
candle_metal_kernels::call_index_select(
|
candle_metal_kernels::call_index_select(
|
||||||
&device.device,
|
&device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
@ -916,7 +945,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
lhs_l: &Layout,
|
lhs_l: &Layout,
|
||||||
rhs_l: &Layout,
|
rhs_l: &Layout,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul");
|
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul")?;
|
||||||
let name = match self.dtype {
|
let name = match self.dtype {
|
||||||
DType::F32 => "sgemm",
|
DType::F32 => "sgemm",
|
||||||
DType::F16 => "hgemm",
|
DType::F16 => "hgemm",
|
||||||
@ -925,7 +954,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer()?;
|
||||||
command_buffer.set_label("matmul");
|
command_buffer.set_label("matmul");
|
||||||
candle_metal_kernels::call_gemm(
|
candle_metal_kernels::call_gemm(
|
||||||
&self.device.device,
|
&self.device.device,
|
||||||
@ -946,7 +975,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||||
let command_buffer = self.device.command_buffer();
|
let command_buffer = self.device.command_buffer()?;
|
||||||
if src_l.is_contiguous() && self.dtype == dst.dtype() {
|
if src_l.is_contiguous() && self.dtype == dst.dtype() {
|
||||||
command_buffer.set_label("copy_contiguous");
|
command_buffer.set_label("copy_contiguous");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
@ -1047,8 +1076,8 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros");
|
let buffer = self.new_buffer(shape.elem_count(), dtype, "zeros")?;
|
||||||
let command_buffer = self.command_buffer();
|
let command_buffer = self.command_buffer()?;
|
||||||
command_buffer.set_label("zeros");
|
command_buffer.set_label("zeros");
|
||||||
let blit = command_buffer.new_blit_command_encoder();
|
let blit = command_buffer.new_blit_command_encoder();
|
||||||
blit.wait_for_fence(&self.fence);
|
blit.wait_for_fence(&self.fence);
|
||||||
@ -1080,7 +1109,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
|
||||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
||||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||||
};
|
}?;
|
||||||
Ok(Self::Storage::new(
|
Ok(Self::Storage::new(
|
||||||
buffer.into(),
|
buffer.into(),
|
||||||
self.clone(),
|
self.clone(),
|
||||||
|
@ -210,7 +210,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
) -> Result<(candle::MetalStorage, Shape)> {
|
) -> Result<(candle::MetalStorage, Shape)> {
|
||||||
use candle::{backend::BackendStorage, DType};
|
use candle::{backend::BackendStorage, DType};
|
||||||
let device = storage.device();
|
let device = storage.device();
|
||||||
let command_buffer = device.command_buffer();
|
let command_buffer = device.command_buffer()?;
|
||||||
let kernels = device.kernels();
|
let kernels = device.kernels();
|
||||||
let name = match storage.dtype() {
|
let name = match storage.dtype() {
|
||||||
DType::F32 => "softmax_f32",
|
DType::F32 => "softmax_f32",
|
||||||
@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
|
|
||||||
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
let last_dim = layout.dims()[layout.shape().rank() - 1];
|
||||||
let elem_count = layout.shape().elem_count();
|
let elem_count = layout.shape().elem_count();
|
||||||
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax");
|
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
|
||||||
candle_metal_kernels::call_last_softmax(
|
candle_metal_kernels::call_last_softmax(
|
||||||
device.metal_device(),
|
device.metal_device(),
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
|
Reference in New Issue
Block a user