mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Debugging index_add.
This commit is contained in:
@ -9,7 +9,7 @@ use half::{bf16, f16};
|
|||||||
use metal;
|
use metal;
|
||||||
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
use metal::mps::matrix::{Matrix, MatrixDescriptor, MatrixMultiplication};
|
||||||
use metal::mps::{Float32, MPSDataType};
|
use metal::mps::{Float32, MPSDataType};
|
||||||
use metal::{MTLResourceOptions, Buffer};
|
use metal::{Buffer, MTLResourceOptions};
|
||||||
|
|
||||||
/// Metal related errors
|
/// Metal related errors
|
||||||
#[derive(thiserror::Error, Debug)]
|
#[derive(thiserror::Error, Debug)]
|
||||||
@ -48,12 +48,9 @@ impl MetalDevice {
|
|||||||
self.registry_id()
|
self.registry_id()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer{
|
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
|
||||||
let size = (element_count * dtype.size_in_bytes()) as u64;
|
let size = (element_count * dtype.size_in_bytes()) as u64;
|
||||||
self.device.new_buffer(
|
self.device.new_buffer(size, MTLResourceOptions::empty())
|
||||||
size,
|
|
||||||
MTLResourceOptions::empty(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -80,9 +77,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
match self.dtype{
|
match self.dtype {
|
||||||
DType::F32 => Ok(CpuStorage::F32(self.buffer.read_to_vec(self.buffer.length() as usize / 4))),
|
DType::F32 => Ok(CpuStorage::F32(
|
||||||
dtype => todo!("Unsupported dtype {dtype:?}")
|
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
|
||||||
|
)),
|
||||||
|
dtype => todo!("Unsupported dtype {dtype:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -123,7 +122,11 @@ impl BackendStorage for MetalStorage {
|
|||||||
let mut buffer = device.new_buffer(el_count, dtype);
|
let mut buffer = device.new_buffer(el_count, dtype);
|
||||||
todo!("Implement the kernel calling");
|
todo!("Implement the kernel calling");
|
||||||
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
|
// device.kernels.call_unary(U::KERNEL, &self.buffer, &mut buffer, el_count, dtype);
|
||||||
Ok(Self { buffer, device, dtype })
|
Ok(Self {
|
||||||
|
buffer,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
|
||||||
@ -295,7 +298,11 @@ impl MetalStorage {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
|
||||||
println!("Didn't implemented non contiguous matmul yet {:?} {:?}", lhs_l.is_contiguous(), rhs_l.is_contiguous());
|
println!(
|
||||||
|
"Didn't implemented non contiguous matmul yet {:?} {:?}",
|
||||||
|
lhs_l.is_contiguous(),
|
||||||
|
rhs_l.is_contiguous()
|
||||||
|
);
|
||||||
return Ok(Self {
|
return Ok(Self {
|
||||||
buffer: out_buffer,
|
buffer: out_buffer,
|
||||||
device: self.device.clone(),
|
device: self.device.clone(),
|
||||||
@ -361,7 +368,6 @@ impl MetalStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
impl BackendDevice for MetalDevice {
|
impl BackendDevice for MetalDevice {
|
||||||
type Storage = MetalStorage;
|
type Storage = MetalStorage;
|
||||||
|
|
||||||
@ -446,13 +452,25 @@ impl BackendDevice for MetalDevice {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand_uniform(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
|
fn rand_uniform(
|
||||||
|
&self,
|
||||||
|
shape: &Shape,
|
||||||
|
dtype: DType,
|
||||||
|
mean: f64,
|
||||||
|
stddev: f64,
|
||||||
|
) -> Result<Self::Storage> {
|
||||||
// TODO is there a better way ?
|
// TODO is there a better way ?
|
||||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
|
||||||
self.storage_from_cpu_storage(&cpu_storage)
|
self.storage_from_cpu_storage(&cpu_storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn rand_normal(&self, shape: &Shape, dtype: DType, mean: f64, stddev: f64) -> Result<Self::Storage> {
|
fn rand_normal(
|
||||||
|
&self,
|
||||||
|
shape: &Shape,
|
||||||
|
dtype: DType,
|
||||||
|
mean: f64,
|
||||||
|
stddev: f64,
|
||||||
|
) -> Result<Self::Storage> {
|
||||||
// TODO is there a better way ?
|
// TODO is there a better way ?
|
||||||
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
|
||||||
self.storage_from_cpu_storage(&cpu_storage)
|
self.storage_from_cpu_storage(&cpu_storage)
|
||||||
|
@ -349,12 +349,9 @@ impl crate::CustomOp1 for QTensor {
|
|||||||
// )?;
|
// )?;
|
||||||
let cpu_storage = crate::CpuStorage::F32(dst_storage);
|
let cpu_storage = crate::CpuStorage::F32(dst_storage);
|
||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
if let Device::Metal(device) = &self.device{
|
if let Device::Metal(device) = &self.device {
|
||||||
Ok((
|
Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape))
|
||||||
device.storage_from_cpu_storage(&cpu_storage)?,
|
} else {
|
||||||
dst_shape,
|
|
||||||
))
|
|
||||||
}else{
|
|
||||||
crate::bail!("qtensor not on metal device")
|
crate::bail!("qtensor not on metal device")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ use std::io::Write;
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
use candle::quantized::{ggml_file, gguf_file};
|
use candle::quantized::{ggml_file, gguf_file};
|
||||||
use candle::{Tensor};
|
use candle::Tensor;
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
|
|
||||||
use candle_transformers::models::quantized_llama as model;
|
use candle_transformers::models::quantized_llama as model;
|
||||||
|
49
candle-metal-kernels/src/indexing.metal
Normal file
49
candle-metal-kernels/src/indexing.metal
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
#include <metal_stdlib>
|
||||||
|
#include <metal_config>
|
||||||
|
#define METAL_FUNC inline __attribute__((__always_inline__))
|
||||||
|
using namespace metal;
|
||||||
|
|
||||||
|
struct fault_counter {
|
||||||
|
uint counter;
|
||||||
|
uint tolerance;
|
||||||
|
|
||||||
|
fault_counter(uint tolerance) {
|
||||||
|
this->counter = 0;
|
||||||
|
this->tolerance = tolerance;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool quit() {
|
||||||
|
counter += 1;
|
||||||
|
return (counter > tolerance);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
constant uint IDS_DIM_SIZE [[function_constant(0)]];
|
||||||
|
constant uint SRC_DIM_SIZE [[function_constant(1)]];
|
||||||
|
constant uint DST_DIM_SIZE [[function_constant(2)]];
|
||||||
|
constant uint LEFT_SIZE [[function_constant(3)]];
|
||||||
|
constant uint RIGHT_SIZE [[function_constant(4)]];
|
||||||
|
constant uint NUMEL = LEFT_SIZE * RIGHT_SIZE;
|
||||||
|
|
||||||
|
kernel void index_add(
|
||||||
|
device uint *ids [[buffer(0)]],
|
||||||
|
device float *inp [[buffer(1)]],
|
||||||
|
device float *out [[buffer(2)]],
|
||||||
|
|
||||||
|
uint grid_size [[threadgroups_per_grid]], // gridDim
|
||||||
|
uint gid [[thread_position_in_grid]], // blockIdx
|
||||||
|
uint num_threads [[threads_per_grid]], // blockDim
|
||||||
|
uint thread_index [[thread_index_in_threadgroup]] // threadIdx
|
||||||
|
) {
|
||||||
|
for (uint i = gid * num_threads + thread_index; i < NUMEL; i += num_threads * grid_size) {
|
||||||
|
const uint pre = i / RIGHT_SIZE;
|
||||||
|
const uint post = i % RIGHT_SIZE;
|
||||||
|
|
||||||
|
for (uint j = 0; j < IDS_DIM_SIZE; j++) {
|
||||||
|
const uint idx = ids[j];
|
||||||
|
const uint src_i = (pre * IDS_DIM_SIZE + j) * RIGHT_SIZE + post;
|
||||||
|
const uint dst_i = (pre * DST_DIM_SIZE + idx) * RIGHT_SIZE + post;
|
||||||
|
out[dst_i] += inp[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,8 +1,9 @@
|
|||||||
use metal::{Buffer, Device, Function, Library, CompileOptions};
|
use metal::{Buffer, CompileOptions, Device, Function, Library, NSUInteger};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::RwLock;
|
use std::sync::RwLock;
|
||||||
|
|
||||||
static UNARY: &'static str = include_str!("unary.metal");
|
pub const INDEXING: &str = include_str!("indexing.metal");
|
||||||
|
pub const UNARY: &str = include_str!("unary.metal");
|
||||||
|
|
||||||
pub enum Error {}
|
pub enum Error {}
|
||||||
|
|
||||||
@ -63,10 +64,16 @@ fn call_unary(func: &Function, input: &Buffer, output: &Buffer, length: usize) {
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use metal::{
|
use metal::{
|
||||||
ComputePipelineDescriptor, MTLResourceOptions, MTLResourceUsage, MTLSize,
|
CompileOptions, ComputePipelineDescriptor, Device, FunctionConstantValues, MTLDataType,
|
||||||
|
MTLResourceOptions, MTLResourceUsage, MTLSize, NSUInteger,
|
||||||
};
|
};
|
||||||
|
use std::ffi::c_void;
|
||||||
|
use std::mem;
|
||||||
|
|
||||||
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32>{
|
pub fn void_ptr<T>(v: &T) -> *const c_void {
|
||||||
|
(v as *const T).cast()
|
||||||
|
}
|
||||||
|
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
|
||||||
let b = 10f32.powi(digits);
|
let b = 10f32.powi(digits);
|
||||||
v.iter().map(|t| f32::round(t * b) / b).collect()
|
v.iter().map(|t| f32::round(t * b) / b).collect()
|
||||||
}
|
}
|
||||||
@ -80,11 +87,11 @@ mod tests {
|
|||||||
let command_buffer = command_queue.new_command_buffer();
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
let encoder = command_buffer.new_compute_command_encoder();
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
let input = device.new_buffer_with_data(
|
let input = device.new_buffer_with_data(
|
||||||
v.as_ptr() as *const core::ffi::c_void,
|
v.as_ptr() as *const c_void,
|
||||||
(v.len() * core::mem::size_of::<f32>()) as u64,
|
(v.len() * mem::size_of::<f32>()) as u64,
|
||||||
option,
|
option,
|
||||||
);
|
);
|
||||||
let output = device.new_buffer((v.len() * core::mem::size_of::<f32>()) as u64, option);
|
let output = device.new_buffer((v.len() * mem::size_of::<f32>()) as u64, option);
|
||||||
let library = device
|
let library = device
|
||||||
.new_library_with_source(UNARY, &CompileOptions::new())
|
.new_library_with_source(UNARY, &CompileOptions::new())
|
||||||
.expect("Failed to load unary library");
|
.expect("Failed to load unary library");
|
||||||
@ -130,9 +137,93 @@ mod tests {
|
|||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
command_buffer.commit();
|
command_buffer.commit();
|
||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
|
||||||
let results = output.read_to_vec::<f32>(v.len());
|
let results = output.read_to_vec::<f32>(v.len());
|
||||||
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
|
||||||
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_add() {
|
||||||
|
let device = Device::system_default().expect("no device found");
|
||||||
|
|
||||||
|
let options = CompileOptions::new();
|
||||||
|
let library = device.new_library_with_source(INDEXING, &options).unwrap();
|
||||||
|
|
||||||
|
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
|
||||||
|
let right = [1.0f32; 15];
|
||||||
|
let index = [0u32, 4, 2];
|
||||||
|
let ids_dim_size = index.len() as u32;
|
||||||
|
let src_dim_size = 2u32;
|
||||||
|
let dst_dim_size = 2u32;
|
||||||
|
let left_size = left.len() as u32;
|
||||||
|
let right_size = right.len() as u32;
|
||||||
|
let numel = left_size * right_size;
|
||||||
|
|
||||||
|
let fcv = FunctionConstantValues::new();
|
||||||
|
fcv.set_constant_value_at_index(void_ptr(&ids_dim_size), MTLDataType::UInt, 0);
|
||||||
|
fcv.set_constant_value_at_index(void_ptr(&src_dim_size), MTLDataType::UInt, 1);
|
||||||
|
fcv.set_constant_value_at_index(void_ptr(&dst_dim_size), MTLDataType::UInt, 2);
|
||||||
|
fcv.set_constant_value_at_index(void_ptr(&left_size), MTLDataType::UInt, 3);
|
||||||
|
fcv.set_constant_value_at_index(void_ptr(&right_size), MTLDataType::UInt, 4);
|
||||||
|
|
||||||
|
let function = library.get_function("index_add", Some(fcv)).unwrap();
|
||||||
|
let pipeline = device
|
||||||
|
.new_compute_pipeline_state_with_function(&function)
|
||||||
|
.unwrap();
|
||||||
|
let options = MTLResourceOptions::StorageModeShared;
|
||||||
|
|
||||||
|
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
|
||||||
|
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||||
|
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
|
||||||
|
|
||||||
|
let ids = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
|
||||||
|
let inputs = device.new_buffer_with_data(void_ptr(&left), input_size, options);
|
||||||
|
let outputs = device.new_buffer_with_data(void_ptr(&right), output_size, options);
|
||||||
|
|
||||||
|
let command_queue = device.new_command_queue();
|
||||||
|
let command_buffer = command_queue.new_command_buffer();
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
let thread_group_memory_length = output_size;
|
||||||
|
encoder.set_threadgroup_memory_length(0, thread_group_memory_length as NSUInteger);
|
||||||
|
|
||||||
|
encoder.use_resource(&ids, MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(&inputs, MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(&outputs, MTLResourceUsage::Write);
|
||||||
|
|
||||||
|
encoder.set_buffer(0, Some(&ids), 0);
|
||||||
|
encoder.set_buffer(1, Some(&inputs), 0);
|
||||||
|
encoder.set_buffer(2, Some(&outputs), 0);
|
||||||
|
let width = 16;
|
||||||
|
|
||||||
|
let thread_group_count = MTLSize {
|
||||||
|
width,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width: (numel as NSUInteger + width) / width,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
println!("{:?}", thread_group_count);
|
||||||
|
println!("{:?}", thread_group_size);
|
||||||
|
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
command_buffer.commit();
|
||||||
|
command_buffer.wait_until_completed();
|
||||||
|
|
||||||
|
let expected = vec![
|
||||||
|
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
|
||||||
|
];
|
||||||
|
let result = outputs.read_to_vec::<f32>(right.len());
|
||||||
|
println!("{:?}", result);
|
||||||
|
|
||||||
|
assert_eq!(result, expected);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user