mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
cleanup.
This commit is contained in:
@ -91,7 +91,7 @@ impl MetalDevice {
|
||||
metal::MTLCommandBufferStatus::Committed
|
||||
| metal::MTLCommandBufferStatus::Scheduled
|
||||
| metal::MTLCommandBufferStatus::Completed => {
|
||||
panic!("Alredy committed");
|
||||
panic!("Already committed");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
@ -166,9 +166,6 @@ impl MetalDevice {
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
blit.update_fence(&self.fence);
|
||||
blit.end_encoding();
|
||||
// drop(command_buffer);
|
||||
// real.did_modify_range(metal::NSRange::new(0, real.length()));
|
||||
// println!("Command {:?}", command.status());
|
||||
|
||||
// This is necessary, for mmaped safetensors
|
||||
// Because of the unsafe slice cast we're doing.
|
||||
@ -245,11 +242,7 @@ impl BackendStorage for MetalStorage {
|
||||
DType::I64 => Ok(CpuStorage::I64(read_to_vec(&buffer, length / size))),
|
||||
DType::F16 => Ok(CpuStorage::F16(read_to_vec(&buffer, length / size))),
|
||||
DType::BF16 => Ok(CpuStorage::BF16(read_to_vec(&buffer, length / size))),
|
||||
DType::F32 => {
|
||||
let vec = read_to_vec(&buffer, length / size);
|
||||
// println!("Got back {:?}", &vec[..1]);
|
||||
Ok(CpuStorage::F32(vec))
|
||||
}
|
||||
DType::F32 => Ok(CpuStorage::F32(read_to_vec(&buffer, length / size))),
|
||||
DType::F64 => Ok(CpuStorage::F64(read_to_vec(&buffer, length / size))),
|
||||
}
|
||||
}
|
||||
@ -302,7 +295,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
// buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -401,7 +393,6 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
buffer.did_modify_range(metal::NSRange::new(0, buffer.length()));
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
@ -644,21 +635,13 @@ impl BackendStorage for MetalStorage {
|
||||
|
||||
let kernel_name = match (B::KERNEL, dtype) {
|
||||
("add", DType::F32) => contiguous::add::FLOAT,
|
||||
// ("badd", DType::F32) => contiguous::add::FLOAT,
|
||||
("sub", DType::F32) => contiguous::sub::FLOAT,
|
||||
//("bsub", DType::F32) => contiguous::sub::FLOAT,
|
||||
("mul", DType::F32) => contiguous::mul::FLOAT,
|
||||
// ("bmul", DType::F32) => contiguous::mul::FLOAT,
|
||||
("div", DType::F32) => contiguous::div::FLOAT,
|
||||
// ("bdiv", DType::F32) => contiguous::div::FLOAT,
|
||||
("add", DType::F16) => contiguous::add::HALF,
|
||||
// ("badd", DType::F16) => contiguous::add::HALF,
|
||||
("sub", DType::F16) => contiguous::sub::HALF,
|
||||
// ("bsub", DType::F16) => contiguous::sub::HALF,
|
||||
("mul", DType::F16) => contiguous::mul::HALF,
|
||||
// ("bmul", DType::F16) => contiguous::mul::HALF,
|
||||
("div", DType::F16) => contiguous::div::HALF,
|
||||
// ("bdiv", DType::F16) => contiguous::div::HALF,
|
||||
(name, dtype) => crate::bail!("Match {name} - {dtype:?}"),
|
||||
};
|
||||
candle_metal_kernels::call_binary_contiguous(
|
||||
@ -877,8 +860,6 @@ impl BackendStorage for MetalStorage {
|
||||
lhs_l: &Layout,
|
||||
rhs_l: &Layout,
|
||||
) -> Result<Self> {
|
||||
// Create descriptors
|
||||
|
||||
let buffer = self.device.new_buffer(b * m * n, self.dtype, "matmul");
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "sgemm",
|
||||
@ -889,8 +870,6 @@ impl BackendStorage for MetalStorage {
|
||||
};
|
||||
|
||||
let command_buffer = self.device.command_buffer();
|
||||
// println!("MATMUL {b} {m} {n} {k}");
|
||||
// println!("strides {:?} {:?}", lhs_l.stride(), rhs_l.stride());
|
||||
command_buffer.set_label("matmul");
|
||||
candle_metal_kernels::call_gemm(
|
||||
&self.device.device,
|
||||
@ -907,14 +886,11 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
// Create kernel
|
||||
|
||||
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let command_buffer = self.device.command_buffer();
|
||||
// println!("Copy strided");
|
||||
if src_l.is_contiguous() && self.dtype == dst.dtype() {
|
||||
command_buffer.set_label("copy_contiguous");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
@ -975,7 +951,6 @@ impl BackendDevice for MetalDevice {
|
||||
type Storage = MetalStorage;
|
||||
|
||||
fn new(ordinal: usize) -> Result<Self> {
|
||||
// println!("CREATING DEVICE");
|
||||
let device = metal::Device::all().swap_remove(ordinal);
|
||||
|
||||
let n = 1;
|
||||
@ -1024,6 +999,7 @@ impl BackendDevice for MetalDevice {
|
||||
let command_buffer = self.command_buffer();
|
||||
command_buffer.set_label("zeros");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.wait_for_fence(&self.fence);
|
||||
blit.fill_buffer(
|
||||
&buffer,
|
||||
metal::NSRange {
|
||||
@ -1032,6 +1008,7 @@ impl BackendDevice for MetalDevice {
|
||||
},
|
||||
0,
|
||||
);
|
||||
blit.update_fence(&self.fence);
|
||||
blit.end_encoding();
|
||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
Reference in New Issue
Block a user