Remove unwrap().

This commit is contained in:
Nicolas Patry
2023-12-15 12:23:28 +01:00
parent 8b5059e951
commit aa04015098
2 changed files with 77 additions and 48 deletions

View File

@ -210,7 +210,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
) -> Result<(candle::MetalStorage, Shape)> {
use candle::{backend::BackendStorage, DType};
let device = storage.device();
let command_buffer = device.command_buffer();
let command_buffer = device.command_buffer()?;
let kernels = device.kernels();
let name = match storage.dtype() {
DType::F32 => "softmax_f32",
@ -226,7 +226,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
let last_dim = layout.dims()[layout.shape().rank() - 1];
let elem_count = layout.shape().elem_count();
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax");
let mut output = device.new_buffer(elem_count, storage.dtype(), "softmax")?;
candle_metal_kernels::call_last_softmax(
device.metal_device(),
&command_buffer,