Matmul (no batch, no strided, f32, f32 only) sort of done.

This commit is contained in:
Nicolas Patry
2023-11-01 17:36:51 +01:00
parent 492d164235
commit 198009453a
9 changed files with 205 additions and 96 deletions

View File

@ -49,11 +49,11 @@ mod device;
pub mod display;
mod dtype;
mod dummy_cuda_backend;
#[cfg(feature = "metal")]
pub mod metal_backend;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "accelerate")]
mod metal_backend;
#[cfg(feature = "mkl")]

View File

@ -3,9 +3,13 @@ use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
pub use candle_metal;
use metal;
use core::mem;
use half::{f16, bf16};
use half::{bf16, f16};
use metal;
use metal::mps::matrix::{MatrixMultiplication, Matrix, MatrixDescriptor};
use metal::mps::{Float32, MPSDataType};
use metal::MTLResourceOptions;
use crate::bail;
/// Metal related errors
#[derive(thiserror::Error, Debug)]
@ -17,6 +21,7 @@ pub enum MetalError {
#[derive(Clone)]
pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
}
impl std::fmt::Debug for MetalDevice {
@ -47,8 +52,7 @@ impl MetalDevice {
pub struct MetalStorage {
buffer: metal::Buffer,
device: MetalDevice,
dtype: DType
dtype: DType,
}
impl BackendStorage for MetalStorage {
@ -192,12 +196,77 @@ impl BackendStorage for MetalStorage {
rhs_l: &Layout,
) -> Result<Self> {
let elem_count = b * m * n;
let dev = &self.device;
match (self.dtype, rhs.dtype){
match (self.dtype, rhs.dtype) {
(DType::F32, DType::F32) => {
todo!("MATMUL {b} {m} {n} {k}");
if b != 1 {
bail!("Didn't implemented strided matmul yet");
}
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
bail!("Didn't implemented non contiguous matmul yet");
}
let out_buffer = self.device.new_buffer(
(elem_count * mem::size_of::<f32>()) as u64,
MTLResourceOptions::empty(),
);
let m : u64 = m.try_into().expect("usize should fit u64");
let n: u64 = n.try_into().expect("usize should fit u64");
let k: u64 = k.try_into().expect("usize should fit u64");
// Create descriptors
let left_descriptor =
MatrixDescriptor::init_single(m, k, k * Float32::SIZE, Float32::TYPE_ID);
let right_descriptor =
MatrixDescriptor::init_single(k, n, n * Float32::SIZE, Float32::TYPE_ID);
let result_descriptor =
MatrixDescriptor::init_single(m, n, n * Float32::SIZE, Float32::TYPE_ID);
// Create matrix objects
let left_matrix =
Matrix::init_with_buffer_descriptor(&self.buffer, &left_descriptor)
.expect("Failed to create left matrix");
let right_matrix =
Matrix::init_with_buffer_descriptor(&rhs.buffer, &right_descriptor)
.expect("Failed to create left matrix");
let result_matrix =
Matrix::init_with_buffer_descriptor(&out_buffer, &result_descriptor)
.expect("Failed to create left matrix");
let transpose_left = false;
let transpose_right = false;
let alpha = 1.0;
let beta = 0.0;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.expect("Failed to create matrix multiplication kernel");
let buffer = self.device.command_queue.new_command_buffer();
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
buffer.commit();
Ok(Self{
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
})
}
_ => todo!("Unimplemented matmul for this pair")
_ => todo!("Unimplemented matmul for this pair"),
}
}
@ -211,7 +280,8 @@ impl BackendDevice for MetalDevice {
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
Ok(Self{device })
let command_queue = device.new_command_queue();
Ok(Self { device, command_queue })
}
fn set_seed(&self, _seed: u64) -> Result<()> {
@ -237,57 +307,47 @@ impl BackendDevice for MetalDevice {
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let option = metal::MTLResourceOptions::CPUCacheModeDefaultCache;
let buffer = match storage {
CpuStorage::U8(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u8>()) as u64,
option
)
}
CpuStorage::U32(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u32>()) as u64,
option
)
}
CpuStorage::I64(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<i64>()) as u64,
option
)
}
CpuStorage::BF16(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<bf16>()) as u64,
option
)
}
CpuStorage::F16(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f16>()) as u64,
option
)
}
CpuStorage::F32(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f32>()) as u64,
option
)
}
CpuStorage::F64(storage) => {
self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f64>()) as u64,
option
)
}
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u8>()) as u64,
option,
),
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u32>()) as u64,
option,
),
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<i64>()) as u64,
option,
),
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<bf16>()) as u64,
option,
),
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f16>()) as u64,
option,
),
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f32>()) as u64,
option,
),
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f64>()) as u64,
option,
),
};
Ok(Self::Storage{buffer, device: self.clone(), dtype: storage.dtype()})
Ok(Self::Storage {
buffer,
device: self.clone(),
dtype: storage.dtype(),
})
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {

View File

@ -1,5 +1,5 @@
#![allow(clippy::redundant_closure_call)]
use crate::{CpuStorage, CudaStorage, MetalStorage, Layout, Result, Shape, Tensor};
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use half::{bf16, f16};
use num_traits::float::Float;
@ -176,7 +176,11 @@ pub trait CustomOp1 {
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(&self, _storage: &MetalStorage, _layout: &Layout) -> Result<(MetalStorage, Shape)> {
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no cuda implementation for {}", self.name()).into(),
))

View File

@ -1,7 +1,7 @@
//! Support for the GGML file format.
use super::{k_quants, GgmlDType};
use crate::{Result, Device};
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
@ -148,16 +148,36 @@ pub fn qtensor_from_ggml(
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => {
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4_1 => {
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_0 => {
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_1 => {
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
@ -204,7 +224,10 @@ pub struct Content {
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R, device: &Device) -> Result<Content> {
pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
device: &Device,
) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;

View File

@ -3,7 +3,7 @@
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
use crate::{Result, Device};
use crate::{Device, Result};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
@ -70,7 +70,12 @@ impl TensorInfo {
let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec(), device)
super::ggml_file::qtensor_from_ggml(
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
}
}

View File

@ -178,7 +178,7 @@ impl QTensor {
Ok(Self {
data: Box::new(data),
shape,
device: device.clone()
device: device.clone(),
})
}
@ -201,7 +201,7 @@ impl QTensor {
Ok(Self {
data: Box::new(data),
shape: shape.clone(),
device: device.clone()
device: device.clone(),
})
}

View File

@ -1,6 +1,6 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::{CpuStorage, CudaStorage, MetalStorage, DType, Device, Error, Layout, Result, Shape};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@ -659,7 +659,9 @@ impl Storage {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
(Self::Metal(src), Self::Metal(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
(Self::Metal(src), Self::Metal(dst)) => {
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),

View File

@ -2,25 +2,27 @@ pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
use candle::{Device, Result, Tensor};
use candle::utils::{cuda_is_available, metal_is_available};
use candle::{Device, Result, Tensor};
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else {
if cuda_is_available(){
if cuda_is_available() {
Ok(Device::new_cuda(0)?)
}else if metal_is_available(){
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
}else{
#[cfg(all(target_os="macos", target_arch="aarch64"))]
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!("Running on CPU, to run on GPU(metal), build this example with `--features metal`");
}
#[cfg(not(all(target_os="macos", target_arch="aarch64")))]
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
println!(
"Running on CPU, to run on GPU, build this example with `--features cuda`"
);
}
Ok(Device::Cpu)
}

View File

@ -181,14 +181,20 @@ pub struct ModelWeights {
span_output: tracing::Span,
}
fn precomput_freqs_cis(head_dim: usize, freq_base: f32, device: &Device) -> Result<(Tensor, Tensor)> {
fn precomput_freqs_cis(
head_dim: usize,
freq_base: f32,
device: &Device,
) -> Result<(Tensor, Tensor)> {
let theta: Vec<_> = (0..head_dim)
.step_by(2)
.map(|i| 1f32 / freq_base.powf(i as f32 / head_dim as f32))
.collect();
let theta = Tensor::new(theta.as_slice(), device)?;
let range: Vec<f32> = (0..MAX_SEQ_LEN).map(|r| r as f32).collect();
let idx_theta = Tensor::new(range.as_slice(), device)?.reshape((MAX_SEQ_LEN, 1))?.matmul(&theta.reshape((1, theta.elem_count()))?)?;
let idx_theta = Tensor::new(range.as_slice(), device)?
.reshape((MAX_SEQ_LEN, 1))?
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
// TODO This change avoids allocating on Metal and then casting since allocating directly on
// CPU as f32 seems just as fast
// let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
@ -260,7 +266,7 @@ impl ModelWeights {
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
ct: gguf_file::Content,
reader: &mut R,
device: &Device
device: &Device,
) -> Result<Self> {
let md_get = |s: &str| match ct.metadata.get(s) {
None => candle::bail!("cannot find {s} in metadata"),
@ -283,7 +289,10 @@ impl ModelWeights {
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::new(ct.tensor(reader, "output_norm.weight", device)?, rms_norm_eps)?;
let norm = RmsNorm::new(
ct.tensor(reader, "output_norm.weight", device)?,
rms_norm_eps,
)?;
let output = ct.tensor(reader, "output.weight", device)?;
let mut layers = Vec::with_capacity(block_count);
for layer_idx in 0..block_count {
@ -291,11 +300,15 @@ impl ModelWeights {
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let feed_forward_w1 = ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 = ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let attention_wo =
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
let attention_norm = ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
let attention_norm =
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?;
let ffn_norm = ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?;
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");