mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Updated everything and output a trace.
This commit is contained in:
@ -12,6 +12,7 @@ readme = "README.md"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
accelerate-src = { workspace = true, optional = true }
|
accelerate-src = { workspace = true, optional = true }
|
||||||
byteorder = { workspace = true }
|
byteorder = { workspace = true }
|
||||||
|
tracing = { workspace = true }
|
||||||
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
|
||||||
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
|
||||||
metal = { workspace = true, optional = true}
|
metal = { workspace = true, optional = true}
|
||||||
|
@ -3,7 +3,7 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
|
|||||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, DType, Layout, Result, Shape};
|
||||||
use candle_metal_kernels;
|
use candle_metal_kernels;
|
||||||
use candle_metal_kernels::{void_ptr, Kernels};
|
use candle_metal_kernels::{void_ptr, Kernels, Source};
|
||||||
use core::mem;
|
use core::mem;
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use metal;
|
use metal;
|
||||||
@ -11,6 +11,7 @@ use metal::mps::matrix::encode_gemm;
|
|||||||
use metal::mps::Float32;
|
use metal::mps::Float32;
|
||||||
use metal::{Buffer, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger};
|
use metal::{Buffer, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -55,9 +56,9 @@ impl std::ops::Deref for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl MetalDevice {
|
impl MetalDevice {
|
||||||
pub fn metal_device(&self) -> &metal::DeviceRef {
|
// pub fn metal_device(&self) -> &metal::DeviceRef {
|
||||||
self.device.as_ref()
|
// self.device.as_ref()
|
||||||
}
|
// }
|
||||||
|
|
||||||
pub fn id(&self) -> u64 {
|
pub fn id(&self) -> u64 {
|
||||||
self.registry_id()
|
self.registry_id()
|
||||||
@ -65,6 +66,7 @@ impl MetalDevice {
|
|||||||
|
|
||||||
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||||
let size = (element_count * dtype.size_in_bytes()) as u64;
|
let size = (element_count * dtype.size_in_bytes()) as u64;
|
||||||
|
// debug!("Allocate 1 - buffer size {size}");
|
||||||
self.device
|
self.device
|
||||||
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
|
||||||
}
|
}
|
||||||
@ -103,73 +105,95 @@ impl BackendStorage for MetalStorage {
|
|||||||
|
|
||||||
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let command_buffer = self.device.command_queue.new_owned_command_buffer();
|
|
||||||
|
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el = shape.elem_count();
|
let el = shape.elem_count();
|
||||||
|
let dtype = self.dtype;
|
||||||
|
|
||||||
|
debug!("{shape:?} {el:?} {:?}", layout.stride());
|
||||||
|
let output_buffer = device.new_buffer(el, self.dtype);
|
||||||
|
// return Ok(Self {
|
||||||
|
// buffer: output_buffer,
|
||||||
|
// device: device.clone(),
|
||||||
|
// dtype,
|
||||||
|
// });
|
||||||
let function = self
|
let function = self
|
||||||
.device
|
.device
|
||||||
.kernels
|
.kernels
|
||||||
.load_function(&device.device, "affine", "affine")
|
.load_function(&device.device, Source::Affine, "affine")
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
|
||||||
let pipeline = device
|
let pipeline = device
|
||||||
.new_compute_pipeline_state_with_function(&function)
|
.new_compute_pipeline_state_with_function(&function)
|
||||||
.map_err(MetalError::msg)?;
|
.map_err(MetalError::msg)?;
|
||||||
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
|
|
||||||
let output_size = el * self.dtype.size_in_bytes();
|
assert_eq!(output_buffer.length(), self.buffer.length());
|
||||||
let output_buffer = device.new_buffer(output_size, self.dtype);
|
|
||||||
|
|
||||||
let src_length = self.buffer.length() as usize - layout.start_offset();
|
|
||||||
let src = self.device.new_buffer(src_length, self.dtype);
|
|
||||||
let blit_encoder = command_buffer.new_blit_command_encoder();
|
|
||||||
blit_encoder.copy_from_buffer(
|
|
||||||
self.buffer.as_ref(),
|
|
||||||
layout.start_offset() as NSUInteger,
|
|
||||||
output_buffer.as_ref(),
|
|
||||||
0,
|
|
||||||
(src_length * self.dtype.size_in_bytes()) as NSUInteger,
|
|
||||||
);
|
|
||||||
blit_encoder.end_encoding();
|
|
||||||
|
|
||||||
|
let length = el;
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
encoder.set_compute_pipeline_state(&pipeline);
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
// encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
|
||||||
|
|
||||||
encoder.set_bytes(0, 4, void_ptr(&el));
|
encoder.set_bytes(0, 4, void_ptr(&el));
|
||||||
encoder.set_bytes(1, 4, void_ptr(&dims));
|
encoder.set_bytes(1, 4, void_ptr(&dims));
|
||||||
let info = [dims, layout.stride()].concat();
|
encoder.set_bytes(
|
||||||
let info_len = (info.len() * mem::size_of::<usize>()) as NSUInteger;
|
2,
|
||||||
encoder.set_bytes(2, info_len, info.as_slice().as_ptr().cast());
|
(mem::size_of::<usize>() * dims.len()) as u64,
|
||||||
|
dims.as_ptr() as *const core::ffi::c_void,
|
||||||
|
);
|
||||||
|
encoder.set_bytes(
|
||||||
|
3,
|
||||||
|
(mem::size_of::<usize>() * layout.stride().len()) as u64,
|
||||||
|
layout.stride().as_ptr() as *const core::ffi::c_void,
|
||||||
|
);
|
||||||
|
encoder.set_buffer(4, Some(&self.buffer), 0);
|
||||||
|
encoder.set_buffer(5, Some(&output_buffer), 0);
|
||||||
|
|
||||||
encoder.set_buffer(3, Some(&src), 0);
|
encoder.set_bytes(6, mem::size_of::<f32>() as u64, void_ptr(&(mul as f32)));
|
||||||
encoder.set_buffer(4, Some(&output_buffer), 0);
|
encoder.set_bytes(7, mem::size_of::<f32>() as u64, void_ptr(&(add as f32)));
|
||||||
|
|
||||||
encoder.set_bytes(5, 4, void_ptr(&(mul as f32)));
|
|
||||||
encoder.set_bytes(6, 4, void_ptr(&(add as f32)));
|
|
||||||
|
|
||||||
let grid_size = MTLSize {
|
let grid_size = MTLSize {
|
||||||
width: output_size as NSUInteger,
|
|
||||||
height: 1,
|
|
||||||
depth: 1,
|
|
||||||
};
|
|
||||||
|
|
||||||
let thread_group_size = MTLSize {
|
|
||||||
width: 1,
|
width: 1,
|
||||||
height: 1,
|
height: 1,
|
||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
encoder.dispatch_threads(grid_size, thread_group_size);
|
let thread_group_size = MTLSize {
|
||||||
|
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), el as u64),
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
|
// debug!(
|
||||||
|
// "Affine {:?}({:?}, {:?}) - {:?}",
|
||||||
|
// command_buffer.status(),
|
||||||
|
// self.buffer.length(),
|
||||||
|
// output_buffer.length(),
|
||||||
|
// start.elapsed()
|
||||||
|
// );
|
||||||
// command_buffer.wait_until_completed();
|
// command_buffer.wait_until_completed();
|
||||||
println!("Affine");
|
debug!(
|
||||||
|
"Affine {:?} - {:?}",
|
||||||
|
command_buffer.status(),
|
||||||
|
start.elapsed()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(self.clone())
|
let capture = metal::CaptureManager::shared();
|
||||||
|
capture.stop_capture();
|
||||||
|
panic!("Done");
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
buffer: output_buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
|
||||||
@ -180,10 +204,78 @@ impl BackendStorage for MetalStorage {
|
|||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
|
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||||
println!("TODO reduce_op");
|
debug!("TODO reduce_op");
|
||||||
Ok(self.clone())
|
let src_stride = layout.stride();
|
||||||
|
let src_dims = layout.shape().dims();
|
||||||
|
let src_el: usize = src_dims.iter().product();
|
||||||
|
// Source dims and strides with the sum dims at the end.
|
||||||
|
let mut dims = vec![];
|
||||||
|
let mut stride = vec![];
|
||||||
|
let mut dst_el: usize = 1;
|
||||||
|
for (dim_idx, &d) in src_dims.iter().enumerate() {
|
||||||
|
if !sum_dims.contains(&dim_idx) {
|
||||||
|
dst_el *= d;
|
||||||
|
dims.push(d);
|
||||||
|
stride.push(src_stride[dim_idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for &dim_idx in sum_dims.iter() {
|
||||||
|
dims.push(src_dims[dim_idx]);
|
||||||
|
stride.push(src_stride[dim_idx]);
|
||||||
|
}
|
||||||
|
// let el_to_sum_per_block = src_el / dst_el;
|
||||||
|
// // The reduction loop requires the shared array to be properly initialized and for
|
||||||
|
// // this we want the number of threads to be a power of two.
|
||||||
|
// let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two();
|
||||||
|
// let cfg = LaunchConfig {
|
||||||
|
// // TODO: Maybe use grid_y if the output is too large?
|
||||||
|
// // TODO: Specialized implementation when reducing on no or all dimensions or when
|
||||||
|
// // reducing only aggregate a small number of elements together.
|
||||||
|
// grid_dim: (dst_el as u32, 1, 1),
|
||||||
|
// block_dim: (block_dim as u32, 1, 1),
|
||||||
|
// shared_mem_bytes: 0,
|
||||||
|
// };
|
||||||
|
// let ds = dev
|
||||||
|
// .htod_copy([dims.as_slice(), stride.as_slice()].concat())
|
||||||
|
// .w()?;
|
||||||
|
// let src = &src.slice(layout.start_offset()..);
|
||||||
|
// let (name, check_empty, return_index) = match self.1 {
|
||||||
|
// ReduceOp::Sum => ("fast_sum", false, false),
|
||||||
|
// ReduceOp::Min => ("fast_min", true, false),
|
||||||
|
// ReduceOp::Max => ("fast_max", true, false),
|
||||||
|
// ReduceOp::ArgMin => ("fast_argmin", true, true),
|
||||||
|
// ReduceOp::ArgMax => ("fast_argmax", true, true),
|
||||||
|
// };
|
||||||
|
// if check_empty && layout.shape().elem_count() == 0 {
|
||||||
|
// Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
|
||||||
|
// }
|
||||||
|
// let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
|
||||||
|
// if return_index {
|
||||||
|
// // SAFETY: filled in by the follow up kernel.
|
||||||
|
// let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
|
||||||
|
// let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||||
|
// // SAFETY: ffi.
|
||||||
|
// unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
// Ok(S::U32(out))
|
||||||
|
// } else {
|
||||||
|
// // SAFETY: filled in by the follow up kernel.
|
||||||
|
// let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
|
// let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
|
||||||
|
// // SAFETY: ffi.
|
||||||
|
// unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
// Ok(wrap(out))
|
||||||
|
// }
|
||||||
|
// Ok(self.clone())
|
||||||
// todo!()
|
// todo!()
|
||||||
|
let dtype = self.dtype;
|
||||||
|
let device = self.device();
|
||||||
|
let buffer = device.new_buffer(dst_el, dtype);
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
@ -201,6 +293,12 @@ impl BackendStorage for MetalStorage {
|
|||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
let el_count = shape.elem_count();
|
let el_count = shape.elem_count();
|
||||||
let mut buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
|
// TODO remove
|
||||||
|
return Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
});
|
||||||
let command_buffer = device.command_queue.new_command_buffer();
|
let command_buffer = device.command_queue.new_command_buffer();
|
||||||
if layout.is_contiguous() {
|
if layout.is_contiguous() {
|
||||||
use candle_metal_kernels::unary::contiguous;
|
use candle_metal_kernels::unary::contiguous;
|
||||||
@ -227,9 +325,17 @@ impl BackendStorage for MetalStorage {
|
|||||||
} else {
|
} else {
|
||||||
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
todo!("TODO Implement the kernel calling {}", B::KERNEL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let start = std::time::Instant::now();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
// command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
println!("Unary {:?}", B::KERNEL);
|
debug!(
|
||||||
|
"Unary {:?} - {:?} - {:?} - {:?}",
|
||||||
|
B::KERNEL,
|
||||||
|
start.elapsed(),
|
||||||
|
self.buffer.length(),
|
||||||
|
buffer.length()
|
||||||
|
);
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
buffer,
|
buffer,
|
||||||
@ -239,13 +345,13 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
println!("TODO Binary {:?}", B::NAME);
|
debug!("TODO Binary {:?}", B::NAME);
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
// todo!()
|
// todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
|
||||||
println!("TODO where_cond");
|
debug!("TODO where_cond");
|
||||||
Ok(rhs.clone())
|
Ok(rhs.clone())
|
||||||
// todo!()
|
// todo!()
|
||||||
}
|
}
|
||||||
@ -312,9 +418,29 @@ impl BackendStorage for MetalStorage {
|
|||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
|
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
println!("TODO Index select");
|
// todo!("TODO Index select {:?} {ids:?} {l:?} {ids_l:?} {dim:?}", self.buffer.length());
|
||||||
Ok(self.clone())
|
let src = self;
|
||||||
|
let ids_shape = ids_l.shape();
|
||||||
|
let ids_dims = ids_shape.dims();
|
||||||
|
// let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
|
||||||
|
// let src = match src_l.contiguous_offsets() {
|
||||||
|
// Some((o1, o2)) => src.slice(o1..o2),
|
||||||
|
// None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
|
||||||
|
// };
|
||||||
|
let left_size: usize = src_l.dims()[..dim].iter().product();
|
||||||
|
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
|
let src_dim_size = src_l.dims()[dim];
|
||||||
|
let ids_dim_size = ids_shape.elem_count();
|
||||||
|
let dst_el = ids_shape.elem_count() * left_size * right_size;
|
||||||
|
let dtype = self.dtype;
|
||||||
|
let device = self.device();
|
||||||
|
let buffer = device.new_buffer(dst_el, dtype);
|
||||||
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
// todo!()
|
// todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -354,7 +480,7 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
|
||||||
println!("TODO Copy strided");
|
debug!("TODO Copy strided");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -398,7 +524,7 @@ impl MetalStorage {
|
|||||||
(DType::F32, DType::F32) => {
|
(DType::F32, DType::F32) => {
|
||||||
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
|
||||||
if b != 1 {
|
if b != 1 {
|
||||||
println!("TODO implement batched matmul for B={b}");
|
debug!("TODO implement batched matmul for B={b}");
|
||||||
// bail!("Didn't implemented strided matmul yet");
|
// bail!("Didn't implemented strided matmul yet");
|
||||||
return Ok(Self {
|
return Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: out_buffer,
|
||||||
@ -407,7 +533,7 @@ impl MetalStorage {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
||||||
println!(
|
debug!(
|
||||||
"Didn't implemented non contiguous matmul yet {:?} {:?}",
|
"Didn't implemented non contiguous matmul yet {:?} {:?}",
|
||||||
lhs_l.is_contiguous(),
|
lhs_l.is_contiguous(),
|
||||||
rhs_l.is_contiguous()
|
rhs_l.is_contiguous()
|
||||||
@ -419,7 +545,7 @@ impl MetalStorage {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("GEMM");
|
debug!("GEMM");
|
||||||
let command_buffer = self.device.command_queue.new_command_buffer();
|
let command_buffer = self.device.command_queue.new_command_buffer();
|
||||||
encode_gemm::<Float32, Float32, Float32>(
|
encode_gemm::<Float32, Float32, Float32>(
|
||||||
&self.device,
|
&self.device,
|
||||||
@ -438,6 +564,7 @@ impl MetalStorage {
|
|||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_scheduled();
|
||||||
|
|
||||||
// println!("lhs {:?} {m} {k}", self.buffer.length());
|
// println!("lhs {:?} {m} {k}", self.buffer.length());
|
||||||
// println!("rhs {:?} {k} {n}", rhs.buffer.length());
|
// println!("rhs {:?} {k} {n}", rhs.buffer.length());
|
||||||
@ -460,6 +587,19 @@ impl BackendDevice for MetalDevice {
|
|||||||
|
|
||||||
fn new(ordinal: usize) -> Result<Self> {
|
fn new(ordinal: usize) -> Result<Self> {
|
||||||
let device = metal::Device::all().swap_remove(ordinal);
|
let device = metal::Device::all().swap_remove(ordinal);
|
||||||
|
|
||||||
|
let capture = metal::CaptureManager::shared();
|
||||||
|
let descriptor = metal::CaptureDescriptor::new();
|
||||||
|
descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
|
||||||
|
println!("{:?}", std::env::current_dir()?);
|
||||||
|
descriptor.set_capture_device(&device);
|
||||||
|
let mut dir = std::env::current_dir()?;
|
||||||
|
dir.push("out.gputrace");
|
||||||
|
descriptor.set_output_url(dir);
|
||||||
|
|
||||||
|
capture
|
||||||
|
.start_capture(&descriptor)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
let command_queue = device.new_command_queue();
|
let command_queue = device.new_command_queue();
|
||||||
// let command_buffer = _command_queue.new_owned_command_buffer();
|
// let command_buffer = _command_queue.new_owned_command_buffer();
|
||||||
let kernels = Arc::new(Kernels::new());
|
let kernels = Arc::new(Kernels::new());
|
||||||
@ -496,7 +636,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||||
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
|
let option = metal::MTLResourceOptions::StorageModeManaged;
|
||||||
let buffer = match storage {
|
let buffer = match storage {
|
||||||
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
|
||||||
storage.as_ptr() as *const core::ffi::c_void,
|
storage.as_ptr() as *const core::ffi::c_void,
|
||||||
@ -534,6 +674,7 @@ impl BackendDevice for MetalDevice {
|
|||||||
option,
|
option,
|
||||||
),
|
),
|
||||||
};
|
};
|
||||||
|
// debug!("Allocate 2 - buffer size {}", buffer.length());
|
||||||
Ok(Self::Storage {
|
Ok(Self::Storage {
|
||||||
buffer,
|
buffer,
|
||||||
device: self.clone(),
|
device: self.clone(),
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
use crate::{Device, Result, Shape, Tensor};
|
use crate::{Device, Result, Shape, Tensor};
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
#[cfg(target_feature = "avx")]
|
#[cfg(target_feature = "avx")]
|
||||||
pub mod avx;
|
pub mod avx;
|
||||||
@ -321,7 +322,7 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
storage: &crate::MetalStorage,
|
storage: &crate::MetalStorage,
|
||||||
layout: &crate::Layout,
|
layout: &crate::Layout,
|
||||||
) -> Result<(crate::MetalStorage, Shape)> {
|
) -> Result<(crate::MetalStorage, Shape)> {
|
||||||
println!("TODO qmatmul");
|
debug!("TODO qmatmul");
|
||||||
if !layout.is_contiguous() {
|
if !layout.is_contiguous() {
|
||||||
crate::bail!("input tensor is not contiguous {layout:?}")
|
crate::bail!("input tensor is not contiguous {layout:?}")
|
||||||
}
|
}
|
||||||
|
@ -238,13 +238,16 @@ fn main() -> anyhow::Result<()> {
|
|||||||
} else {
|
} else {
|
||||||
Some(args.temperature)
|
Some(args.temperature)
|
||||||
};
|
};
|
||||||
let _guard = if args.tracing {
|
tracing_subscriber::fmt::init();
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
// let _guard = if args.tracing {
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
// // let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
Some(guard)
|
// // tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
} else {
|
// tracing_subscriber::fmt::init();
|
||||||
None
|
// None
|
||||||
};
|
// // Some(guard)
|
||||||
|
// } else {
|
||||||
|
// None
|
||||||
|
// };
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
@ -35,25 +35,27 @@ METAL_FUNC uint get_strided_index(
|
|||||||
kernel void affine(
|
kernel void affine(
|
||||||
constant size_t &dim,
|
constant size_t &dim,
|
||||||
constant size_t &num_dims,
|
constant size_t &num_dims,
|
||||||
constant size_t *info,
|
constant size_t *dims,
|
||||||
|
constant size_t *strides,
|
||||||
|
|
||||||
device float *inp [[buffer(3)]],
|
device float *inp [[buffer(4)]],
|
||||||
device float *out [[buffer(4)]],
|
device float *out [[buffer(5)]],
|
||||||
|
|
||||||
constant float &mul,
|
constant float &mul,
|
||||||
constant float &add
|
constant float &add,
|
||||||
|
uint threadgroup_size [[threads_per_threadgroup]], \
|
||||||
|
uint thread_index [[thread_index_in_threadgroup]]
|
||||||
) {
|
) {
|
||||||
|
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size;
|
||||||
constant size_t *dims = info;
|
const size_t start = thread_index * length;
|
||||||
constant size_t *strides = info + num_dims;
|
const size_t stop = min(start + length, dim);
|
||||||
|
|
||||||
if (is_contiguous(num_dims, dims, strides)) {
|
if (is_contiguous(num_dims, dims, strides)) {
|
||||||
for (size_t i = 0; i < dim; i++) {
|
for (size_t i = start; i < stop; i++) {
|
||||||
float x = inp ? inp[i] : out[i];
|
float x = inp ? inp[i] : out[i];
|
||||||
out[i] = x * mul + add;
|
out[i] = x * mul + add;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (size_t i = 0; i < dim; i++) {
|
for (size_t i = start; i < stop; i++) {
|
||||||
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
uint strided_i = get_strided_index(i, num_dims, dims, strides);
|
||||||
float x = inp ? inp[strided_i] : out[strided_i];
|
float x = inp ? inp[strided_i] : out[strided_i];
|
||||||
out[strided_i] = x * mul + add;
|
out[strided_i] = x * mul + add;
|
||||||
|
@ -424,15 +424,13 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn affine() {
|
fn affine() {
|
||||||
let device = device();
|
let device = device();
|
||||||
|
|
||||||
let options = CompileOptions::new();
|
let options = CompileOptions::new();
|
||||||
let library = device.new_library_with_source(AFFINE, &options).unwrap();
|
let library = device.new_library_with_source(AFFINE, &options).unwrap();
|
||||||
|
|
||||||
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
|
||||||
let output = [2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
let output = [2.0f32, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||||
let dim: u32 = 8;
|
let shape = vec![4usize, 2];
|
||||||
let num_dims: u32 = 4;
|
let strides = vec![2usize, 1];
|
||||||
let info = [1u32, 2, 3];
|
|
||||||
let mul: f32 = 1.5;
|
let mul: f32 = 1.5;
|
||||||
let add: f32 = 1.1;
|
let add: f32 = 1.1;
|
||||||
|
|
||||||
@ -455,29 +453,42 @@ mod tests {
|
|||||||
let inputs_buffer = device.new_buffer_with_data(void_ptr(&input), input_size, options);
|
let inputs_buffer = device.new_buffer_with_data(void_ptr(&input), input_size, options);
|
||||||
let outputs_buffer = device.new_buffer_with_data(void_ptr(&output), output_size, options);
|
let outputs_buffer = device.new_buffer_with_data(void_ptr(&output), output_size, options);
|
||||||
|
|
||||||
encoder.set_bytes(0, 4, void_ptr(&dim));
|
let dim: usize = shape.iter().product();
|
||||||
encoder.set_bytes(1, 4, void_ptr(&num_dims));
|
let num_dims = shape.len();
|
||||||
encoder.set_bytes(2, 4, void_ptr(&info));
|
encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&dim));
|
||||||
|
encoder.set_bytes(1, core::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
|
||||||
|
encoder.set_bytes(
|
||||||
|
2,
|
||||||
|
(core::mem::size_of::<usize>() * shape.len()) as u64,
|
||||||
|
shape.as_ptr() as *const c_void,
|
||||||
|
);
|
||||||
|
encoder.set_bytes(
|
||||||
|
3,
|
||||||
|
(core::mem::size_of::<usize>() * strides.len()) as u64,
|
||||||
|
strides.as_ptr() as *const c_void,
|
||||||
|
);
|
||||||
|
|
||||||
encoder.set_buffer(3, Some(&inputs_buffer), 0);
|
encoder.set_buffer(4, Some(&inputs_buffer), 0);
|
||||||
encoder.set_buffer(4, Some(&outputs_buffer), 0);
|
encoder.set_buffer(5, Some(&outputs_buffer), 0);
|
||||||
|
|
||||||
encoder.set_bytes(5, 4, void_ptr(&mul));
|
encoder.set_bytes(6, core::mem::size_of::<f32>() as u64, void_ptr(&mul));
|
||||||
encoder.set_bytes(6, 4, void_ptr(&add));
|
encoder.set_bytes(7, core::mem::size_of::<f32>() as u64, void_ptr(&add));
|
||||||
|
|
||||||
let grid_size = MTLSize {
|
let thread_group_count = MTLSize {
|
||||||
width: output.len() as NSUInteger,
|
width: 1,
|
||||||
height: 1,
|
height: 1,
|
||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), dim as u64);
|
||||||
|
println!("WIDTH {width}");
|
||||||
let thread_group_size = MTLSize {
|
let thread_group_size = MTLSize {
|
||||||
width: pipeline.max_total_threads_per_threadgroup(),
|
width,
|
||||||
height: 1,
|
height: 1,
|
||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
encoder.dispatch_threads(grid_size, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
@ -545,7 +556,7 @@ mod tests {
|
|||||||
depth: 1,
|
depth: 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
encoder.dispatch_threads(grid_size, thread_group_size);
|
encoder.dispatch_thread_groups(grid_size, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
|
@ -14,6 +14,7 @@ accelerate-src = { workspace = true, optional = true }
|
|||||||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
|
||||||
half = { workspace = true }
|
half = { workspace = true }
|
||||||
thiserror = { workspace = true }
|
thiserror = { workspace = true }
|
||||||
|
tracing = { workspace = true }
|
||||||
intel-mkl-src = { workspace = true, optional = true }
|
intel-mkl-src = { workspace = true, optional = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
rayon = { workspace = true }
|
rayon = { workspace = true }
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||||
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
|
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
|
||||||
@ -198,7 +199,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
|||||||
storage: &candle::MetalStorage,
|
storage: &candle::MetalStorage,
|
||||||
layout: &Layout,
|
layout: &Layout,
|
||||||
) -> Result<(candle::MetalStorage, Shape)> {
|
) -> Result<(candle::MetalStorage, Shape)> {
|
||||||
println!("TODO softmax-last-dim");
|
debug!("TODO softmax-last-dim");
|
||||||
Ok((storage.clone(), layout.shape().clone()))
|
Ok((storage.clone(), layout.shape().clone()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user