Compare commits

..

3 Commits

Author SHA1 Message Date
69c1fb1ee8 Add a benchmark for the matmul slowness. 2023-10-11 15:49:42 +02:00
c55ebaf477 Use full tensors for zeros and ones. 2023-10-11 08:50:43 +02:00
4c91dd2ff4 Only optimize float tensors. 2023-10-10 09:45:49 +02:00
160 changed files with 792 additions and 14316 deletions

Binary file not shown.

View File

@ -1,62 +0,0 @@
name: PyO3-CI
on:
workflow_dispatch:
push:
branches:
- main
paths:
- candle-pyo3/**
pull_request:
paths:
- candle-pyo3/**
jobs:
build_and_test:
name: Check everything builds & tests
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [ubuntu-latest] # For now, only test on Linux
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
- name: Install Python
uses: actions/setup-python@v4
with:
python-version: 3.11
architecture: "x64"
- name: Cache Cargo Registry
uses: actions/cache@v1
with:
path: ~/.cargo/registry
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}
- name: Install
working-directory: ./candle-pyo3
run: |
python -m venv .env
source .env/bin/activate
pip install -U pip
pip install pytest maturin black
python -m maturin develop -r
- name: Check style
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python stub.py --check
black --check .
- name: Run tests
working-directory: ./candle-pyo3
run: |
source .env/bin/activate
python -m pytest -s -v tests

View File

@ -7,7 +7,13 @@ members = [
"candle-nn",
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/*",
"candle-wasm-examples/llama2-c",
"candle-wasm-examples/segment-anything",
"candle-wasm-examples/whisper",
"candle-wasm-examples/yolo",
"candle-wasm-examples/bert",
"candle-wasm-examples/phi",
"candle-wasm-examples/t5",
"candle-wasm-tests",
]
exclude = ["candle-flash-attn", "candle-kernels"]
@ -28,7 +34,8 @@ anyhow = { version = "1", features = ["backtrace"] }
byteorder = "1.4.3"
clap = { version = "4.2.4", features = ["derive"] }
cudarc = { version = "0.9.14", features = ["f16"] }
gemm = { version = "0.16.6", features = ["wasm-simd128-enable"] }
# TODO: Switch back to the official gemm implementation once it has caught up.
gemm = { version = "0.16.0", package = "candle-gemm" }
hf-hub = "0.3.0"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
image = { version = "0.24.7", default-features = false, features = ["jpeg", "png"] }
@ -55,7 +62,6 @@ tracing-subscriber = "0.3.7"
wav = "1.0.0"
yoke = { version = "0.7.2", features = ["derive"] }
zip = { version = "0.6.6", default-features = false }
metal = { git = "https://github.com/ivarflakstad/metal-rs.git", features = ["mps"] }
[profile.release-with-debug]
inherits = "release"

View File

@ -56,19 +56,17 @@ These online demos run entirely in your browser:
- [T5](https://huggingface.co/spaces/radames/Candle-T5-Generation-Wasm): text generation.
- [Phi-v1.5](https://huggingface.co/spaces/radames/Candle-Phi-1.5-Wasm): text generation.
- [Segment Anything Model](https://huggingface.co/spaces/radames/candle-segment-anything-wasm): Image segmentation.
- [BLIP](https://huggingface.co/spaces/radames/Candle-BLIP-Image-Captioning): image captioning.
We also provide a some command line based examples using state of the art models:
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
- [Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
pre-trained on 1T tokens of English and code datasets.
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
performance larger than all publicly available 13b models as of 2023-09-28.
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
- [Quantized LLaMA](./candle-examples/examples/quantized/): quantized version of
the LLaMA model using the same quantization techniques as
[llama.cpp](https://github.com/ggerganov/llama.cpp).
@ -96,15 +94,10 @@ We also provide a some command line based examples using state of the art models
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/): useful for sentence embeddings.
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
using self-supervision (can be used for imagenet classification, depth
evaluation, segmentation).
- [BLIP](./candle-examples/examples/blip/): image to text model, can be used to
generate captions for an image.
- [Marian-MT](./candle-examples/examples/marian-mt/): neural machine translation
model, generates the translated text from the input text.
Run them using commands like:
```
@ -136,13 +129,8 @@ And then head over to
<!--- ANCHOR: useful_libraries --->
## Useful External Resources
- [`candle-tutorial`](https://github.com/ToluClassics/candle-tutorial): a
very detailed tutorial showing how to convert a PyTorch model to Candle.
- [`optimisers`](https://github.com/KGrewal1/optimisers): a collection of optimisers
including SGD with momentum, AdaGrad, AdaDelta, AdaMax, NAdam, RAdam, and RMSprop.
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora): a LoRA implementation
that conforms to the official `peft` implementation.
## Useful Libraries
- [`candle-lora`](https://github.com/EricLBuehler/candle-lora) provides a LoRA implementation that conforms to the official `peft` implementation.
If you have an addition to this list, please submit a pull request.
@ -167,20 +155,16 @@ If you have an addition to this list, please submit a pull request.
- Phi v1.5.
- Mistral 7b v0.1.
- StableLM-3B-4E1T.
- Replit-code-v1.5-3B.
- T5.
- Bert.
- Whisper (multi-lingual support).
- Text to image.
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Image to text.
- BLIP.
- Text to text.
- Marian MT (Machine Translation).
- Stable Diffusion v1.5, v2.1, XL v1.0.
- Wurstchen v2.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT.
- yolo-v3, yolo-v8.
- DINOv2.
- EfficientNet.
- yolo-v3.
- yolo-v8.
- Segment-Anything Model (SAM).
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
- Serverless (on CPU), small and fast deployments.

View File

@ -12,9 +12,6 @@ compute_cap
8.9
```
You can also compile the Cuda kernels for a specific compute cap using the
`CUDA_COMPUTE_CAP=<compute cap>` environment variable.
If any of the above commands errors out, please make sure to update your Cuda version.
2. Create a new app and add [`candle-core`](https://github.com/huggingface/candle/tree/main/candle-core) with Cuda support.

View File

@ -12,10 +12,7 @@ readme = "README.md"
[dependencies]
accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
tracing = { workspace = true }
candle-kernels = { path = "../candle-kernels", version = "0.3.0", optional = true }
candle-metal-kernels = { path = "../candle-metal-kernels", version = "0.3.0", optional = true }
metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
@ -42,4 +39,3 @@ cuda = ["cudarc", "dep:candle-kernels"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
metal = ["dep:candle-metal-kernels", "dep:metal"]

View File

@ -36,8 +36,6 @@ impl Tensor {
// Do not call recursively on the "leaf" nodes.
track_grad = true;
nodes
} else if node.dtype().is_int() {
nodes
} else if let Some(op) = node.op() {
match op {
Op::IndexAdd(t1, t2, t3, _)
@ -105,6 +103,7 @@ impl Tensor {
| Op::Broadcast(node)
| Op::Cmp(node, _)
| Op::Reduce(node, ReduceOp::Min | ReduceOp::Sum | ReduceOp::Max, _)
| Op::ToDType(node)
| Op::ToDevice(node)
| Op::Transpose(node, _, _)
| Op::Permute(node, _)
@ -117,15 +116,6 @@ impl Tensor {
track_grad |= tg;
nodes
}
Op::ToDType(node) => {
if node.dtype().is_float() {
let (tg, nodes) = walk(node, nodes, already_seen);
track_grad |= tg;
nodes
} else {
nodes
}
}
Op::Reduce(_, ReduceOp::ArgMin | ReduceOp::ArgMax, _) => nodes,
}
} else {
@ -238,13 +228,6 @@ impl Tensor {
.conv2d(&grad.transpose(0, 1)?, *padding, *dilation, *stride, 1)?
.transpose(0, 1)?;
let sum_grad = grads.or_insert(kernel)?;
let (_, _, k0, k1) = kernel.dims4()?;
let (_, _, g_k0, g_k1) = grad_kernel.dims4()?;
let grad_kernel = if g_k0 != k0 || g_k1 != k1 {
grad_kernel.narrow(2, 0, k0)?.narrow(3, 0, k1)?
} else {
grad_kernel
};
*sum_grad = sum_grad.add(&grad_kernel)?;
}
Op::ConvTranspose2D { .. } => Err(Error::BackwardNotSupported {
@ -391,7 +374,7 @@ impl Tensor {
}
Op::ToDType(arg) => {
let sum_grad = grads.or_insert(arg)?;
*sum_grad = sum_grad.add(&grad.to_dtype(arg.dtype())?)?
*sum_grad = sum_grad.add(&grad.to_dtype(node.dtype())?)?
}
Op::Copy(arg) => {
let sum_grad = grads.or_insert(arg)?;
@ -478,15 +461,7 @@ impl Tensor {
Op::Unary(_, UnaryOp::Round) => {
Err(Error::BackwardNotSupported { op: "round" })?
}
Op::Unary(arg, UnaryOp::Gelu) => {
let sum_grad = grads.or_insert(arg)?;
let cube = arg.powf(3.)?;
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
let gelu_grad = (((0.5 * &tanh)?
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
+ 0.5)?;
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?

View File

@ -804,11 +804,11 @@ impl<'a, I: IntDType> Map1 for Gather<'a, I> {
fn f<T: WithDType>(&self, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
None => Err(Error::RequiresContiguous { op: "gather" })?,
};
let src = match src_l.contiguous_offsets() {
Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
None => Err(Error::RequiresContiguous { op: "gather" })?,
};
let dim = self.dim;
let ids_dims = self.ids_l.dims();
@ -857,7 +857,7 @@ impl<'a, I: IntDType> Map1 for IndexSelect<'a, I> {
fn f<T: WithDType>(&self, src: &[T], layout: &Layout) -> Result<Vec<T>> {
let src = match layout.contiguous_offsets() {
Some((a, b)) => &src[a..b],
None => Err(Error::RequiresContiguous { op: "index-select" }.bt())?,
None => Err(Error::RequiresContiguous { op: "index-select" })?,
};
let dim = self.dim;
let n_ids = match self.ids_l.dims() {
@ -913,7 +913,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
None => Err(Error::RequiresContiguous { op: "scatter-add" })?,
Some((o1, o2)) => &src[o1..o2],
};
@ -929,7 +929,7 @@ impl<'a, I: IntDType> Map2 for ScatterAdd<'a, I> {
let ids = match self.ids_l.contiguous_offsets() {
Some((a, b)) => &self.ids[a..b],
None => Err(Error::RequiresContiguous { op: "gather" }.bt())?,
None => Err(Error::RequiresContiguous { op: "gather" })?,
};
for left_i in 0..ids_left_len {
let start_ids_idx = left_i * ids_right_len * ids_dim_len;
@ -971,7 +971,7 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
let mut dst = vec![T::zero(); dst_len];
copy_strided_src_(v1, &mut dst, 0, l1);
let src = match src_l.contiguous_offsets() {
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
None => Err(Error::RequiresContiguous { op: "index-add" })?,
Some((o1, o2)) => &src[o1..o2],
};
let dim = self.dim;
@ -2539,25 +2539,25 @@ impl BackendStorage for CpuStorage {
Self::U8(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
None => Err(Error::RequiresContiguous { op: "index-add" })?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::U32(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
None => Err(Error::RequiresContiguous { op: "index-add" })?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
Self::I64(ids) => {
let ids = match ids_l.contiguous_offsets() {
Some((a, b)) => &ids[a..b],
None => Err(Error::RequiresContiguous { op: "index-add" }.bt())?,
None => Err(Error::RequiresContiguous { op: "index-add" })?,
};
IndexAdd { ids, dim }.map(self, l, src, src_l)
}
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add").bt()),
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-add")),
}
}

View File

@ -224,10 +224,8 @@ impl BackendDevice for CudaDevice {
}
fn set_seed(&self, seed: u64) -> Result<()> {
// We do not call set_seed but instead create a new curand object. This ensures that the
// state will be identical and the same random numbers will be generated.
let mut curand = self.curand.lock().unwrap();
curand.0 = cudarc::curand::CudaRng::new(seed, self.device.clone()).w()?;
curand.0.set_seed(seed).w()?;
Ok(())
}
@ -2171,7 +2169,7 @@ impl BackendStorage for CudaStorage {
if src_l.is_contiguous() {
dev.dtod_copy(&src, &mut dst).w()?
} else {
let func = dev.get_or_load_func("ucopy_f64", kernels::UNARY)?;
let func = dev.get_or_load_func("ucopy_64", kernels::UNARY)?;
// SAFETY: Set later by running the kernel.
let params = (el_count, dims.len(), &ds, &src, &mut dst);
// SAFETY: ffi.

View File

@ -1,6 +1,6 @@
use crate::backend::BackendDevice;
use crate::cpu_backend::CpuDevice;
use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType};
use crate::{CpuStorage, DType, Result, Shape, Storage, WithDType};
/// A `DeviceLocation` represents a physical device whereas multiple `Device`
/// can live on the same location (typically for cuda devices).
@ -8,14 +8,12 @@ use crate::{bail, CpuStorage, DType, Result, Shape, Storage, WithDType};
pub enum DeviceLocation {
Cpu,
Cuda { gpu_id: usize },
Metal,
}
#[derive(Debug, Clone)]
pub enum Device {
Cpu,
Cuda(crate::CudaDevice),
Metal(crate::MetalDevice),
}
pub trait NdArray {
@ -105,14 +103,14 @@ impl<S: WithDType, const N1: usize, const N2: usize, const N3: usize, const N4:
impl<S: NdArray> NdArray for Vec<S> {
fn shape(&self) -> Result<Shape> {
if self.is_empty() {
bail!("empty array")
crate::bail!("empty array")
}
let shape0 = self[0].shape()?;
let n = self.len();
for v in self.iter() {
let shape = v.shape()?;
if shape != shape0 {
bail!("two elements have different shapes {shape:?} {shape0:?}")
crate::bail!("two elements have different shapes {shape:?} {shape0:?}")
}
}
Ok(Shape::from([[n].as_slice(), shape0.dims()].concat()))
@ -130,18 +128,6 @@ impl Device {
Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?))
}
pub fn new_metal(ordinal: usize) -> Result<Self> {
Ok(Self::Metal(crate::MetalDevice::new(ordinal)?))
}
pub fn set_seed(&self, seed: u64) -> Result<()> {
match self {
Self::Cpu => CpuDevice.set_seed(seed),
Self::Cuda(c) => c.set_seed(seed),
Self::Metal(m) => m.set_seed(seed),
}
}
pub fn same_device(&self, rhs: &Self) -> bool {
match (self, rhs) {
(Self::Cpu, Self::Cpu) => true,
@ -154,16 +140,21 @@ impl Device {
match self {
Self::Cpu => DeviceLocation::Cpu,
Self::Cuda(device) => device.location(),
Device::Metal(device) => device.location(),
}
}
pub fn is_cpu(&self) -> bool {
matches!(self, Self::Cpu)
match self {
Self::Cpu => true,
Self::Cuda(_) => false,
}
}
pub fn is_cuda(&self) -> bool {
matches!(self, Self::Cuda(_))
match self {
Self::Cpu => false,
Self::Cuda(_) => true,
}
}
pub fn cuda_if_available(ordinal: usize) -> Result<Self> {
@ -190,11 +181,6 @@ impl Device {
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(_device) => {
// let storage = device.rand_uniform(shape, dtype, lo, up)?;
// Ok(Storage::Metal(storage))
bail!("Metal rand_uniform not implemented")
}
}
}
@ -223,10 +209,6 @@ impl Device {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Metal(storage))
}
}
}
@ -249,10 +231,6 @@ impl Device {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.ones_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
@ -266,10 +244,6 @@ impl Device {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = device.zeros_impl(shape, dtype)?;
Ok(Storage::Metal(storage))
}
}
}
@ -281,11 +255,6 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = array.to_cpu_storage();
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
}
@ -297,11 +266,6 @@ impl Device {
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Cuda(storage))
}
Device::Metal(device) => {
let storage = S::to_cpu_storage_owned(data);
let storage = device.storage_from_cpu_storage(&storage)?;
Ok(Storage::Metal(storage))
}
}
}
}

View File

@ -14,7 +14,6 @@ impl Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
_ => todo!(),
};
write!(f, "Tensor[")?;
@ -477,7 +476,6 @@ impl std::fmt::Display for Tensor {
crate::DeviceLocation::Cuda { gpu_id } => {
format!(", cuda:{}", gpu_id)
}
crate::DeviceLocation::Metal => todo!(),
};
write!(

View File

@ -1,201 +0,0 @@
#![allow(dead_code)]
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Error, Layout, Result, Shape};
#[derive(Debug, Clone)]
pub struct MetalDevice;
#[derive(Debug)]
pub struct MetalStorage;
macro_rules! fail {
() => {
unimplemented!("metal support has not been enabled, add `metal` feature to enable.")
};
}
impl crate::backend::BackendStorage for MetalStorage {
type Device = MetalDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn dtype(&self) -> DType {
fail!()
}
fn device(&self) -> &Self::Device {
fail!()
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn affine(&self, _: &Layout, _: f64, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn reduce_op(&self, _: ReduceOp, _: &Layout, _: &[usize]) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn to_dtype(&self, _: &Layout, _: DType) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn unary_impl<B: UnaryOpT>(&self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn binary_impl<B: BinaryOpT>(&self, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn where_cond(&self, _: &Layout, _: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv1d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv1D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv2d(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &crate::conv::ParamsConv2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &crate::conv::ParamsConvTranspose2D,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn matmul(
&self,
_: &Self,
_: (usize, usize, usize, usize),
_: &Layout,
_: &Layout,
) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
}
impl crate::backend::BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(_: usize) -> Result<Self> {
Err(Error::NotCompiledWithMetalSupport)
}
fn set_seed(&self, _: u64) -> Result<()> {
Err(Error::NotCompiledWithMetalSupport)
}
fn location(&self) -> crate::DeviceLocation {
fail!()
}
fn same_device(&self, _: &Self) -> bool {
fail!()
}
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn storage_from_cpu_storage(&self, _: &CpuStorage) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_uniform(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
fn rand_normal(&self, _: &Shape, _: DType, _: f64, _: f64) -> Result<Self::Storage> {
Err(Error::NotCompiledWithMetalSupport)
}
}

View File

@ -1,4 +1,4 @@
use crate::{metal_backend, DType, DeviceLocation, Layout, Shape};
use crate::{DType, DeviceLocation, Layout, Shape};
#[derive(Debug, Clone)]
pub struct MatMulUnexpectedStriding {
@ -142,9 +142,6 @@ pub enum Error {
#[error("{op} expects at least one tensor")]
OpRequiresAtLeastOneTensor { op: &'static str },
#[error("{op} expects at least two tensors")]
OpRequiresAtLeastTwoTensors { op: &'static str },
#[error("backward is not supported for {op}")]
BackwardNotSupported { op: &'static str },
@ -152,9 +149,6 @@ pub enum Error {
#[error("the candle crate has not been built with cuda support")]
NotCompiledWithCudaSupport,
#[error("the candle crate has not been built with metal support")]
NotCompiledWithMetalSupport,
#[error("cannot find tensor {path}")]
CannotFindTensor { path: String },
@ -162,9 +156,6 @@ pub enum Error {
#[error(transparent)]
Cuda(Box<dyn std::error::Error + Send + Sync>),
#[error("Metal error {0}")]
Metal(#[from] metal_backend::MetalError),
#[error(transparent)]
TryFromIntError(#[from] core::num::TryFromIntError),

View File

@ -52,10 +52,6 @@ mod dummy_cuda_backend;
pub mod error;
mod indexer;
pub mod layout;
#[cfg(feature = "metal")]
pub mod metal_backend;
#[cfg(feature = "accelerate")]
mod metal_backend;
#[cfg(feature = "mkl")]
mod mkl;
pub mod npy;
@ -91,12 +87,6 @@ pub use cuda_backend::{CudaDevice, CudaStorage};
#[cfg(not(feature = "cuda"))]
pub use dummy_cuda_backend::{CudaDevice, CudaStorage};
#[cfg(feature = "metal")]
pub use metal_backend::{MetalDevice, MetalStorage};
#[cfg(not(feature = "metal"))]
pub use dummy_metal_backend::{MetalDevice, MetalStorage};
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
@ -135,15 +125,3 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
self(xs)
}
}
// A trait defining a module with forward method using a single tensor argument and a flag to
// separate the training and evaluation behaviors.
pub trait ModuleT {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor>;
}
impl<M: Module> ModuleT for M {
fn forward_t(&self, xs: &Tensor, _train: bool) -> Result<Tensor> {
self.forward(xs)
}
}

View File

@ -1,806 +0,0 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::conv::{ParamsConv1D, ParamsConv2D, ParamsConvTranspose2D};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{CpuStorage, DType, Layout, Result, Shape};
use candle_metal_kernels;
use candle_metal_kernels::{void_ptr, Kernels, Source};
use core::mem;
use half::{bf16, f16};
use metal;
use metal::mps::matrix::encode_gemm;
use metal::mps::Float32;
use metal::{Buffer, CompileOptions, MTLResourceOptions, MTLSize, NSUInteger};
use std::sync::Arc;
use tracing::debug;
/// Metal related errors
#[derive(thiserror::Error, Debug)]
pub enum MetalError {
#[error("{0}")]
Message(String),
#[error(transparent)]
KernelError(#[from] candle_metal_kernels::MetalKernelError),
}
impl From<String> for MetalError {
fn from(e: String) -> Self {
MetalError::Message(e)
}
}
impl MetalError {
fn msg<S: AsRef<str>>(msg: S) -> Self {
MetalError::Message(msg.as_ref().to_string())
}
}
#[derive(Clone)]
pub struct MetalDevice {
device: metal::Device,
command_queue: metal::CommandQueue,
kernels: Arc<candle_metal_kernels::Kernels>,
}
impl std::fmt::Debug for MetalDevice {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "MetalDevice({:?})", self.device.registry_id())
}
}
impl std::ops::Deref for MetalDevice {
type Target = metal::DeviceRef;
fn deref(&self) -> &Self::Target {
&self.device
}
}
impl MetalDevice {
// pub fn metal_device(&self) -> &metal::DeviceRef {
// self.device.as_ref()
// }
pub fn id(&self) -> u64 {
self.registry_id()
}
fn new_buffer(&self, element_count: usize, dtype: DType) -> Buffer {
let size = (element_count * dtype.size_in_bytes()) as u64;
// debug!("Allocate 1 - buffer size {size}");
self.device
.new_buffer(size, MTLResourceOptions::StorageModeManaged)
}
}
#[derive(Debug, Clone)]
pub struct MetalStorage {
buffer: metal::Buffer,
device: MetalDevice,
dtype: DType,
}
impl BackendStorage for MetalStorage {
type Device = MetalDevice;
fn try_clone(&self, _: &Layout) -> Result<Self> {
Ok(self.clone())
}
fn dtype(&self) -> DType {
self.dtype
}
fn device(&self) -> &Self::Device {
&self.device
}
fn to_cpu_storage(&self) -> Result<CpuStorage> {
match self.dtype {
DType::F32 => Ok(CpuStorage::F32(
self.buffer.read_to_vec(self.buffer.length() as usize / 4),
)),
dtype => todo!("Unsupported dtype {dtype:?}"),
}
}
fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
let device = self.device().clone();
let shape = layout.shape();
let dims = shape.dims();
let el = shape.elem_count();
let dtype = self.dtype;
assert!(layout.is_contiguous());
assert_eq!(dtype, DType::F32);
let mut buffer = device.new_buffer(el, self.dtype);
let command_buffer = self.device.command_queue.new_command_buffer();
candle_metal_kernels::call_affine(
&device.device,
&command_buffer,
&device.kernels,
el,
&self.buffer,
&mut buffer,
mul as f32,
add as f32,
)
.unwrap();
command_buffer.commit();
return Ok(Self {
buffer,
device: device.clone(),
dtype,
});
}
fn powf(&self, _: &Layout, _: f64) -> Result<Self> {
todo!()
}
fn elu(&self, _: &Layout, _: f64) -> Result<Self> {
todo!()
}
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
debug!("TODO reduce_op {op:?}");
let src_stride = layout.stride();
let src_dims = layout.shape().dims();
let src_el: usize = src_dims.iter().product();
// Source dims and strides with the sum dims at the end.
let mut dims = vec![];
let mut stride = vec![];
let mut dst_el: usize = 1;
for (dim_idx, &d) in src_dims.iter().enumerate() {
if !sum_dims.contains(&dim_idx) {
dst_el *= d;
dims.push(d);
stride.push(src_stride[dim_idx]);
}
}
for &dim_idx in sum_dims.iter() {
dims.push(src_dims[dim_idx]);
stride.push(src_stride[dim_idx]);
}
// let el_to_sum_per_block = src_el / dst_el;
// // The reduction loop requires the shared array to be properly initialized and for
// // this we want the number of threads to be a power of two.
// let block_dim = usize::min(1024, el_to_sum_per_block).next_power_of_two();
// let cfg = LaunchConfig {
// // TODO: Maybe use grid_y if the output is too large?
// // TODO: Specialized implementation when reducing on no or all dimensions or when
// // reducing only aggregate a small number of elements together.
// grid_dim: (dst_el as u32, 1, 1),
// block_dim: (block_dim as u32, 1, 1),
// shared_mem_bytes: 0,
// };
// let ds = dev
// .htod_copy([dims.as_slice(), stride.as_slice()].concat())
// .w()?;
// let src = &src.slice(layout.start_offset()..);
// let (name, check_empty, return_index) = match self.1 {
// ReduceOp::Sum => ("fast_sum", false, false),
// ReduceOp::Min => ("fast_min", true, false),
// ReduceOp::Max => ("fast_max", true, false),
// ReduceOp::ArgMin => ("fast_argmin", true, true),
// ReduceOp::ArgMax => ("fast_argmax", true, true),
// };
// if check_empty && layout.shape().elem_count() == 0 {
// Err(crate::Error::EmptyTensor { op: "reduce" }.bt())?
// }
// let func = dev.get_or_load_func(&kernel_name::<T>(name), kernels::REDUCE)?;
// if return_index {
// // SAFETY: filled in by the follow up kernel.
// let out = unsafe { dev.alloc::<u32>(dst_el) }.w()?;
// let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// // SAFETY: ffi.
// unsafe { func.launch(cfg, params) }.w()?;
// Ok(S::U32(out))
// } else {
// // SAFETY: filled in by the follow up kernel.
// let out = unsafe { dev.alloc::<T>(dst_el) }.w()?;
// let params = (src_el, el_to_sum_per_block, src_dims.len(), &ds, src, &out);
// // SAFETY: ffi.
// unsafe { func.launch(cfg, params) }.w()?;
// Ok(wrap(out))
// }
// Ok(self.clone())
// todo!()
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn cmp(&self, _: CmpOp, _: &Self, _: &Layout, _: &Layout) -> Result<Self> {
todo!()
}
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
let device = self.device();
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (self.dtype, dtype) {
(DType::U32, DType::F32) => "cast_u32_f32",
(left, right) => todo!("to dtype {left:?} - {right:?}"),
};
candle_metal_kernels::call_cast_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
todo!(
"TODO Implement the kernel calling cast {:?}-{:?}",
self.dtype,
dtype
);
}
let start = std::time::Instant::now();
command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"cast {:?} - {:?} - {:?} - {:?}",
dtype,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = layout.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
// TODO remove
// return Ok(Self {
// buffer,
// device: device.clone(),
// dtype,
// });
let command_buffer = device.command_queue.new_command_buffer();
if layout.is_contiguous() {
use candle_metal_kernels::unary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("ucos", DType::F32) => contiguous::cos::FLOAT,
("usin", DType::F32) => contiguous::sin::FLOAT,
("usqr", DType::F32) => contiguous::sqr::FLOAT,
("usqrt", DType::F32) => contiguous::sqrt::FLOAT,
("uneg", DType::F32) => contiguous::neg::FLOAT,
("uexp", DType::F32) => contiguous::exp::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_unary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
todo!("TODO Implement the kernel calling {}", B::KERNEL);
}
let start = std::time::Instant::now();
command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"Unary {:?} - {:?} - {:?} - {:?}",
B::KERNEL,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn binary_impl<B: BinaryOpT>(
&self,
rhs: &Self,
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let device = self.device();
let dtype = self.dtype;
let shape = lhs_l.shape();
let dims = shape.dims();
let el_count = shape.elem_count();
let mut buffer = device.new_buffer(el_count, dtype);
let command_buffer = device.command_queue.new_command_buffer();
if lhs_l.is_contiguous() && rhs_l.is_contiguous() {
use candle_metal_kernels::binary::contiguous;
let kernel_name = match (B::KERNEL, dtype) {
("add", DType::F32) => contiguous::add::FLOAT,
("badd", DType::F32) => contiguous::add::FLOAT,
("sub", DType::F32) => contiguous::sub::FLOAT,
("bsub", DType::F32) => contiguous::sub::FLOAT,
("mul", DType::F32) => contiguous::mul::FLOAT,
("bmul", DType::F32) => contiguous::mul::FLOAT,
("div", DType::F32) => contiguous::div::FLOAT,
("bdiv", DType::F32) => contiguous::div::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_contiguous(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
el_count,
&self.buffer,
&rhs.buffer,
&mut buffer,
)
.map_err(MetalError::from)?;
} else {
use candle_metal_kernels::binary::strided;
let kernel_name = match (B::KERNEL, dtype) {
("badd", DType::F32) => strided::add::FLOAT,
("bsub", DType::F32) => strided::sub::FLOAT,
("bmul", DType::F32) => strided::mul::FLOAT,
("bdiv", DType::F32) => strided::div::FLOAT,
(name, dtype) => todo!("Match {name} - {dtype:?}"),
};
candle_metal_kernels::call_binary_strided(
&device.device,
&command_buffer,
&device.kernels,
kernel_name,
lhs_l.dims(),
&self.buffer,
&lhs_l.stride(),
lhs_l.start_offset(),
&rhs.buffer,
&rhs_l.stride(),
rhs_l.start_offset(),
&mut buffer,
)
.map_err(MetalError::from)?;
}
let start = std::time::Instant::now();
command_buffer.commit();
// command_buffer.wait_until_scheduled();
debug!(
"Binary {:?} - {:?} - {:?} - {:?}",
B::KERNEL,
start.elapsed(),
self.buffer.length(),
buffer.length()
);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
}
fn where_cond(&self, _: &Layout, rhs: &Self, _: &Layout, _: &Self, _: &Layout) -> Result<Self> {
debug!("TODO where_cond");
Ok(rhs.clone())
// todo!()
}
fn conv1d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConv1D,
) -> Result<Self> {
todo!()
}
fn conv2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConv2D,
) -> Result<Self> {
todo!()
}
fn conv_transpose2d(
&self,
_l: &Layout,
_kernel: &Self,
_kernel_l: &Layout,
_params: &ParamsConvTranspose2D,
) -> Result<Self> {
todo!()
}
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
todo!()
}
fn max_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
todo!()
}
fn upsample_nearest1d(&self, _: &Layout, _: usize) -> Result<Self> {
todo!()
}
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self> {
todo!()
}
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self> {
todo!()
}
fn scatter_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
todo!()
}
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
debug!(
"TODO Index select {:?} {:?} {src_l:?} {ids_l:?} {dim:?}",
self.buffer.length(),
ids.buffer.length(),
);
let src = self;
let ids_shape = ids_l.shape();
let ids_dims = ids_shape.dims();
// let ds = dev.htod_copy([ids_dims, ids_l.stride()].concat()).w()?;
// let src = match src_l.contiguous_offsets() {
// Some((o1, o2)) => src.slice(o1..o2),
// None => Err(crate::Error::RequiresContiguous { op: "index-select" }.bt())?,
// };
let left_size: usize = src_l.dims()[..dim].iter().product();
let right_size: usize = src_l.dims()[dim + 1..].iter().product();
let src_dim_size = src_l.dims()[dim];
let ids_dim_size = ids_shape.elem_count();
let dst_el = ids_shape.elem_count() * left_size * right_size;
let dtype = self.dtype;
let device = self.device();
let buffer = device.new_buffer(dst_el, dtype);
Ok(Self {
buffer,
device: device.clone(),
dtype,
})
// todo!()
}
fn index_add(
&self,
_: &Layout,
_: &Self,
_: &Layout,
_: &Self,
_: &Layout,
_: usize,
) -> Result<Self> {
todo!()
}
fn matmul(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let transpose_left = false;
let transpose_right = false;
let alpha = 1.0;
let beta = 0.0;
self.matmul_generic(
rhs,
(b, m, n, k),
lhs_l,
rhs_l,
transpose_left,
transpose_right,
alpha,
beta,
)
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
let src_shape = src_l.shape();
let dims = src_shape.dims();
let el_count = src_shape.elem_count();
if el_count == 0 {
return Ok(());
}
if src_l.is_contiguous() {
let command_buffer = self.device.command_queue.new_command_buffer();
let blip = command_buffer.new_blit_command_encoder();
blip.copy_from_buffer(
&self.buffer,
src_l.start_offset() as u64,
&dst.buffer,
dst_offset as u64,
self.buffer.length(),
);
} else {
let command_buffer = self.device.command_queue.new_command_buffer();
let kernel_name = match self.dtype {
DType::F32 => candle_metal_kernels::unary::strided::copy::FLOAT,
DType::F16 => candle_metal_kernels::unary::strided::copy::HALF,
DType::BF16 => candle_metal_kernels::unary::strided::copy::BFLOAT,
dtype => todo!("copy_strided not implemented for {dtype:?}"),
};
candle_metal_kernels::call_unary_strided(
&self.device.device,
&command_buffer,
&self.device.kernels,
kernel_name,
src_l.dims(),
&self.buffer,
&src_l.stride(),
src_l.start_offset(),
&mut dst.buffer,
dst_offset,
)
.map_err(MetalError::from)?;
command_buffer.commit();
}
Ok(())
}
}
impl MetalStorage {
pub(crate) fn matmul_t(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
) -> Result<Self> {
let transpose_left = false;
let transpose_right = true;
let alpha = 1.0;
let beta = 0.0;
self.matmul_generic(
rhs,
(b, m, n, k),
lhs_l,
rhs_l,
transpose_left,
transpose_right,
alpha,
beta,
)
}
pub(crate) fn matmul_generic(
&self,
rhs: &Self,
(b, m, n, k): (usize, usize, usize, usize),
lhs_l: &Layout,
rhs_l: &Layout,
transpose_left: bool,
transpose_right: bool,
alpha: f64,
beta: f64,
) -> Result<Self> {
let elem_count = b * m * n;
match (self.dtype, rhs.dtype) {
(DType::F32, DType::F32) => {
let mut out_buffer = self.device.new_buffer(elem_count, self.dtype);
if b != 1 {
debug!("TODO implement batched matmul for B={b}");
// bail!("Didn't implemented strided matmul yet");
return Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
});
}
if !lhs_l.is_contiguous() || !rhs_l.is_contiguous() {
debug!(
"TODO non contiguous matmul yet {:?} {:?}",
lhs_l.is_contiguous(),
rhs_l.is_contiguous()
);
return Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
});
}
debug!("GEMM");
let command_buffer = self.device.command_queue.new_command_buffer();
encode_gemm::<Float32, Float32, Float32>(
&self.device,
&command_buffer,
transpose_left,
transpose_right,
&self.buffer,
&rhs.buffer,
&mut out_buffer,
m as NSUInteger,
n as NSUInteger,
k as NSUInteger,
alpha as f32,
beta as f32,
Some(b as NSUInteger),
)
.map_err(MetalError::from)?;
command_buffer.commit();
// command_buffer.wait_until_scheduled();
Ok(Self {
buffer: out_buffer,
device: self.device.clone(),
dtype: self.dtype(),
})
}
_ => todo!("Unimplemented matmul for this pair"),
}
}
}
impl BackendDevice for MetalDevice {
type Storage = MetalStorage;
fn new(ordinal: usize) -> Result<Self> {
let device = metal::Device::all().swap_remove(ordinal);
// let capture = metal::CaptureManager::shared();
// let descriptor = metal::CaptureDescriptor::new();
// descriptor.set_destination(metal::MTLCaptureDestination::GpuTraceDocument);
// descriptor.set_capture_device(&device);
// let mut dir = std::env::current_dir()?;
// dir.push("out.gputrace");
// descriptor.set_output_url(dir);
// capture
// .start_capture(&descriptor)
// .map_err(MetalError::from)?;
let command_queue = device.new_command_queue();
// let command_buffer = _command_queue.new_owned_command_buffer();
let kernels = Arc::new(Kernels::new());
Ok(Self {
device,
command_queue,
// command_buffer,
kernels,
})
}
fn set_seed(&self, _seed: u64) -> Result<()> {
todo!("set_seed")
}
fn location(&self) -> crate::DeviceLocation {
crate::DeviceLocation::Metal
}
fn same_device(&self, rhs: &Self) -> bool {
self.device.registry_id() == rhs.device.registry_id()
}
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.zeros_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
let option = metal::MTLResourceOptions::StorageModeManaged;
let buffer = match storage {
CpuStorage::U8(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u8>()) as u64,
option,
),
CpuStorage::U32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<u32>()) as u64,
option,
),
CpuStorage::I64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<i64>()) as u64,
option,
),
CpuStorage::BF16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<bf16>()) as u64,
option,
),
CpuStorage::F16(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f16>()) as u64,
option,
),
CpuStorage::F32(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f32>()) as u64,
option,
),
CpuStorage::F64(storage) => self.device.new_buffer_with_data(
storage.as_ptr() as *const core::ffi::c_void,
(storage.len() * mem::size_of::<f64>()) as u64,
option,
),
};
// debug!("Allocate 2 - buffer size {}", buffer.length());
Ok(Self::Storage {
buffer,
device: self.clone(),
dtype: storage.dtype(),
})
}
fn rand_uniform(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_uniform(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
}
fn rand_normal(
&self,
shape: &Shape,
dtype: DType,
mean: f64,
stddev: f64,
) -> Result<Self::Storage> {
// TODO is there a better way ?
let cpu_storage = crate::cpu_backend::CpuDevice.rand_normal(shape, dtype, mean, stddev)?;
self.storage_from_cpu_storage(&cpu_storage)
}
}

View File

@ -250,6 +250,8 @@ impl Tensor {
if header.fortran_order {
return Err(Error::Npy("fortran order not supported".to_string()));
}
let mut data: Vec<u8> = vec![];
reader.read_to_end(&mut data)?;
Self::from_reader(header.shape(), header.descr, &mut reader)
}

View File

@ -1,5 +1,5 @@
#![allow(clippy::redundant_closure_call)]
use crate::{CpuStorage, CudaStorage, Layout, MetalStorage, Result, Shape, Tensor};
use crate::{CpuStorage, CudaStorage, Layout, Result, Shape, Tensor};
use half::{bf16, f16};
use num_traits::float::Float;
@ -174,18 +174,6 @@ pub trait CustomOp1 {
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_storage: &MetalStorage,
_layout: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
/// This function takes as argument the argument `arg` used in the forward pass, the result
/// produced by the forward operation `res` and the gradient of the result `grad_res`.
/// The function should return the gradient of the argument.
@ -221,20 +209,6 @@ pub trait CustomOp2 {
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
@ -277,22 +251,6 @@ pub trait CustomOp3 {
))
}
/// The forward pass, as run on a metal gpu device. Note that the storage can use arbitrary strides,
/// offsets etc so the associated layout should be used to access it.
fn metal_fwd(
&self,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
_: &MetalStorage,
_: &Layout,
) -> Result<(MetalStorage, Shape)> {
Err(crate::Error::Metal(
format!("no metal implementation for {}", self.name()).into(),
))
}
fn bwd(
&self,
_arg1: &Tensor,
@ -578,6 +536,7 @@ unary_op!(Log, "log", v, v.ln(), vs_ln, vd_ln);
unary_op!(Sin, "sin", v, v.sin(), vs_sin, vd_sin);
unary_op!(Cos, "cos", v, v.cos(), vs_cos, vd_cos);
unary_op!(Tanh, "tanh", v, v.tanh(), vs_tanh, vd_tanh);
unary_op!(Abs, "abs", v, v.abs());
unary_op!(Neg, "neg", v, -v);
unary_op!(Recip, "recip", v, v.recip());
unary_op!(Sqr, "sqr", v, v * v, vs_sqr, vd_sqr);
@ -707,40 +666,6 @@ impl UnaryOpT for Erf {
}
}
impl UnaryOpT for Abs {
const NAME: &'static str = "abs";
const KERNEL: &'static str = "uabs";
const V: Self = Abs;
#[inline(always)]
fn bf16(v: bf16) -> bf16 {
v.abs()
}
#[inline(always)]
fn f16(v: f16) -> f16 {
v.abs()
}
#[inline(always)]
fn f32(v: f32) -> f32 {
v.abs()
}
#[inline(always)]
fn f64(v: f64) -> f64 {
v.abs()
}
#[inline(always)]
fn u8(v: u8) -> u8 {
v
}
#[inline(always)]
fn u32(v: u32) -> u32 {
v
}
#[inline(always)]
fn i64(v: i64) -> i64 {
v.abs()
}
}
impl UnaryOpT for Ceil {
const NAME: &'static str = "ceil";
const KERNEL: &'static str = "uceil";

View File

@ -193,50 +193,6 @@ impl Object {
_ => Err(self),
}
}
pub fn into_tensor_info(
self,
name: Self,
dir_name: &std::path::Path,
) -> Result<Option<TensorInfo>> {
let name = match name.unicode() {
Ok(name) => name,
Err(_) => return Ok(None),
};
let (callable, args) = match self.reduce() {
Ok(callable_args) => callable_args,
_ => return Ok(None),
};
let (callable, args) = match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._tensor" && class_name == "_rebuild_from_type_v2" => {
let mut args = args.tuple()?;
let callable = args.remove(0);
let args = args.remove(1);
(callable, args)
}
_ => (callable, args),
};
match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
_ => return Ok(None),
};
let (layout, dtype, file_path, storage_size) = rebuild_args(args)?;
let mut path = dir_name.to_path_buf();
path.push(file_path);
Ok(Some(TensorInfo {
name,
dtype,
layout,
path: path.to_string_lossy().into_owned(),
storage_size,
}))
}
}
impl TryFrom<Object> for String {
@ -609,7 +565,6 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
"HalfStorage" => DType::F16,
"BFloat16Storage" => DType::BF16,
"ByteStorage" => DType::U8,
"LongStorage" => DType::I64,
other => {
crate::bail!("unsupported storage type {other}")
}
@ -668,10 +623,50 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
};
if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() {
match value.into_tensor_info(name, &dir_name) {
Ok(Some(tensor_info)) => tensor_infos.push(tensor_info),
Ok(None) => {}
Err(err) => eprintln!("skipping: {err:?}"),
let name = match name.unicode() {
Ok(name) => name,
Err(_) => continue,
};
let (callable, args) = match value.reduce() {
Ok(callable_args) => callable_args,
_ => continue,
};
let (callable, args) = match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._tensor"
&& class_name == "_rebuild_from_type_v2" =>
{
let mut args = args.tuple()?;
let callable = args.remove(0);
let args = args.remove(1);
(callable, args)
}
_ => (callable, args),
};
match callable {
Object::Class {
module_name,
class_name,
} if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {}
_ => continue,
};
match rebuild_args(args) {
Ok((layout, dtype, file_path, storage_size)) => {
let mut path = dir_name.clone();
path.push(file_path);
tensor_infos.push(TensorInfo {
name,
dtype,
layout,
path: path.to_string_lossy().into_owned(),
storage_size,
})
}
Err(err) => {
eprintln!("skipping {name}: {err:?}")
}
}
}
}
@ -728,16 +723,3 @@ impl PthTensors {
Ok(Some(tensor))
}
}
/// Read all the tensors from a PyTorch pth file.
pub fn read_all<P: AsRef<std::path::Path>>(path: P) -> Result<Vec<(String, Tensor)>> {
let pth = PthTensors::new(path)?;
let tensor_names = pth.tensor_infos.keys();
let mut tensors = Vec::with_capacity(tensor_names.len());
for name in tensor_names {
if let Some(tensor) = pth.get(name)? {
tensors.push((name.to_string(), tensor))
}
}
Ok(tensors)
}

View File

@ -50,9 +50,14 @@ pub(crate) unsafe fn mul_sum_i8_pairs_float(x: __m256i, y: __m256i) -> __m256 {
#[inline(always)]
pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = _mm256_setzero_ps();
for (x, y) in xs.iter().zip(ys.iter()) {

View File

@ -1,7 +1,7 @@
//! Support for the GGML file format.
use super::{k_quants, GgmlDType};
use crate::{Device, Result};
use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt};
use std::collections::HashMap;
@ -121,12 +121,11 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
raw_data: &[u8],
size_in_bytes: usize,
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let raw_data_ptr = raw_data.as_ptr();
let n_blocks = size_in_bytes / std::mem::size_of::<T>();
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
super::QTensor::new(data.to_vec(), dims, device)
super::QTensor::new(data.to_vec(), dims)
}
/// Creates a [Tensor] from a raw GGML tensor.
@ -134,7 +133,6 @@ pub fn qtensor_from_ggml(
ggml_dtype: GgmlDType,
raw_data: &[u8],
dims: Vec<usize>,
device: &Device,
) -> Result<super::QTensor> {
let tensor_elems = dims.iter().product::<usize>();
let blck_size = ggml_dtype.blck_size();
@ -146,38 +144,18 @@ pub fn qtensor_from_ggml(
let size_in_bytes = tensor_elems / blck_size * ggml_dtype.type_size();
match ggml_dtype {
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims, device),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims, device),
GgmlDType::Q4_0 => {
from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4_1 => {
from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_0 => {
from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5_1 => {
from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q8_0 => {
from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q2K => {
from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q3K => {
from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q4K => {
from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q5K => {
from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::Q6K => {
from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims, device)
}
GgmlDType::F32 => from_raw_data::<f32>(raw_data, size_in_bytes, dims),
GgmlDType::F16 => from_raw_data::<half::f16>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_0 => from_raw_data::<k_quants::BlockQ4_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q4_1 => from_raw_data::<k_quants::BlockQ4_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_0 => from_raw_data::<k_quants::BlockQ5_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q5_1 => from_raw_data::<k_quants::BlockQ5_1>(raw_data, size_in_bytes, dims),
GgmlDType::Q8_0 => from_raw_data::<k_quants::BlockQ8_0>(raw_data, size_in_bytes, dims),
GgmlDType::Q2K => from_raw_data::<k_quants::BlockQ2K>(raw_data, size_in_bytes, dims),
GgmlDType::Q3K => from_raw_data::<k_quants::BlockQ3K>(raw_data, size_in_bytes, dims),
GgmlDType::Q4K => from_raw_data::<k_quants::BlockQ4K>(raw_data, size_in_bytes, dims),
GgmlDType::Q5K => from_raw_data::<k_quants::BlockQ5K>(raw_data, size_in_bytes, dims),
GgmlDType::Q6K => from_raw_data::<k_quants::BlockQ6K>(raw_data, size_in_bytes, dims),
_ => crate::bail!("quantized type {ggml_dtype:?} is not supported yet"),
}
}
@ -185,7 +163,6 @@ pub fn qtensor_from_ggml(
fn read_one_tensor<R: std::io::Seek + std::io::Read>(
reader: &mut R,
magic: VersionedMagic,
device: &Device,
) -> Result<(String, super::QTensor)> {
let n_dims = reader.read_u32::<LittleEndian>()?;
let name_len = reader.read_u32::<LittleEndian>()?;
@ -210,7 +187,7 @@ fn read_one_tensor<R: std::io::Seek + std::io::Read>(
// TODO: Mmap version to avoid copying the data around?
let mut raw_data = vec![0u8; size_in_bytes];
reader.read_exact(&mut raw_data)?;
match qtensor_from_ggml(ggml_dtype, &raw_data, dims, device) {
match qtensor_from_ggml(ggml_dtype, &raw_data, dims) {
Ok(tensor) => Ok((name, tensor)),
Err(e) => crate::bail!("Error creating tensor {name}: {e}"),
}
@ -224,10 +201,7 @@ pub struct Content {
}
impl Content {
pub fn read<R: std::io::Seek + std::io::Read>(
reader: &mut R,
device: &Device,
) -> Result<Content> {
pub fn read<R: std::io::Seek + std::io::Read>(reader: &mut R) -> Result<Content> {
// https://github.com/ggerganov/llama.cpp/blob/468ea24fb4633a0d681f7ac84089566c1c6190cb/llama.cpp#L505
let last_position = reader.seek(std::io::SeekFrom::End(0))?;
reader.seek(std::io::SeekFrom::Start(0))?;
@ -237,7 +211,7 @@ impl Content {
let mut tensors = HashMap::new();
while reader.stream_position()? != last_position {
let (name, tensor) = read_one_tensor(reader, magic, device)?;
let (name, tensor) = read_one_tensor(reader, magic)?;
tensors.insert(name, tensor);
}
Ok(Self {

View File

@ -3,7 +3,7 @@
//! Spec: https://github.com/philpax/ggml/blob/gguf-spec/docs/gguf.md
use super::{GgmlDType, QTensor};
use crate::{Device, Result};
use crate::Result;
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use std::collections::HashMap;
@ -57,7 +57,6 @@ impl TensorInfo {
&self,
reader: &mut R,
tensor_data_offset: u64,
device: &Device,
) -> Result<QTensor> {
let tensor_elems = self.shape.elem_count();
let blck_size = self.ggml_dtype.blck_size();
@ -70,12 +69,7 @@ impl TensorInfo {
let mut raw_data = vec![0u8; size_in_bytes];
reader.seek(std::io::SeekFrom::Start(tensor_data_offset + self.offset))?;
reader.read_exact(&mut raw_data)?;
super::ggml_file::qtensor_from_ggml(
self.ggml_dtype,
&raw_data,
self.shape.dims().to_vec(),
device,
)
super::ggml_file::qtensor_from_ggml(self.ggml_dtype, &raw_data, self.shape.dims().to_vec())
}
}
@ -456,13 +450,12 @@ impl Content {
&self,
reader: &mut R,
name: &str,
device: &Device,
) -> Result<QTensor> {
let tensor_info = match self.tensor_infos.get(name) {
Some(tensor_info) => tensor_info,
None => crate::bail!("cannot find tensor-infor for {name}"),
};
tensor_info.read(reader, self.tensor_data_offset, device)
tensor_info.read(reader, self.tensor_data_offset)
}
}

View File

@ -236,9 +236,14 @@ impl GgmlType for BlockQ4_0 {
fn vec_dot_unopt(n: usize, xs: &[Self], ys: &[Self::VecDotType]) -> Result<f32> {
let qk = QK8_0;
let nb = n / qk;
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
// Generic implementation.
let mut sumf = 0f32;
for (xs, ys) in xs.iter().zip(ys.iter()) {

View File

@ -1,5 +1,4 @@
use crate::{Device, Result, Shape, Tensor};
use tracing::debug;
#[cfg(target_feature = "avx")]
pub mod avx;
@ -15,7 +14,6 @@ pub mod utils;
pub use k_quants::GgmlType;
pub struct QTensor {
device: Device,
data: Box<dyn QuantizedType>,
shape: Shape,
}
@ -172,20 +170,17 @@ impl QTensor {
pub fn new<S: Into<Shape>, T: k_quants::GgmlType + Send + Sync + 'static>(
data: Vec<T>,
shape: S,
device: &Device,
) -> Result<Self> {
let shape = shape.into();
check_shape::<T>(&shape)?;
Ok(Self {
data: Box::new(data),
shape,
device: device.clone(),
})
}
pub fn quantize<T: k_quants::GgmlType + Send + Sync + 'static>(src: &Tensor) -> Result<Self> {
let shape = src.shape();
let device = src.device();
check_shape::<T>(shape)?;
let src = src
.to_dtype(crate::DType::F32)?
@ -202,7 +197,6 @@ impl QTensor {
Ok(Self {
data: Box::new(data),
shape: shape.clone(),
device: device.clone(),
})
}
@ -218,12 +212,7 @@ impl QTensor {
&self.shape
}
pub fn device(&self) -> &Device {
&self.device
}
pub fn dequantize(&self, device: &Device) -> Result<Tensor> {
// TODO Skip the CPU part on metal
let mut f32_data = vec![0f32; self.shape.elem_count()];
self.data.to_float(&mut f32_data)?;
Tensor::from_vec(f32_data, &self.shape, device)
@ -316,46 +305,6 @@ impl crate::CustomOp1 for QTensor {
)?;
Ok((crate::CpuStorage::F32(dst_storage), dst_shape))
}
fn metal_fwd(
&self,
storage: &crate::MetalStorage,
layout: &crate::Layout,
) -> Result<(crate::MetalStorage, Shape)> {
debug!("TODO qmatmul");
if !layout.is_contiguous() {
crate::bail!("input tensor is not contiguous {layout:?}")
}
let src_shape = layout.shape();
// self is transposed so n is first then k.
let (n, k) = self.shape.dims2()?;
if src_shape.rank() < 2 {
crate::bail!("input tensor has only one dimension {layout:?}")
}
let mut dst_shape = src_shape.dims().to_vec();
let last_k = dst_shape.pop().unwrap();
if last_k != k {
crate::bail!("input tensor {layout:?} incompatible with {:?}", self.shape)
}
dst_shape.push(n);
let dst_shape = Shape::from(dst_shape);
// let storage = storage.as_slice::<f32>()?;
// let storage =
// &storage[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
let dst_storage = vec![0f32; dst_shape.elem_count()];
// self.matmul_t(
// (dst_shape.elem_count() / n, k, n),
// storage,
// &mut dst_storage,
// )?;
let cpu_storage = crate::CpuStorage::F32(dst_storage);
use crate::backend::{BackendDevice, BackendStorage};
if let Device::Metal(device) = &self.device {
Ok((device.storage_from_cpu_storage(&cpu_storage)?, dst_shape))
} else {
crate::bail!("qtensor not on metal device")
}
}
}
impl QMatMul {

View File

@ -19,29 +19,42 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
for i in 0..nb {
let mut sumv1 = vdupq_n_f32(0.0f32);
for i in (0..nb).step_by(2) {
let x0 = &xs[i];
let x1 = &xs[i + 1];
let y0 = &ys[i];
let y1 = &ys[i + 1];
let m4b = vdupq_n_u8(0x0F);
let s8b = vdupq_n_s8(0x8);
let v0_0 = vld1q_u8(x0.qs.as_ptr());
let v0_1 = vld1q_u8(x1.qs.as_ptr());
// 4-bit -> 8-bit
let v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
let v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
let v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
let v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// sub 8
let v0_0ls = vsubq_s8(v0_0l, s8b);
let v0_0hs = vsubq_s8(v0_0h, s8b);
let v0_1ls = vsubq_s8(v0_1l, s8b);
let v0_1hs = vsubq_s8(v0_1h, s8b);
// load y
let v1_0l = vld1q_s8(y0.qs.as_ptr());
let v1_0h = vld1q_s8(y0.qs.as_ptr().add(16));
let v1_1l = vld1q_s8(y1.qs.as_ptr());
let v1_1h = vld1q_s8(y1.qs.as_ptr().add(16));
// TODO: Support dotprod when it's available outside of nightly.
let pl0l = vmull_s8(vget_low_s8(v0_0ls), vget_low_s8(v1_0l));
@ -49,16 +62,28 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
let ph0l = vmull_s8(vget_low_s8(v0_0hs), vget_low_s8(v1_0h));
let ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
let pl1l = vmull_s8(vget_low_s8(v0_1ls), vget_low_s8(v1_1l));
let pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
let ph1l = vmull_s8(vget_low_s8(v0_1hs), vget_low_s8(v1_1h));
let ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
let pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
let ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
let pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
let ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(pl0, ph0)),
x0.d.to_f32() * y0.d.to_f32(),
);
sumv1 = vmlaq_n_f32(
sumv1,
vcvtq_f32_s32(vaddq_s32(pl1, ph1)),
x1.d.to_f32() * y1.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0))
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
}
}
@ -69,18 +94,28 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
}
unsafe {
let mut sumv0 = vdupq_n_f32(0.0f32);
for i in 0..nb {
let mut sumv1 = vdupq_n_f32(0.0f32);
for i in (0..nb).step_by(2) {
let x0 = &xs[i];
let x1 = &xs[i + 1];
let y0 = &ys[i];
let y1 = &ys[i + 1];
let x0_0 = vld1q_s8(x0.qs.as_ptr());
let x0_1 = vld1q_s8(x0.qs.as_ptr().add(16));
let x1_0 = vld1q_s8(x1.qs.as_ptr());
let x1_1 = vld1q_s8(x1.qs.as_ptr().add(16));
// load y
let y0_0 = vld1q_s8(y0.qs.as_ptr());
let y0_1 = vld1q_s8(y0.qs.as_ptr().add(16));
let y1_0 = vld1q_s8(y1.qs.as_ptr());
let y1_1 = vld1q_s8(y1.qs.as_ptr().add(16));
// TODO dotprod once this is the intrinsics are.
let p0_0 = vmull_s8(vget_low_s8(x0_0), vget_low_s8(y0_0));
@ -88,16 +123,28 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
let p0_2 = vmull_s8(vget_low_s8(x0_1), vget_low_s8(y0_1));
let p0_3 = vmull_s8(vget_high_s8(x0_1), vget_high_s8(y0_1));
let p1_0 = vmull_s8(vget_low_s8(x1_0), vget_low_s8(y1_0));
let p1_1 = vmull_s8(vget_high_s8(x1_0), vget_high_s8(y1_0));
let p1_2 = vmull_s8(vget_low_s8(x1_1), vget_low_s8(y1_1));
let p1_3 = vmull_s8(vget_high_s8(x1_1), vget_high_s8(y1_1));
let p0 = vaddq_s32(vpaddlq_s16(p0_0), vpaddlq_s16(p0_1));
let p1 = vaddq_s32(vpaddlq_s16(p0_2), vpaddlq_s16(p0_3));
let p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1));
let p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3));
sumv0 = vmlaq_n_f32(
sumv0,
vcvtq_f32_s32(vaddq_s32(p0, p1)),
x0.d.to_f32() * y0.d.to_f32(),
);
sumv1 = vmlaq_n_f32(
sumv1,
vcvtq_f32_s32(vaddq_s32(p2, p3)),
x1.d.to_f32() * y1.d.to_f32(),
);
}
Ok(vaddvq_f32(sumv0))
Ok(vaddvq_f32(sumv0) + vaddvq_f32(sumv1))
}
}

View File

@ -11,6 +11,10 @@ pub(crate) fn vec_dot_q4_0_q8_0(n: usize, xs: &[BlockQ4_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q4_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {
@ -57,6 +61,10 @@ pub(crate) fn vec_dot_q8_0_q8_0(n: usize, xs: &[BlockQ8_0], ys: &[BlockQ8_0]) ->
if n % QK8_0 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {n} is not divisible by {qk}")
}
let nb = n / QK8_0;
if nb % 2 != 0 {
crate::bail!("vec_dot_q8_0_q8_0: {nb} is not even")
}
unsafe {
let mut acc = f32x4_splat(0.0f32);
for (x, y) in xs.iter().zip(ys.iter()) {

View File

@ -203,7 +203,7 @@ impl Shape {
/// Check whether the two shapes are compatible for broadcast, and if it is the case return the
/// broadcasted shape. This is to be used for binary pointwise ops.
pub fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
pub(crate) fn broadcast_shape_binary_op(&self, rhs: &Self, op: &'static str) -> Result<Shape> {
let lhs = self;
let lhs_dims = lhs.dims();
let rhs_dims = rhs.dims();
@ -511,119 +511,154 @@ impl ShapeWithOneHole for ((),) {
}
}
fn hole_size(el_count: usize, prod_d: usize, s: &dyn std::fmt::Debug) -> Result<usize> {
if prod_d == 0 {
crate::bail!("cannot reshape tensor of {el_count} elements to {s:?}")
}
if el_count % prod_d != 0 {
crate::bail!("cannot reshape tensor with {el_count} elements to {s:?}")
}
Ok(el_count / prod_d)
}
impl ShapeWithOneHole for ((), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1) = self;
Ok((hole_size(el_count, d1, &self)?, d1).into())
if el_count % d1 != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((el_count / d1, d1).into())
}
}
impl ShapeWithOneHole for (usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, ()) = self;
Ok((d1, hole_size(el_count, d1, &self)?).into())
if el_count % d1 != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d1}")
}
Ok((d1, el_count / d1).into())
}
}
impl ShapeWithOneHole for ((), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2) = self;
Ok((hole_size(el_count, d1 * d2, &self)?, d1, d2).into())
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2).into())
}
}
impl ShapeWithOneHole for (usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2) = self;
Ok((d1, hole_size(el_count, d1 * d2, &self)?, d2).into())
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2).into())
}
}
impl ShapeWithOneHole for (usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, ()) = self;
Ok((d1, d2, hole_size(el_count, d1 * d2, &self)?).into())
let d = d1 * d2;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d).into())
}
}
impl ShapeWithOneHole for ((), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3) = self;
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
Ok((d, d1, d2, d3).into())
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3).into())
}
}
impl ShapeWithOneHole for (usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3) = self;
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
Ok((d1, d, d2, d3).into())
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3).into())
}
}
impl ShapeWithOneHole for (usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3) = self;
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
Ok((d1, d2, d, d3).into())
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, ()) = self;
let d = hole_size(el_count, d1 * d2 * d3, &self)?;
Ok((d1, d2, d3, d).into())
let d = d1 * d2 * d3;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d).into())
}
}
impl ShapeWithOneHole for ((), usize, usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let ((), d1, d2, d3, d4) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d, d1, d2, d3, d4).into())
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((el_count / d, d1, d2, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, (), usize, usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, (), d2, d3, d4) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d1, d, d2, d3, d4).into())
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, el_count / d, d2, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, (), usize, usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, (), d3, d4) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d1, d2, d, d3, d4).into())
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, el_count / d, d3, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, (), usize) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, (), d4) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d1, d2, d3, d, d4).into())
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, el_count / d, d4).into())
}
}
impl ShapeWithOneHole for (usize, usize, usize, usize, ()) {
fn into_shape(self, el_count: usize) -> Result<Shape> {
let (d1, d2, d3, d4, ()) = self;
let d = hole_size(el_count, d1 * d2 * d3 * d4, &self)?;
Ok((d1, d2, d3, d4, d).into())
let d = d1 * d2 * d3 * d4;
if el_count % d != 0 {
crate::bail!("tensor number of elements {el_count} is not divisible by {d}")
}
Ok((d1, d2, d3, d4, el_count / d).into())
}
}

View File

@ -1,6 +1,6 @@
use crate::backend::BackendStorage;
use crate::op::{self, CmpOp, CustomOp1, CustomOp2, CustomOp3, ReduceOp};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, Result, Shape};
// We do not want to implement Clone on Storage as cloning may fail because of
// out of memory. Instead try_clone should be used.
@ -8,7 +8,6 @@ use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage,
pub enum Storage {
Cpu(CpuStorage),
Cuda(CudaStorage),
Metal(MetalStorage),
}
impl Storage {
@ -19,10 +18,6 @@ impl Storage {
let storage = storage.try_clone(layout)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.try_clone(layout)?;
Ok(Self::Metal(storage))
}
}
}
@ -30,7 +25,6 @@ impl Storage {
match self {
Self::Cpu(_) => Device::Cpu,
Self::Cuda(storage) => Device::Cuda(storage.device().clone()),
Self::Metal(storage) => Device::Metal(storage.device().clone()),
}
}
@ -38,7 +32,6 @@ impl Storage {
match self {
Self::Cpu(storage) => storage.dtype(),
Self::Cuda(storage) => storage.dtype(),
Self::Metal(storage) => storage.dtype(),
}
}
@ -72,10 +65,6 @@ impl Storage {
let storage = storage.affine(layout, mul, add)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.affine(layout, mul, add)?;
Ok(Self::Metal(storage))
}
}
}
@ -89,10 +78,6 @@ impl Storage {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.powf(layout, alpha)?;
Ok(Self::Metal(storage))
}
}
}
@ -106,10 +91,6 @@ impl Storage {
let storage = storage.elu(layout, alpha)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.elu(layout, alpha)?;
Ok(Self::Metal(storage))
}
}
}
@ -131,10 +112,6 @@ impl Storage {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.cmp(op, rhs, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
@ -158,10 +135,6 @@ impl Storage {
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.reduce_op(op, layout, s)?;
Ok(Self::Metal(storage))
}
}
}
@ -175,10 +148,6 @@ impl Storage {
let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.to_dtype(layout, dtype)?;
Ok(Self::Metal(storage))
}
}
}
@ -192,10 +161,6 @@ impl Storage {
let (storage, shape) = c.cuda_fwd(storage, l)?;
Ok((Self::Cuda(storage), shape))
}
Self::Metal(storage) => {
let (storage, shape) = c.metal_fwd(storage, l)?;
Ok((Self::Metal(storage), shape))
}
}
}
@ -216,10 +181,6 @@ impl Storage {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2)?;
Ok((Self::Cuda(s), shape))
}
(Self::Metal(s1), Self::Metal(s2)) => {
let (s, shape) = c.metal_fwd(s1, l1, s2, l2)?;
Ok((Self::Metal(s), shape))
}
_ => unreachable!(),
}
}
@ -244,10 +205,6 @@ impl Storage {
let (s, shape) = c.cuda_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Cuda(s), shape))
}
(Self::Metal(s1), Self::Metal(s2), Self::Metal(s3)) => {
let (s, shape) = c.metal_fwd(s1, l1, s2, l2, s3, l3)?;
Ok((Self::Metal(s), shape))
}
_ => unreachable!(),
}
}
@ -262,10 +219,6 @@ impl Storage {
let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.unary_impl::<B>(layout)?;
Ok(Self::Metal(storage))
}
}
}
@ -286,10 +239,6 @@ impl Storage {
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.binary_impl::<B>(rhs, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => {
// Should not happen because of the same device check above but we're defensive
// anyway.
@ -321,10 +270,6 @@ impl Storage {
let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv1d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -352,10 +297,6 @@ impl Storage {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv2d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -383,10 +324,6 @@ impl Storage {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Cuda(s))
}
(Storage::Metal(inp), Storage::Metal(kernel)) => {
let s = inp.conv_transpose2d(l, kernel, kernel_l, params)?;
Ok(Self::Metal(s))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -411,10 +348,6 @@ impl Storage {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.avg_pool2d(layout, kernel_size, stride)?;
Ok(Self::Metal(storage))
}
}
}
@ -433,10 +366,6 @@ impl Storage {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.max_pool2d(layout, kernel_size, stride)?;
Ok(Self::Metal(storage))
}
}
}
@ -450,10 +379,6 @@ impl Storage {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.upsample_nearest1d(layout, sz)?;
Ok(Self::Metal(storage))
}
}
}
@ -467,10 +392,6 @@ impl Storage {
let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Cuda(storage))
}
Self::Metal(storage) => {
let storage = storage.upsample_nearest2d(layout, h, w)?;
Ok(Self::Metal(storage))
}
}
}
@ -494,10 +415,6 @@ impl Storage {
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(cond), Self::Metal(t), Self::Metal(f)) => {
let storage = cond.where_cond(layout, t, layout_t, f, layout_f)?;
Ok(Self::Metal(storage))
}
(_, lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -524,10 +441,6 @@ impl Storage {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes)) => {
let storage = s.gather(l, indexes, indexes_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
@ -552,10 +465,6 @@ impl Storage {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
@ -580,10 +489,6 @@ impl Storage {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
let storage = s.index_add(l, indexes, indexes_l, source, source_l, d)?;
Ok(Self::Metal(storage))
}
_ => unreachable!(),
}
}
@ -605,10 +510,6 @@ impl Storage {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.index_select(rhs, lhs_l, rhs_l, d)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -636,10 +537,6 @@ impl Storage {
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Cuda(storage))
}
(Self::Metal(lhs), Self::Metal(rhs)) => {
let storage = lhs.matmul(rhs, bmnk, lhs_layout, rhs_layout)?;
Ok(Self::Metal(storage))
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),
@ -659,9 +556,6 @@ impl Storage {
match (self, dst) {
(Self::Cpu(src), Self::Cpu(dst)) => src.copy_strided_src(dst, dst_offset, src_l),
(Self::Cuda(src), Self::Cuda(dst)) => Ok(src.copy_strided_src(dst, dst_offset, src_l)?),
(Self::Metal(src), Self::Metal(dst)) => {
Ok(src.copy_strided_src(dst, dst_offset, src_l)?)
}
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
lhs: lhs.device().location(),
rhs: rhs.device().location(),

View File

@ -6,7 +6,7 @@ use crate::op::{
};
use crate::scalar::TensorOrScalar;
use crate::shape::{Dim, Dims};
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
use crate::{storage::Storage, DType, Device, Error, Layout, Result, Shape};
use std::sync::{Arc, RwLock};
/// Unique identifier for tensors.
@ -385,21 +385,11 @@ impl Tensor {
step: D,
device: &Device,
) -> Result<Self> {
if D::is_zero(&step) {
crate::bail!("step cannot be zero")
}
let mut data = vec![];
let mut current = start;
if step >= D::zero() {
while current < end {
data.push(current);
current += step;
}
} else {
while current > end {
data.push(current);
current += step;
}
while current < end {
data.push(current);
current += step;
}
let len = data.len();
Self::from_vec_impl(data, len, device, false)
@ -459,7 +449,7 @@ impl Tensor {
/// Returns true if the computation graph should track this op, that is if it is
/// a variable or if it has some variable as dependencies.
pub fn track_op(&self) -> bool {
pub(crate) fn track_op(&self) -> bool {
self.is_variable || self.op.is_some()
}
@ -523,7 +513,6 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(cpu_storage) => from_cpu_storage(cpu_storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@ -551,73 +540,6 @@ impl Tensor {
Ok(inp)
}
/// Creates grids of coordinates specified by the 1D inputs.
///
/// # Arguments
///
/// * `args` - A slice of 1D tensors.
/// * `xy_indexing` - Whether to use xy indexing or ij indexing. If xy is selected, the
/// first dimension corresponds to the cardinality of the second input and the second
/// dimension corresponds to the cardinality of the first input. If ij is selected, the
/// dimensions are in the same order as the cardinality of the inputs.
///
/// # Examples
///
/// ```rust
/// use candle_core::{Tensor, Device, Shape};
/// let x = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?;
/// let y = Tensor::new(&[4f32, 5., 6.], &Device::Cpu)?;
///
/// let grids_xy = Tensor::meshgrid(&[&x, &y], true)?;
///
/// assert_eq!(grids_xy.len(), 2);
/// assert_eq!(grids_xy[0].dims(), &[3, 3]);
///
/// assert_eq!(grids_xy[0].to_vec2::<f32>()?, &[[1., 2., 3.], [1., 2., 3.], [1., 2., 3.]]);
/// assert_eq!(grids_xy[1].to_vec2::<f32>()?, &[[4., 4., 4.], [5., 5., 5.], [6., 6., 6.]]);
///
/// let grids_ij = Tensor::meshgrid(&[&x, &y], false)?;
///
/// assert_eq!(grids_ij[0].to_vec2::<f32>()?, &[[1., 1., 1.], [2., 2., 2.], [3., 3., 3.]]);
/// assert_eq!(grids_ij[1].to_vec2::<f32>()?, &[[4., 5., 6.], [4., 5., 6.], [4., 5., 6.]]);
/// # Ok::<(), candle_core::Error>(())
/// ```
///
/// # Errors
///
/// * Will return `Err` if `args` contains less than 2 tensors.
///
pub fn meshgrid<A: AsRef<Tensor>>(args: &[A], xy_indexing: bool) -> Result<Vec<Self>> {
if args.len() <= 1 {
Err(Error::OpRequiresAtLeastTwoTensors { op: "meshgrid" }.bt())?
}
let args: Vec<_> = if xy_indexing {
args.iter().rev().collect()
} else {
args.iter().collect()
};
let mut shape = Vec::with_capacity(args.len());
for arg in args.iter() {
shape.push(arg.as_ref().dims1()?)
}
let mut grids = Vec::with_capacity(args.len());
for idx in 0..args.len() {
let mut ones = vec![1usize; args.len()];
ones[idx] = shape[idx];
let arg = args[idx].as_ref().reshape(ones)?;
let mut repeats = shape.clone();
repeats[idx] = 1;
let repeated_tensor = arg.repeat(repeats)?;
grids.push(repeated_tensor);
}
if xy_indexing {
grids.reverse();
}
Ok(grids)
}
/// This operation multiplies the input tensor by `mul` then adds `add` and return the result.
/// The input values `mul` and `add` are casted to the appropriate type so some rounding might
/// be performed.
@ -693,23 +615,15 @@ impl Tensor {
pub fn narrow<D: Dim>(&self, dim: D, start: usize, len: usize) -> Result<Self> {
let dims = self.dims();
let dim = dim.to_index(self.shape(), "narrow")?;
let err = |msg| {
Err::<(), _>(
Error::NarrowInvalidArgs {
shape: self.shape().clone(),
dim,
start,
len,
msg,
}
.bt(),
)
};
if start > dims[dim] {
err("start > dim_len")?
}
if start.saturating_add(len) > dims[dim] {
err("start + len > dim_len")?
if start + len > dims[dim] {
Err(Error::NarrowInvalidArgs {
shape: self.shape().clone(),
dim,
start,
len,
msg: "start + len > dim_len",
}
.bt())?
}
if start == 0 && dims[dim] == len {
Ok(self.clone())
@ -1197,16 +1111,14 @@ impl Tensor {
op: "scatter-add (self, src)",
lhs: self.shape().clone(),
rhs: source.shape().clone(),
}
.bt())?
})?
}
if indexes.dims() != source.dims() {
Err(Error::ShapeMismatchBinaryOp {
op: "scatter-add (indexes, src)",
lhs: indexes.shape().clone(),
rhs: source.shape().clone(),
}
.bt())?
})?
}
let storage = self.storage().scatter_add(
self.layout(),
@ -1278,8 +1190,7 @@ impl Tensor {
op: "slice-scatter (self, src)",
lhs: self.shape().clone(),
rhs: src.shape().clone(),
}
.bt())?
})?
}
let mut storage = self.device().zeros(self.shape(), self.dtype())?;
self.storage()
@ -1313,8 +1224,7 @@ impl Tensor {
op: "index-add (self, source)",
lhs: self.shape().clone(),
rhs: source.shape().clone(),
}
.bt())?
})?
}
// The number of element in indexes must match the dimension on which the add is
// performed on the source tensor (and the index values from `indexes` are taken from
@ -1325,8 +1235,7 @@ impl Tensor {
op: "index-add (ids, source))",
lhs: indexes.shape().clone(),
rhs: source.shape().clone(),
}
.bt())?
})?
}
let storage = self.storage().index_add(
self.layout(),
@ -1374,8 +1283,7 @@ impl Tensor {
op: "gather",
lhs: self.shape().clone(),
rhs: indexes.shape().clone(),
}
.bt())?
})?
}
let storage =
self.storage()
@ -1449,7 +1357,6 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@ -1480,7 +1387,6 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@ -1521,7 +1427,6 @@ impl Tensor {
match &*self.storage() {
Storage::Cpu(storage) => from_cpu_storage(storage),
Storage::Cuda(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
Storage::Metal(storage) => from_cpu_storage(&storage.to_cpu_storage()?),
}
}
@ -1685,24 +1590,6 @@ impl Tensor {
}
}
/// Returns the sub-tensor fixing the index at `index` on the dimension `dim`.
///
/// ```rust
/// use candle_core::{Tensor, Device};
/// let tensor = Tensor::new(&[[0f32, 1.], [2., 3.], [4., 5.]], &Device::Cpu)?;
/// let t = tensor.get_on_dim(1, 0)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[0., 2., 4.]);
/// let t = tensor.get_on_dim(1, 1)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[1., 3., 5.]);
/// let t = tensor.get_on_dim(0, 1)?;
/// assert_eq!(t.to_vec1::<f32>()?, &[2., 3.]);
/// # Ok::<(), candle_core::Error>(())
/// ```
pub fn get_on_dim<D: Dim>(&self, dim: D, index: usize) -> Result<Tensor> {
let dim = dim.to_index(self.shape(), "get_on_dim")?;
self.narrow(dim, index, 1)?.squeeze(dim)
}
/// Returns a tensor that is a transposed version of the input, the two last dimensions of the
/// input are swapped.
///
@ -1841,9 +1728,6 @@ impl Tensor {
Storage::Cuda(cuda.storage_from_cpu_storage(&cpu_storage)?)
}
(Storage::Cpu(storage), Device::Cpu) => Storage::Cpu(storage.clone()),
_ => {
bail!("not implemented yet")
}
};
let op = BackpropOp::new1(self, Op::ToDevice);
let tensor_ = Tensor_ {
@ -2243,56 +2127,11 @@ impl Tensor {
}
}
/// Pad the input tensor using same values along dimension `dim`. This adds `left` elements before the
/// input tensor values and `right` elements after.
pub fn pad_with_same<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
if left == 0 && right == 0 {
Ok(self.clone())
} else if self.elem_count() == 0 {
crate::bail!("cannot use pad_with_same on an empty tensor")
} else if left == 0 {
let dim = dim.to_index(self.shape(), "pad_with_same")?;
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
let mut v = vec![self];
for _ in 0..right {
v.push(&r)
}
Tensor::cat(&v, dim)
} else if right == 0 {
let dim = dim.to_index(self.shape(), "pad_with_same")?;
let l = self.narrow(dim, 0, 1)?;
let mut v = vec![];
for _ in 0..left {
v.push(&l)
}
v.push(self);
Tensor::cat(&v, dim)
} else {
let dim = dim.to_index(self.shape(), "pad_with_same")?;
let l = self.narrow(dim, 0, 1)?;
let r = self.narrow(dim, self.dim(dim)? - 1, 1)?;
let mut v = vec![];
for _ in 0..left {
v.push(&l)
}
v.push(self);
for _ in 0..right {
v.push(&r)
}
Tensor::cat(&v, dim)
}
}
/// Run the `forward` method of `m` on `self`.
pub fn apply<M: crate::Module>(&self, m: &M) -> Result<Self> {
m.forward(self)
}
/// Run the `forward` method of `m` on `self`.
pub fn apply_t<M: crate::ModuleT>(&self, m: &M, train: bool) -> Result<Self> {
m.forward_t(self, train)
}
pub(crate) fn storage(&self) -> std::sync::RwLockReadGuard<'_, Storage> {
self.storage.read().unwrap()
}

View File

@ -23,10 +23,6 @@ pub fn cuda_is_available() -> bool {
cfg!(feature = "cuda")
}
pub fn metal_is_available() -> bool {
cfg!(feature = "metal")
}
pub fn with_avx() -> bool {
cfg!(target_feature = "avx")
}

View File

@ -479,71 +479,6 @@ fn conv2d_grad(dev: &Device) -> Result<()> {
]
]
);
// Replicate the issue from https://github.com/huggingface/candle/issues/1212
let res = t.i((.., .., 0..4, 0..4))?.conv2d(&w, 0, 2, 1, 1)?;
let loss = res.sqr()?.sum_all()?;
assert_eq!(test_utils::to_vec0_round(&loss, 2)?, 21.12f32);
let grads = loss.backward()?;
let grad_t = grads.get(&t).unwrap();
let grad_w = grads.get(&w).unwrap();
assert_eq!(grad_t.dims(), [1, 4, 5, 5]);
assert_eq!(grad_w.dims(), [2, 4, 3, 3]);
assert_eq!(
test_utils::to_vec3_round(&grad_t.i(0)?, 2)?,
[
[
[9.29, -7.03, 7.87, 0.0, 0.0],
[-1.8, -7.82, 5.9, 0.0, 0.0],
[-3.12, 4.49, 5.52, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[21.73, 3.39, 4.77, 0.0, 0.0],
[8.25, 3.73, 27.61, 0.0, 0.0],
[-20.55, -5.61, -2.77, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[-8.98, 9.91, -7.15, 0.0, 0.0],
[4.93, -0.33, 4.56, 0.0, 0.0],
[-6.7, -5.76, -8.05, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
],
[
[23.54, 6.98, -10.0, 0.0, 0.0],
[9.65, 6.18, 18.72, 0.0, 0.0],
[3.29, -5.27, 0.79, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0, 0.0]
]
]
);
assert_eq!(
test_utils::to_vec3_round(&grad_w.i(0)?, 2)?,
[
[
[-3.47, 7.44, 0.66],
[12.89, -3.4, -9.29],
[-14.16, -0.83, 7.14]
],
[
[-3.23, 5.37, -3.02],
[-2.12, -11.24, 1.94],
[6.97, 7.2, 2.99]
],
[
[-4.04, -3.31, 4.87],
[-6.68, -5.68, 1.73],
[-5.54, 4.32, 0.52]
],
[[-4.72, 1.5, 4.72], [3.79, 4.04, 6.76], [-4.6, 5.8, 6.93]]
]
);
Ok(())
}

View File

@ -192,19 +192,6 @@ fn unary_grad(device: &Device) -> Result<()> {
test_utils::to_vec1_round(grad_x, 2)?,
[0.01, 0.42, 0.0, 0.98],
);
// testing compared to pytorch nn.GELU(approximate = 'tanh')
let y = x.gelu()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.9964, 0.8412, 3.9999, 0.0839]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0116, 1.0830, 1.0003, 0.6188],
);
Ok(())
}

View File

@ -1,9 +0,0 @@
import numpy as np
x = np.arange(10)
# Write a npy file.
np.save("test.npy", x)
# Write multiple values to a npz file.
values = { "x": x, "x_plus_one": x + 1 }
np.savez("test.npz", **values)

View File

@ -1,24 +0,0 @@
use candle_core::{DType, Result, Tensor};
#[test]
fn npy() -> Result<()> {
let npy = Tensor::read_npy("tests/test.npy")?;
assert_eq!(
npy.to_dtype(DType::U8)?.to_vec1::<u8>()?,
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
);
Ok(())
}
#[test]
fn npz() -> Result<()> {
let npz = Tensor::read_npz("tests/test.npz")?;
assert_eq!(npz.len(), 2);
assert_eq!(npz[0].0, "x");
assert_eq!(npz[1].0, "x_plus_one");
assert_eq!(
npz[1].1.to_dtype(DType::U8)?.to_vec1::<u8>()?,
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
);
Ok(())
}

View File

@ -29,26 +29,7 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
Ok(())
}
fn arange(device: &Device) -> Result<()> {
assert_eq!(
Tensor::arange(0u8, 5u8, device)?.to_vec1::<u8>()?,
[0, 1, 2, 3, 4],
);
assert_eq!(
Tensor::arange_step(0u8, 5u8, 2, device)?.to_vec1::<u8>()?,
[0, 2, 4],
);
assert_eq!(
Tensor::arange_step(0u8, 5u8, 3, device)?.to_vec1::<u8>()?,
[0, 3],
);
assert_eq!(
Tensor::arange_step(5i64, 0i64, -1, device)?.to_vec1::<i64>()?,
[5, 4, 3, 2, 1],
);
Ok(())
}
@ -1056,7 +1037,6 @@ fn randn(device: &Device) -> Result<()> {
test_device!(zeros, zeros_cpu, zeros_gpu);
test_device!(ones, ones_cpu, ones_gpu);
test_device!(arange, arange_cpu, arange_gpu);
test_device!(add_mul, add_mul_cpu, add_mul_gpu);
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu);
test_device!(narrow, narrow_cpu, narrow_gpu);
@ -1093,27 +1073,3 @@ fn randn_hasneg() -> Result<()> {
}
Ok(())
}
#[test]
fn pad_with_same() -> Result<()> {
let t = Tensor::arange(1f32, 5f32, &Device::Cpu)?.reshape((2, 2))?;
let t0 = t.pad_with_same(0, 1, 2)?;
assert_eq!(
t0.to_vec2::<f32>()?,
[[1.0, 2.0], [1.0, 2.0], [3.0, 4.0], [3.0, 4.0], [3.0, 4.0]]
);
let t1 = t.pad_with_same(1, 1, 2)?;
assert_eq!(
t1.to_vec2::<f32>()?,
[[1.0, 1.0, 2.0, 2.0, 2.0], [3.0, 3.0, 4.0, 4.0, 4.0]]
);
Ok(())
}
#[test]
fn i64_abs() -> Result<()> {
let t = Tensor::new(&[-42i64, 1337], &Device::Cpu)?;
let t = t.abs()?;
assert_eq!(t.to_vec1::<i64>()?, [42, 1337]);
Ok(())
}

Binary file not shown.

Binary file not shown.

View File

@ -21,7 +21,6 @@ half = { workspace = true, optional = true }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
pyo3 = { version = "0.20.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
safetensors = { workspace = true }
serde = { workspace = true }
@ -51,7 +50,6 @@ anyhow = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
cudnn = ["candle/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
mkl = ["dep:intel-mkl-src", "candle/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
@ -60,7 +58,3 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"]
[[example]]
name = "llama_multiprocess"
required-features = ["cuda", "nccl", "flash-attn"]
[[example]]
name = "reinforcement-learning"
required-features = ["pyo3"]

View File

@ -5,11 +5,11 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
use anyhow::{Error as E, Result};
use anyhow::{anyhow, Error as E, Result};
use candle::Tensor;
use candle_nn::VarBuilder;
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use hf_hub::{api::sync::Api, Cache, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};
#[derive(Parser, Debug)]
@ -19,6 +19,10 @@ struct Args {
#[arg(long)]
cpu: bool,
/// Run offline (you must have the files already cached)
#[arg(long)]
offline: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
@ -34,10 +38,6 @@ struct Args {
#[arg(long)]
prompt: Option<String>,
/// Use the pytorch weights rather than the safetensors ones
#[arg(long)]
use_pth: bool,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
@ -60,27 +60,34 @@ impl Args {
};
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
let (config_filename, tokenizer_filename, weights_filename) = {
let (config_filename, tokenizer_filename, weights_filename) = if self.offline {
let cache = Cache::default().repo(repo);
(
cache
.get("config.json")
.ok_or(anyhow!("Missing config file in cache"))?,
cache
.get("tokenizer.json")
.ok_or(anyhow!("Missing tokenizer file in cache"))?,
cache
.get("model.safetensors")
.ok_or(anyhow!("Missing weights file in cache"))?,
)
} else {
let api = Api::new()?;
let api = api.repo(repo);
let config = api.get("config.json")?;
let tokenizer = api.get("tokenizer.json")?;
let weights = if self.use_pth {
api.get("pytorch_model.bin")?
} else {
api.get("model.safetensors")?
};
(config, tokenizer, weights)
(
api.get("config.json")?,
api.get("tokenizer.json")?,
api.get("model.safetensors")?,
)
};
let config = std::fs::read_to_string(config_filename)?;
let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let vb = if self.use_pth {
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
} else {
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
};
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer))
}

View File

@ -1,19 +0,0 @@
# candle-blip
The
[blip-image-captioning](https://huggingface.co/Salesforce/blip-image-captioning-base)
model can generate captions for an input image.
## Running on an example
```bash
cargo run --example blip --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg
```
```
Running on CPU, to run on GPU, build this example with `--features cuda`
loaded image Tensor[dims 3, 384, 384; f32]
model built
several cyclists are riding down a road with cars behind them%
```
![Leading group, Giro d'Italia 2021](../yolo-v8/assets/bike.jpg)

View File

@ -1,154 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;
use candle::{DType, Device, Result, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::blip;
use candle_transformers::models::quantized_blip;
use tokenizers::Tokenizer;
enum Model {
M(blip::BlipForConditionalGeneration),
Q(quantized_blip::BlipForConditionalGeneration),
}
impl Model {
fn text_decoder_forward(&mut self, xs: &Tensor, img_xs: &Tensor) -> Result<Tensor> {
match self {
Self::M(m) => m.text_decoder().forward(xs, img_xs),
Self::Q(m) => m.text_decoder().forward(xs, img_xs),
}
}
}
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,
}
const SEP_TOKEN_ID: u32 = 102;
/// Loads an image from disk using the image crate, this returns a tensor with shape
/// (3, 384, 384). OpenAI normalization is applied.
pub fn load_image<P: AsRef<std::path::Path>>(p: P) -> Result<Tensor> {
let img = image::io::Reader::open(p)?
.decode()
.map_err(candle::Error::wrap)?
.resize_to_fill(384, 384, image::imageops::FilterType::Triangle);
let img = img.to_rgb8();
let data = img.into_raw();
let data = Tensor::from_vec(data, (384, 384, 3), &Device::Cpu)?.permute((2, 0, 1))?;
let mean =
Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], &Device::Cpu)?.reshape((3, 1, 1))?;
let std = Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], &Device::Cpu)?
.reshape((3, 1, 1))?;
(data.to_dtype(candle::DType::F32)? / 255.)?
.broadcast_sub(&mean)?
.broadcast_div(&std)
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
if args.quantized {
let api = api.model("lmz/candle-blip".to_string());
api.get("blip-image-captioning-large-q4k.gguf")?
} else {
let api = api.repo(hf_hub::Repo::with_revision(
"Salesforce/blip-image-captioning-large".to_string(),
hf_hub::RepoType::Model,
"refs/pr/18".to_string(),
));
api.get("model.safetensors")?
}
}
Some(model) => model.into(),
};
let tokenizer = match args.tokenizer {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("Salesforce/blip-image-captioning-large".to_string());
api.get("tokenizer.json")?
}
Some(file) => file.into(),
};
let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let mut tokenizer = TokenOutputStream::new(tokenizer);
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let config = blip::Config::image_captioning_large();
let (image_embeds, device, mut model) = if args.quantized {
let device = Device::Cpu;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let vb = quantized_blip::VarBuilder::from_gguf(model_file)?;
let model = quantized_blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::Q(model))
} else {
let device = candle_examples::device(args.cpu)?;
let image = load_image(args.image)?.to_device(&device)?;
println!("loaded image {image:?}");
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = blip::BlipForConditionalGeneration::new(&config, vb)?;
let image_embeds = image.unsqueeze(0)?.apply(model.vision_model())?;
(image_embeds, device, Model::M(model))
};
let mut token_ids = vec![30522u32];
for index in 0..1000 {
let context_size = if index > 0 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.text_decoder_forward(&input_ids, &image_embeds)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
if token == SEP_TOKEN_ID {
break;
}
token_ids.push(token);
if let Some(t) = tokenizer.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}

View File

@ -1,59 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::Parser;
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::convmixer;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-convmixer".into());
api.get("convmixer_1024_20_ks9_p14.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = convmixer::c1024_20(1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -1,45 +0,0 @@
# candle-jina-bert
Jina-Bert is a general large language model with a context size of 8192, [model
card](https://huggingface.co/jinaai/jina-embeddings-v2-base-en). In this example
it can be used for two different tasks:
- Compute sentence embeddings for a prompt.
- Compute similarities between a set of sentences.
## Sentence embeddings
Jina-Bert is used to compute the sentence embeddings for a prompt. The model weights
are downloaded from the hub on the first run.
```bash
cargo run --example jina-bert --release -- --prompt "Here is a test sentence"
> [[[ 0.1595, -0.9885, 0.6494, ..., 0.3003, -0.6901, -1.2355],
> [ 0.0374, -0.1798, 1.3359, ..., 0.6731, 0.2133, -1.6807],
> [ 0.1700, -0.8534, 0.8924, ..., -0.1785, -0.0727, -1.5087],
> ...
> [-0.3113, -1.3665, 0.2027, ..., -0.2519, 0.1711, -1.5811],
> [ 0.0907, -1.0492, 0.5382, ..., 0.0242, -0.7077, -1.0830],
> [ 0.0369, -0.6343, 0.6105, ..., 0.0671, 0.3778, -1.1505]]]
> Tensor[[1, 7, 768], f32]
```
## Similarities
In this example, Jina-Bert is used to compute the sentence embeddings for a set of
sentences (hardcoded in the examples). Then cosine similarities are computed for
each sentence pair and they are reported by decreasing values, hence the first
reported pair contains the two sentences that have the highest similarity score.
The sentence embeddings are computed using average pooling through all the
sentence tokens, including some potential padding.
```bash
cargo run --example jina-bert --release
> score: 0.94 'The new movie is awesome' 'The new movie is so great'
> score: 0.81 'The cat sits outside' 'The cat plays in the garden'
> score: 0.78 'I love pasta' 'Do you like pizza?'
> score: 0.68 'I love pasta' 'The new movie is awesome'
> score: 0.67 'A man is playing guitar' 'A woman watches TV'
```

View File

@ -1,180 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_transformers::models::jina_bert::{BertModel, Config};
use anyhow::Error as E;
use candle::{DType, Module, Tensor};
use candle_nn::VarBuilder;
use clap::Parser;
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
/// The number of times to run the prompt.
#[arg(long, default_value = "1")]
n: usize,
/// L2 normalization for embeddings.
#[arg(long, default_value = "true")]
normalize_embeddings: bool,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
model: Option<String>,
}
impl Args {
fn build_model_and_tokenizer(&self) -> anyhow::Result<(BertModel, tokenizers::Tokenizer)> {
use hf_hub::{api::sync::Api, Repo, RepoType};
let model = match &self.model {
Some(model_file) => std::path::PathBuf::from(model_file),
None => Api::new()?
.repo(Repo::new(
"jinaai/jina-embeddings-v2-base-en".to_string(),
RepoType::Model,
))
.get("model.safetensors")?,
};
let tokenizer = match &self.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => Api::new()?
.repo(Repo::new(
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
RepoType::Model,
))
.get("tokenizer.json")?,
};
let device = candle_examples::device(self.cpu)?;
let config = Config::v2_base();
let tokenizer = tokenizers::Tokenizer::from_file(tokenizer).map_err(E::msg)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let model = BertModel::new(vb, &config)?;
Ok((model, tokenizer))
}
}
fn main() -> anyhow::Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let start = std::time::Instant::now();
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
let device = &model.device;
if let Some(prompt) = args.prompt {
let tokenizer = tokenizer
.with_padding(None)
.with_truncation(None)
.map_err(E::msg)?;
let tokens = tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
println!("Loaded and encoded {:?}", start.elapsed());
for idx in 0..args.n {
let start = std::time::Instant::now();
let ys = model.forward(&token_ids)?;
if idx == 0 {
println!("{ys}");
}
println!("Took {:?}", start.elapsed());
}
} else {
let sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
];
let n_sentences = sentences.len();
if let Some(pp) = tokenizer.get_padding_mut() {
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
} else {
let pp = tokenizers::PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
};
tokenizer.with_padding(Some(pp));
}
let tokens = tokenizer
.encode_batch(sentences.to_vec(), true)
.map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
let token_ids = Tensor::stack(&token_ids, 0)?;
println!("running inference on batch {:?}", token_ids.shape());
let embeddings = model.forward(&token_ids)?;
println!("generated embeddings {:?}", embeddings.shape());
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3()?;
let embeddings = (embeddings.sum(1)? / (n_tokens as f64))?;
let embeddings = if args.normalize_embeddings {
normalize_l2(&embeddings)?
} else {
embeddings
};
println!("pooled embeddings {:?}", embeddings.shape());
let mut similarities = vec![];
for i in 0..n_sentences {
let e_i = embeddings.get(i)?;
for j in (i + 1)..n_sentences {
let e_j = embeddings.get(j)?;
let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::<f32>()?;
let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::<f32>()?;
let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt();
similarities.push((cosine_similarity, i, j))
}
}
similarities.sort_by(|u, v| v.0.total_cmp(&u.0));
for &(score, i, j) in similarities[..5].iter() {
println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j])
}
}
Ok(())
}
pub fn normalize_l2(v: &Tensor) -> candle::Result<Tensor> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}

View File

@ -6,10 +6,9 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
use candle_transformers::models::llama2_c as model;
use candle_transformers::models::llama2_c_weights as weights;
use candle_transformers::models::quantized_llama2_c as qmodel;
mod model;
mod training;
mod weights;
use clap::{Parser, Subcommand};
use anyhow::{Error as E, Result};
@ -20,7 +19,6 @@ use std::io::Write;
use tokenizers::Tokenizer;
use model::{Config, Llama};
use qmodel::QLlama;
use weights::TransformerWeights;
#[derive(Parser, Debug, Clone)]
@ -154,20 +152,6 @@ fn main() -> anyhow::Result<()> {
Ok(())
}
enum Model {
Llama(Llama),
QLlama(QLlama),
}
impl Model {
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
match self {
Self::Llama(l) => Ok(l.forward(xs, pos)?),
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
}
}
}
fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
use std::io::BufRead;
@ -257,66 +241,24 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let device = candle_examples::device(common_args.cpu)?;
let is_gguf = config_path.extension().map_or(false, |v| v == "gguf");
let is_safetensors = config_path
.extension()
.map_or(false, |v| v == "safetensors");
let (model, config) = if is_gguf {
let vb = qmodel::VarBuilder::from_gguf(config_path)?;
let (_vocab_size, dim) = vb
.get_no_shape("model.embed_tokens.weight")?
.shape()
.dims2()?;
let config = match dim {
64 => Config::tiny_260k(),
288 => Config::tiny_15m(),
512 => Config::tiny_42m(),
768 => Config::tiny_110m(),
_ => anyhow::bail!("no config for dim {dim}"),
};
let freq_cis_real = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_real",
)?
.dequantize(&candle::Device::Cpu)?;
let freq_cis_imag = vb
.get(
(config.seq_len, config.head_size() / 2),
"rot.freq_cis_imag",
)?
.dequantize(&candle::Device::Cpu)?;
let fake_vb = candle_nn::VarBuilder::from_tensors(
[
("freq_cis_real".to_string(), freq_cis_real),
("freq_cis_imag".to_string(), freq_cis_imag),
]
.into_iter()
.collect(),
candle::DType::F32,
&candle::Device::Cpu,
);
let cache = model::Cache::new(true, &config, fake_vb)?;
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
(model, config)
} else if is_safetensors {
let config = Config::tiny_15m();
let (vb, config) = if is_safetensors {
let config = Config::tiny();
let tensors = candle::safetensors::load(config_path, &device)?;
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config)
(vb, config)
} else {
let mut file = std::fs::File::open(config_path)?;
let config = Config::from_reader(&mut file)?;
println!("{config:?}");
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
let vb = weights.var_builder(&config, &device)?;
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
(model, config)
(vb, config)
};
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
let model = Llama::load(vb, &cache, config)?;
println!("starting the inference loop");
let mut logits_processor = LogitsProcessor::new(299792458, args.temperature, args.top_p);
@ -331,7 +273,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
let start_gen = std::time::Instant::now();
for index in 0.. {
if tokens.len() >= config.seq_len {
if tokens.len() >= model.config.seq_len {
break;
}
let context_size = if index > 0 { 1 } else { tokens.len() };

View File

@ -17,20 +17,7 @@ pub struct Config {
}
impl Config {
pub fn tiny_260k() -> Self {
Self {
dim: 64,
hidden_dim: 768,
n_layers: 5,
n_heads: 8,
n_kv_heads: 4,
vocab_size: 32000,
seq_len: 512,
norm_eps: 1e-5,
}
}
pub fn tiny_15m() -> Self {
pub fn tiny() -> Self {
Self {
dim: 288,
hidden_dim: 768,
@ -42,32 +29,6 @@ impl Config {
norm_eps: 1e-5,
}
}
pub fn tiny_42m() -> Self {
Self {
dim: 512,
hidden_dim: 768,
n_layers: 8,
n_heads: 8,
n_kv_heads: 8,
vocab_size: 32000,
seq_len: 1024,
norm_eps: 1e-5,
}
}
pub fn tiny_110m() -> Self {
Self {
dim: 768,
hidden_dim: 768,
n_layers: 12,
n_heads: 12,
n_kv_heads: 12,
vocab_size: 32000,
seq_len: 1024,
norm_eps: 1e-5,
}
}
}
#[derive(Clone)]
@ -75,9 +36,9 @@ pub struct Cache {
masks: Arc<Mutex<HashMap<usize, Tensor>>>,
pub use_kv_cache: bool,
#[allow(clippy::type_complexity)]
pub kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
pub cos: Tensor,
pub sin: Tensor,
kvs: Arc<Mutex<Vec<Option<(Tensor, Tensor)>>>>,
cos: Tensor,
sin: Tensor,
device: Device,
}
@ -114,7 +75,7 @@ impl Cache {
})
}
pub fn mask(&self, t: usize) -> Result<Tensor> {
fn mask(&self, t: usize) -> Result<Tensor> {
let mut masks = self.masks.lock().unwrap();
if let Some(mask) = masks.get(&t) {
Ok(mask.clone())

View File

@ -33,7 +33,7 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
);
let varmap = candle_nn::VarMap::new();
let vb = candle_nn::VarBuilder::from_varmap(&varmap, DType::F32, &device);
let config = Config::tiny_15m();
let config = Config::tiny();
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);

View File

@ -1,8 +1,9 @@
use anyhow::Result;
use byteorder::{LittleEndian, ReadBytesExt};
use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
use candle::{DType, Device, IndexOp, Shape, Tensor};
use candle_nn::VarBuilder;
use super::llama2_c::Config;
use crate::model::Config;
pub struct TransformerWeights {
// token embedding table

View File

@ -1,38 +0,0 @@
# candle-marian-mt
`marian-mt` is a neural machine translation model. In this example it is used to
translate text from French to English. See the associated [model
card](https://huggingface.co/Helsinki-NLP/opus-mt-tc-big-fr-en) for details on
the model itself.
## Running an example
```bash
cargo run --example marian-mt --release -- \
--text "Demain, dès l'aube, à l'heure où blanchit la campagne, Je partirai. Vois-tu, je sais que tu m'attends. J'irai par la forêt, j'irai par la montagne. Je ne puis demeurer loin de toi plus longtemps."
```
```
<NIL> Tomorrow, at dawn, at the time when the country is whitening, I will go. See,
I know you are waiting for me. I will go through the forest, I will go through the
mountain. I cannot stay far from you any longer.</s>
```
## Generating the tokenizer.json files
You can use the following script to generate the `tokenizer.json` config files
from the hf-hub repos. This requires the `tokenizers` and `sentencepiece`
packages to be install and use the `convert_slow_tokenizer.py` script from this
directory.
```python
from convert_slow_tokenizer import MarianConverter
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en", use_fast=False)
fast_tokenizer = MarianConverter(tokenizer, index=0).converted()
fast_tokenizer.save(f"tokenizer-marian-base-fr.json")
fast_tokenizer = MarianConverter(tokenizer, index=1).converted()
fast_tokenizer.save(f"tokenizer-marian-base-en.json")
```

File diff suppressed because it is too large Load Diff

View File

@ -1,152 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::Error as E;
use clap::{Parser, ValueEnum};
use candle::{DType, Tensor};
use candle_examples::token_output_stream::TokenOutputStream;
use candle_nn::VarBuilder;
use candle_transformers::models::marian;
use tokenizers::Tokenizer;
#[derive(Clone, Debug, Copy, ValueEnum)]
enum Which {
Base,
Big,
}
// TODO: Maybe add support for the conditional prompt.
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
tokenizer_dec: Option<String>,
/// Choose the variant of the model to run.
#[arg(long, default_value = "big")]
which: Which,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Use the quantized version of the model.
#[arg(long)]
quantized: bool,
/// Text to be translated
#[arg(long)]
text: String,
}
pub fn main() -> anyhow::Result<()> {
use hf_hub::api::sync::Api;
let args = Args::parse();
let config = match args.which {
Which::Base => marian::Config::opus_mt_fr_en(),
Which::Big => marian::Config::opus_mt_tc_big_fr_en(),
};
let tokenizer = {
let tokenizer = match args.tokenizer {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let name = match args.which {
Which::Base => "tokenizer-marian-base-fr.json",
Which::Big => "tokenizer-marian-fr.json",
};
Api::new()?
.model("lmz/candle-marian".to_string())
.get(name)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let tokenizer_dec = {
let tokenizer = match args.tokenizer_dec {
Some(tokenizer) => std::path::PathBuf::from(tokenizer),
None => {
let name = match args.which {
Which::Base => "tokenizer-marian-base-en.json",
Which::Big => "tokenizer-marian-en.json",
};
Api::new()?
.model("lmz/candle-marian".to_string())
.get(name)?
}
};
Tokenizer::from_file(&tokenizer).map_err(E::msg)?
};
let mut tokenizer_dec = TokenOutputStream::new(tokenizer_dec);
let device = candle_examples::device(args.cpu)?;
let vb = {
let model = match args.model {
Some(model) => std::path::PathBuf::from(model),
None => match args.which {
Which::Base => Api::new()?
.repo(hf_hub::Repo::with_revision(
"Helsinki-NLP/opus-mt-fr-en".to_string(),
hf_hub::RepoType::Model,
"refs/pr/4".to_string(),
))
.get("model.safetensors")?,
Which::Big => Api::new()?
.model("Helsinki-NLP/opus-mt-tc-big-fr-en".to_string())
.get("model.safetensors")?,
},
};
unsafe { VarBuilder::from_mmaped_safetensors(&[&model], DType::F32, &device)? }
};
let mut model = marian::MTModel::new(&config, vb)?;
let mut logits_processor =
candle_transformers::generation::LogitsProcessor::new(1337, None, None);
let encoder_xs = {
let mut tokens = tokenizer
.encode(args.text, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
tokens.push(config.eos_token_id);
let tokens = Tensor::new(tokens.as_slice(), &device)?.unsqueeze(0)?;
model.encoder().forward(&tokens, 0)?
};
let mut token_ids = vec![config.decoder_start_token_id];
for index in 0..1000 {
let context_size = if index >= 1 { 1 } else { token_ids.len() };
let start_pos = token_ids.len().saturating_sub(context_size);
let input_ids = Tensor::new(&token_ids[start_pos..], &device)?.unsqueeze(0)?;
let logits = model.decode(&input_ids, &encoder_xs, start_pos)?;
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
token_ids.push(token);
if let Some(t) = tokenizer_dec.next_token(token)? {
use std::io::Write;
print!("{t}");
std::io::stdout().flush()?;
}
if token == config.eos_token_id || token == config.forced_eos_token_id {
break;
}
}
if let Some(rest) = tokenizer_dec.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
println!();
Ok(())
}

View File

@ -9,7 +9,7 @@ use clap::{Parser, ValueEnum};
use rand::prelude::*;
use candle::{DType, Result, Tensor, D};
use candle_nn::{loss, ops, Conv2d, Linear, Module, ModuleT, Optimizer, VarBuilder, VarMap};
use candle_nn::{loss, ops, Conv2d, Linear, Module, Optimizer, VarBuilder, VarMap};
const IMAGE_DIM: usize = 784;
const LABELS: usize = 10;
@ -95,7 +95,7 @@ impl ConvNet {
.flatten_from(1)?
.apply(&self.fc1)?
.relu()?;
self.dropout.forward_t(&xs, train)?.apply(&self.fc2)
self.dropout.forward(&xs, train)?.apply(&self.fc2)
}
}

View File

@ -41,16 +41,3 @@ def median(arr):
else:
return arr[n//2]
```
This also supports the [Puffin Phi v2
model](https://huggingface.co/teknium/Puffin-Phi-v2) for human interaction.
```
$ cargo run --example phi --release -- \
--prompt "USER: What would you do on a sunny day in Paris?\nASSISTANT:" \
--sample-len 200 --model puffin-phi-v2 --quantized
USER: What would you do on a sunny day in Paris?
ASSISTANT: On a sunny day in Paris, you could visit the Musée du Louvre to admire the famous
painting "Mona Lisa" by Leonardo da Vinci. You might also want to stroll along the Champs-Élysées
and enjoy the beautiful architecture of the buildings around you. Don't forget to stop by a café
for a cup of coffee and to soak up the sun!"
```

View File

@ -5,7 +5,7 @@ extern crate intel_mkl_src;
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::{Parser, ValueEnum};
use clap::Parser;
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
@ -28,7 +28,6 @@ struct TextGeneration {
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
impl TextGeneration {
@ -41,7 +40,6 @@ impl TextGeneration {
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
@ -51,7 +49,6 @@ impl TextGeneration {
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
}
}
@ -59,24 +56,20 @@ impl TextGeneration {
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the phi model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let mut tokens = tokens.get_ids().to_vec();
print!("{prompt}");
std::io::stdout().flush()?;
let mut tokens = self
.tokenizer
.encode(prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"),
};
print!("{prompt}");
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
@ -117,16 +110,6 @@ impl TextGeneration {
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum WhichModel {
#[value(name = "1")]
V1,
#[value(name = "1.5")]
V1_5,
PuffinPhiV2,
PhiHermes,
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
@ -138,10 +121,6 @@ struct Args {
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
@ -161,21 +140,15 @@ struct Args {
#[arg(long, short = 'n', default_value_t = 100)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "microsoft/phi-1_5")]
model_id: String,
#[arg(long, default_value = "1.5")]
model: WhichModel,
#[arg(long)]
revision: Option<String>,
#[arg(long, default_value = "refs/pr/18")]
revision: String,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
#[arg(long)]
quantized: bool,
@ -216,62 +189,20 @@ fn main() -> Result<()> {
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => {
if args.quantized {
"lmz/candle-quantized-phi".to_string()
} else {
match args.model {
WhichModel::V1 => "microsoft/phi-1".to_string(),
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
"lmz/candle-quantized-phi".to_string()
}
}
}
}
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => {
if args.quantized {
"main".to_string()
} else {
match args.model {
WhichModel::V1 => "refs/pr/2".to_string(),
WhichModel::V1_5 => "refs/pr/18".to_string(),
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => "main".to_string(),
}
}
}
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => match args.model {
WhichModel::V1 | WhichModel::V1_5 => repo.get("tokenizer.json")?,
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
repo.get("tokenizer-puffin-phi-v2.json")?
}
},
};
let repo = api.repo(Repo::with_revision(
args.model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = repo.get("tokenizer.json")?;
let filename = match args.weight_file {
Some(weight_file) => std::path::PathBuf::from(weight_file),
None => {
if args.quantized {
match args.model {
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2-q4k.gguf")?,
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B-q4k.gguf")?,
}
api.model("lmz/candle-quantized-phi".to_string())
.get("model-q4k.gguf")?
} else {
match args.model {
WhichModel::V1 | WhichModel::V1_5 => repo.get("model.safetensors")?,
WhichModel::PuffinPhiV2 => repo.get("model-puffin-phi-v2.safetensors")?,
WhichModel::PhiHermes => repo.get("model-phi-hermes-1_3B.safetensors")?,
}
repo.get("model.safetensors")?
}
}
};
@ -279,12 +210,7 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = match args.model {
WhichModel::V1 => Config::v1(),
WhichModel::V1_5 => Config::v1_5(),
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
};
let config = Config::v1_5();
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
let model = QMixFormer::new(&config, vb)?;
@ -305,7 +231,6 @@ fn main() -> Result<()> {
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;

View File

@ -9,7 +9,7 @@ use std::io::Write;
use tokenizers::Tokenizer;
use candle::quantized::{ggml_file, gguf_file};
use candle::Tensor;
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
use candle_transformers::models::quantized_llama as model;
@ -48,8 +48,6 @@ enum Which {
Mistral7b,
#[value(name = "7b-mistral-instruct")]
Mistral7bInstruct,
#[value(name = "7b-zephyr")]
Zephyr7b,
}
impl Which {
@ -64,7 +62,7 @@ impl Which {
| Self::L7bCode
| Self::L13bCode
| Self::L34bCode => false,
Self::Mistral7b | Self::Mistral7bInstruct | Self::Zephyr7b => true,
Self::Mistral7b | Self::Mistral7bInstruct => true,
}
}
}
@ -176,10 +174,6 @@ impl Args {
"TheBloke/Mistral-7B-Instruct-v0.1-GGUF",
"mistral-7b-instruct-v0.1.Q4_K_S.gguf",
),
Which::Zephyr7b => (
"TheBloke/zephyr-7B-alpha-GGUF",
"zephyr-7b-alpha.Q4_K_M.gguf",
),
};
let api = hf_hub::api::sync::Api::new()?;
let api = api.model(repo.to_string());
@ -232,13 +226,11 @@ fn main() -> anyhow::Result<()> {
use tracing_subscriber::prelude::*;
let args = Args::parse();
let device = candle_examples::device(false)?;
let temperature = if args.temperature == 0. {
None
} else {
Some(args.temperature)
};
tracing_subscriber::fmt::init();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
@ -278,10 +270,10 @@ fn main() -> anyhow::Result<()> {
&format_size(total_size_in_bytes),
start.elapsed().as_secs_f32(),
);
ModelWeights::from_gguf(model, &mut file, &device)?
ModelWeights::from_gguf(model, &mut file)?
}
Some("ggml" | "bin") | Some(_) | None => {
let model = ggml_file::Content::read(&mut file, &device)?;
let model = ggml_file::Content::read(&mut file)?;
let mut total_size_in_bytes = 0;
for (_, tensor) in model.tensors.iter() {
let elem_count = tensor.shape().elem_count();
@ -303,13 +295,9 @@ fn main() -> anyhow::Result<()> {
| Which::L7bCode
| Which::L13bCode
| Which::L34bCode => 1,
Which::Mistral7b
| Which::Mistral7bInstruct
| Which::Zephyr7b
| Which::L70b
| Which::L70bChat => 8,
Which::Mistral7b | Which::Mistral7bInstruct | Which::L70b | Which::L70bChat => 8,
};
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa), &device)?
ModelWeights::from_ggml(model, args.gqa.unwrap_or(default_gqa))?
}
};
println!("model built");
@ -368,13 +356,10 @@ fn main() -> anyhow::Result<()> {
let start_prompt_processing = std::time::Instant::now();
let mut next_token = {
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
let input = Tensor::new(prompt_tokens.as_slice(), &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, 0)?;
let logits = logits.squeeze(0)?;
// TODO Remove this once implementation is finished.
let logits = logits.ones_like()?;
// logits_processor.sample(&logits)?
15043
logits_processor.sample(&logits)?
};
let prompt_dt = start_prompt_processing.elapsed();
all_tokens.push(next_token);
@ -384,7 +369,7 @@ fn main() -> anyhow::Result<()> {
let start_post_prompt = std::time::Instant::now();
for index in 0..to_sample {
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
let input = Tensor::new(&[next_token], &Device::Cpu)?.unsqueeze(0)?;
let logits = model.forward(&input, prompt_tokens.len() + index)?;
let logits = logits.squeeze(0)?;
let logits = if args.repeat_penalty == 1. {
@ -397,10 +382,7 @@ fn main() -> anyhow::Result<()> {
&all_tokens[start_at..],
)?
};
// TODO Remove this once implementation is finished.
// let logits = logits.ones_like()?;
// next_token = logits_processor.sample(&logits)?;
let next_token = 15043;
next_token = logits_processor.sample(&logits)?;
all_tokens.push(next_token);
print_token(next_token, &tokenizer);
if next_token == eos_token {

View File

@ -1,16 +0,0 @@
# candle-reinforcement-learning
Reinforcement Learning examples for candle.
This has been tested with `gymnasium` version `0.29.1`. You can install the
Python package with:
```bash
pip install "gymnasium[accept-rom-license]"
```
In order to run the example, use the following command. Note the additional
`--package` flag to ensure that there is no conflict with the `candle-pyo3`
crate.
```bash
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples
```

View File

@ -1,308 +0,0 @@
import gymnasium as gym
import numpy as np
from collections import deque
from PIL import Image
from multiprocessing import Process, Pipe
# atari_wrappers.py
class NoopResetEnv(gym.Wrapper):
def __init__(self, env, noop_max=30):
"""Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
"""
gym.Wrapper.__init__(self, env)
self.noop_max = noop_max
self.override_num_noops = None
assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
def reset(self):
""" Do no-op action for a number of steps in [1, noop_max]."""
self.env.reset()
if self.override_num_noops is not None:
noops = self.override_num_noops
else:
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101
assert noops > 0
obs = None
for _ in range(noops):
obs, _, done, _ = self.env.step(0)
if done:
obs = self.env.reset()
return obs
class FireResetEnv(gym.Wrapper):
def __init__(self, env):
"""Take action on reset for environments that are fixed until firing."""
gym.Wrapper.__init__(self, env)
assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
assert len(env.unwrapped.get_action_meanings()) >= 3
def reset(self):
self.env.reset()
obs, _, done, _ = self.env.step(1)
if done:
self.env.reset()
obs, _, done, _ = self.env.step(2)
if done:
self.env.reset()
return obs
class ImageSaver(gym.Wrapper):
def __init__(self, env, img_path, rank):
gym.Wrapper.__init__(self, env)
self._cnt = 0
self._img_path = img_path
self._rank = rank
def step(self, action):
step_result = self.env.step(action)
obs, _, _, _ = step_result
img = Image.fromarray(obs, 'RGB')
img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt))
self._cnt += 1
return step_result
class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True
def step(self, action):
obs, reward, done, info = self.env.step(action)
self.was_real_done = done
# check current lives, make loss of life terminal,
# then update lives to handle bonus lives
lives = self.env.unwrapped.ale.lives()
if lives < self.lives and lives > 0:
# for Qbert somtimes we stay in lives == 0 condtion for a few frames
# so its important to keep lives > 0, so that we only reset once
# the environment advertises done.
done = True
self.lives = lives
return obs, reward, done, info
def reset(self):
"""Reset only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
and the learner need not know about any of this behind-the-scenes.
"""
if self.was_real_done:
obs = self.env.reset()
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.ale.lives()
return obs
class MaxAndSkipEnv(gym.Wrapper):
def __init__(self, env, skip=4):
"""Return only every `skip`-th frame"""
gym.Wrapper.__init__(self, env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = deque(maxlen=2)
self._skip = skip
def step(self, action):
"""Repeat action, sum reward, and max over last observations."""
total_reward = 0.0
done = None
for _ in range(self._skip):
obs, reward, done, info = self.env.step(action)
self._obs_buffer.append(obs)
total_reward += reward
if done:
break
max_frame = np.max(np.stack(self._obs_buffer), axis=0)
return max_frame, total_reward, done, info
def reset(self):
"""Clear past frame buffer and init. to first obs. from inner env."""
self._obs_buffer.clear()
obs = self.env.reset()
self._obs_buffer.append(obs)
return obs
class ClipRewardEnv(gym.RewardWrapper):
def reward(self, reward):
"""Bin reward to {+1, 0, -1} by its sign."""
return np.sign(reward)
class WarpFrame(gym.ObservationWrapper):
def __init__(self, env):
"""Warp frames to 84x84 as done in the Nature paper and later work."""
gym.ObservationWrapper.__init__(self, env)
self.res = 84
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8')
def observation(self, obs):
frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32'))
frame = np.array(Image.fromarray(frame).resize((self.res, self.res),
resample=Image.BILINEAR), dtype=np.uint8)
return frame.reshape((self.res, self.res, 1))
class FrameStack(gym.Wrapper):
def __init__(self, env, k):
"""Buffer observations and stack across channels (last axis)."""
gym.Wrapper.__init__(self, env)
self.k = k
self.frames = deque([], maxlen=k)
shp = env.observation_space.shape
assert shp[2] == 1 # can only stack 1-channel frames
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8')
def reset(self):
"""Clear buffer and re-fill by duplicating the first observation."""
ob = self.env.reset()
for _ in range(self.k): self.frames.append(ob)
return self.observation()
def step(self, action):
ob, reward, done, info = self.env.step(action)
self.frames.append(ob)
return self.observation(), reward, done, info
def observation(self):
assert len(self.frames) == self.k
return np.concatenate(self.frames, axis=2)
def wrap_deepmind(env, episode_life=True, clip_rewards=True):
"""Configure environment for DeepMind-style Atari.
Note: this does not include frame stacking!"""
assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip
if episode_life:
env = EpisodicLifeEnv(env)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
if 'FIRE' in env.unwrapped.get_action_meanings():
env = FireResetEnv(env)
env = WarpFrame(env)
if clip_rewards:
env = ClipRewardEnv(env)
return env
# envs.py
def make_env(env_id, img_dir, seed, rank):
def _thunk():
env = gym.make(env_id)
env.reset(seed=(seed + rank))
if img_dir is not None:
env = ImageSaver(env, img_dir, rank)
env = wrap_deepmind(env)
env = WrapPyTorch(env)
return env
return _thunk
class WrapPyTorch(gym.ObservationWrapper):
def __init__(self, env=None):
super(WrapPyTorch, self).__init__(env)
self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32')
def observation(self, observation):
return observation.transpose(2, 0, 1)
# vecenv.py
class VecEnv(object):
"""
Vectorized environment base class
"""
def step(self, vac):
"""
Apply sequence of actions to sequence of environments
actions -> (observations, rewards, news)
where 'news' is a boolean vector indicating whether each element is new.
"""
raise NotImplementedError
def reset(self):
"""
Reset all environments
"""
raise NotImplementedError
def close(self):
pass
# subproc_vec_env.py
def worker(remote, env_fn_wrapper):
env = env_fn_wrapper.x()
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if done:
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
ob = env.reset()
remote.send(ob)
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.action_space, env.observation_space))
else:
raise NotImplementedError
class CloudpickleWrapper(object):
"""
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle)
"""
def __init__(self, x):
self.x = x
def __getstate__(self):
import cloudpickle
return cloudpickle.dumps(self.x)
def __setstate__(self, ob):
import pickle
self.x = pickle.loads(ob)
class SubprocVecEnv(VecEnv):
def __init__(self, env_fns):
"""
envs: list of gym environments to run in subprocesses
"""
nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn)))
for (work_remote, env_fn) in zip(self.work_remotes, env_fns)]
for p in self.ps:
p.start()
self.remotes[0].send(('get_spaces', None))
self.action_space, self.observation_space = self.remotes[0].recv()
def step(self, actions):
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
results = [remote.recv() for remote in self.remotes]
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
return np.stack([remote.recv() for remote in self.remotes])
def close(self):
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
@property
def num_envs(self):
return len(self.remotes)
# Create the environment.
def make(env_name, img_dir, num_processes):
envs = SubprocVecEnv([
make_env(env_name, img_dir, 1337, i) for i in range(num_processes)
])
return envs

View File

@ -1,451 +0,0 @@
use std::collections::VecDeque;
use std::fmt::Display;
use candle::{DType, Device, Error, Module, Result, Tensor, Var};
use candle_nn::{
func, linear, sequential::seq, Activation, AdamW, Optimizer, ParamsAdamW, Sequential,
VarBuilder, VarMap,
};
use rand::{distributions::Uniform, thread_rng, Rng};
pub struct OuNoise {
mu: f64,
theta: f64,
sigma: f64,
state: Tensor,
}
impl OuNoise {
pub fn new(mu: f64, theta: f64, sigma: f64, size_action: usize) -> Result<Self> {
Ok(Self {
mu,
theta,
sigma,
state: Tensor::ones(size_action, DType::F32, &Device::Cpu)?,
})
}
pub fn sample(&mut self) -> Result<Tensor> {
let rand = Tensor::randn_like(&self.state, 0.0, 1.0)?;
let dx = ((self.theta * (self.mu - &self.state)?)? + (self.sigma * rand)?)?;
self.state = (&self.state + dx)?;
Ok(self.state.clone())
}
}
#[derive(Clone)]
struct Transition {
state: Tensor,
action: Tensor,
reward: Tensor,
next_state: Tensor,
terminated: bool,
truncated: bool,
}
impl Transition {
fn new(
state: &Tensor,
action: &Tensor,
reward: &Tensor,
next_state: &Tensor,
terminated: bool,
truncated: bool,
) -> Self {
Self {
state: state.clone(),
action: action.clone(),
reward: reward.clone(),
next_state: next_state.clone(),
terminated,
truncated,
}
}
}
pub struct ReplayBuffer {
buffer: VecDeque<Transition>,
capacity: usize,
size: usize,
}
impl ReplayBuffer {
pub fn new(capacity: usize) -> Self {
Self {
buffer: VecDeque::with_capacity(capacity),
capacity,
size: 0,
}
}
pub fn push(
&mut self,
state: &Tensor,
action: &Tensor,
reward: &Tensor,
next_state: &Tensor,
terminated: bool,
truncated: bool,
) {
if self.size == self.capacity {
self.buffer.pop_front();
} else {
self.size += 1;
}
self.buffer.push_back(Transition::new(
state, action, reward, next_state, terminated, truncated,
));
}
#[allow(clippy::type_complexity)]
pub fn random_batch(
&self,
batch_size: usize,
) -> Result<Option<(Tensor, Tensor, Tensor, Tensor, Vec<bool>, Vec<bool>)>> {
if self.size < batch_size {
Ok(None)
} else {
let transitions: Vec<&Transition> = thread_rng()
.sample_iter(Uniform::from(0..self.size))
.take(batch_size)
.map(|i| self.buffer.get(i).unwrap())
.collect();
let states: Vec<Tensor> = transitions
.iter()
.map(|t| t.state.unsqueeze(0))
.collect::<Result<_>>()?;
let actions: Vec<Tensor> = transitions
.iter()
.map(|t| t.action.unsqueeze(0))
.collect::<Result<_>>()?;
let rewards: Vec<Tensor> = transitions
.iter()
.map(|t| t.reward.unsqueeze(0))
.collect::<Result<_>>()?;
let next_states: Vec<Tensor> = transitions
.iter()
.map(|t| t.next_state.unsqueeze(0))
.collect::<Result<_>>()?;
let terminateds: Vec<bool> = transitions.iter().map(|t| t.terminated).collect();
let truncateds: Vec<bool> = transitions.iter().map(|t| t.truncated).collect();
Ok(Some((
Tensor::cat(&states, 0)?,
Tensor::cat(&actions, 0)?,
Tensor::cat(&rewards, 0)?,
Tensor::cat(&next_states, 0)?,
terminateds,
truncateds,
)))
}
}
}
fn track(
varmap: &mut VarMap,
vb: &VarBuilder,
target_prefix: &str,
network_prefix: &str,
dims: &[(usize, usize)],
tau: f64,
) -> Result<()> {
for (i, &(in_dim, out_dim)) in dims.iter().enumerate() {
let target_w = vb.get((out_dim, in_dim), &format!("{target_prefix}-fc{i}.weight"))?;
let network_w = vb.get((out_dim, in_dim), &format!("{network_prefix}-fc{i}.weight"))?;
varmap.set_one(
format!("{target_prefix}-fc{i}.weight"),
((tau * network_w)? + ((1.0 - tau) * target_w)?)?,
)?;
let target_b = vb.get(out_dim, &format!("{target_prefix}-fc{i}.bias"))?;
let network_b = vb.get(out_dim, &format!("{network_prefix}-fc{i}.bias"))?;
varmap.set_one(
format!("{target_prefix}-fc{i}.bias"),
((tau * network_b)? + ((1.0 - tau) * target_b)?)?,
)?;
}
Ok(())
}
struct Actor<'a> {
varmap: VarMap,
vb: VarBuilder<'a>,
network: Sequential,
target_network: Sequential,
size_state: usize,
size_action: usize,
dims: Vec<(usize, usize)>,
}
impl Actor<'_> {
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
let mut varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
let dims = vec![(size_state, 400), (400, 300), (300, size_action)];
let make_network = |prefix: &str| {
let seq = seq()
.add(linear(
dims[0].0,
dims[0].1,
vb.pp(format!("{prefix}-fc0")),
)?)
.add(Activation::Relu)
.add(linear(
dims[1].0,
dims[1].1,
vb.pp(format!("{prefix}-fc1")),
)?)
.add(Activation::Relu)
.add(linear(
dims[2].0,
dims[2].1,
vb.pp(format!("{prefix}-fc2")),
)?)
.add(func(|xs| xs.tanh()));
Ok::<Sequential, Error>(seq)
};
let network = make_network("actor")?;
let target_network = make_network("target-actor")?;
// this sets the two networks to be equal to each other using tau = 1.0
track(&mut varmap, &vb, "target-actor", "actor", &dims, 1.0);
Ok(Self {
varmap,
vb,
network,
target_network,
size_state,
size_action,
dims,
})
}
fn forward(&self, state: &Tensor) -> Result<Tensor> {
self.network.forward(state)
}
fn target_forward(&self, state: &Tensor) -> Result<Tensor> {
self.target_network.forward(state)
}
fn track(&mut self, tau: f64) -> Result<()> {
track(
&mut self.varmap,
&self.vb,
"target-actor",
"actor",
&self.dims,
tau,
)
}
}
struct Critic<'a> {
varmap: VarMap,
vb: VarBuilder<'a>,
network: Sequential,
target_network: Sequential,
size_state: usize,
size_action: usize,
dims: Vec<(usize, usize)>,
}
impl Critic<'_> {
fn new(device: &Device, dtype: DType, size_state: usize, size_action: usize) -> Result<Self> {
let mut varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, dtype, device);
let dims: Vec<(usize, usize)> = vec![(size_state + size_action, 400), (400, 300), (300, 1)];
let make_network = |prefix: &str| {
let seq = seq()
.add(linear(
dims[0].0,
dims[0].1,
vb.pp(format!("{prefix}-fc0")),
)?)
.add(Activation::Relu)
.add(linear(
dims[1].0,
dims[1].1,
vb.pp(format!("{prefix}-fc1")),
)?)
.add(Activation::Relu)
.add(linear(
dims[2].0,
dims[2].1,
vb.pp(format!("{prefix}-fc2")),
)?);
Ok::<Sequential, Error>(seq)
};
let network = make_network("critic")?;
let target_network = make_network("target-critic")?;
// this sets the two networks to be equal to each other using tau = 1.0
track(&mut varmap, &vb, "target-critic", "critic", &dims, 1.0);
Ok(Self {
varmap,
vb,
network,
target_network,
size_state,
size_action,
dims,
})
}
fn forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
let xs = Tensor::cat(&[action, state], 1)?;
self.network.forward(&xs)
}
fn target_forward(&self, state: &Tensor, action: &Tensor) -> Result<Tensor> {
let xs = Tensor::cat(&[action, state], 1)?;
self.target_network.forward(&xs)
}
fn track(&mut self, tau: f64) -> Result<()> {
track(
&mut self.varmap,
&self.vb,
"target-critic",
"critic",
&self.dims,
tau,
)
}
}
#[allow(clippy::upper_case_acronyms)]
pub struct DDPG<'a> {
actor: Actor<'a>,
actor_optim: AdamW,
critic: Critic<'a>,
critic_optim: AdamW,
gamma: f64,
tau: f64,
replay_buffer: ReplayBuffer,
ou_noise: OuNoise,
size_state: usize,
size_action: usize,
pub train: bool,
}
impl DDPG<'_> {
#[allow(clippy::too_many_arguments)]
pub fn new(
device: &Device,
size_state: usize,
size_action: usize,
train: bool,
actor_lr: f64,
critic_lr: f64,
gamma: f64,
tau: f64,
buffer_capacity: usize,
ou_noise: OuNoise,
) -> Result<Self> {
let filter_by_prefix = |varmap: &VarMap, prefix: &str| {
varmap
.data()
.lock()
.unwrap()
.iter()
.filter_map(|(name, var)| name.starts_with(prefix).then_some(var.clone()))
.collect::<Vec<Var>>()
};
let actor = Actor::new(device, DType::F32, size_state, size_action)?;
let actor_optim = AdamW::new(
filter_by_prefix(&actor.varmap, "actor"),
ParamsAdamW {
lr: actor_lr,
..Default::default()
},
)?;
let critic = Critic::new(device, DType::F32, size_state, size_action)?;
let critic_optim = AdamW::new(
filter_by_prefix(&critic.varmap, "critic"),
ParamsAdamW {
lr: critic_lr,
..Default::default()
},
)?;
Ok(Self {
actor,
actor_optim,
critic,
critic_optim,
gamma,
tau,
replay_buffer: ReplayBuffer::new(buffer_capacity),
ou_noise,
size_state,
size_action,
train,
})
}
pub fn remember(
&mut self,
state: &Tensor,
action: &Tensor,
reward: &Tensor,
next_state: &Tensor,
terminated: bool,
truncated: bool,
) {
self.replay_buffer
.push(state, action, reward, next_state, terminated, truncated)
}
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
let actions = self
.actor
.forward(&state.detach()?.unsqueeze(0)?)?
.squeeze(0)?;
let actions = if self.train {
(actions + self.ou_noise.sample()?)?
} else {
actions
};
actions.squeeze(0)?.to_scalar::<f32>()
}
pub fn train(&mut self, batch_size: usize) -> Result<()> {
let (states, actions, rewards, next_states, _, _) =
match self.replay_buffer.random_batch(batch_size)? {
Some(v) => v,
_ => return Ok(()),
};
let q_target = self
.critic
.target_forward(&next_states, &self.actor.target_forward(&next_states)?)?;
let q_target = (rewards + (self.gamma * q_target)?.detach())?;
let q = self.critic.forward(&states, &actions)?;
let diff = (q_target - q)?;
let critic_loss = diff.sqr()?.mean_all()?;
self.critic_optim.backward_step(&critic_loss)?;
let actor_loss = self
.critic
.forward(&states, &self.actor.forward(&states)?)?
.mean_all()?
.neg()?;
self.actor_optim.backward_step(&actor_loss)?;
self.critic.track(self.tau)?;
self.actor.track(self.tau)?;
Ok(())
}
}

View File

@ -1,112 +0,0 @@
#![allow(unused)]
//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym)
use candle::{Device, Result, Tensor};
use pyo3::prelude::*;
use pyo3::types::PyDict;
/// The return value for a step.
#[derive(Debug)]
pub struct Step<A> {
pub state: Tensor,
pub action: A,
pub reward: f64,
pub terminated: bool,
pub truncated: bool,
}
impl<A: Copy> Step<A> {
/// Returns a copy of this step changing the observation tensor.
pub fn copy_with_obs(&self, state: &Tensor) -> Step<A> {
Step {
state: state.clone(),
action: self.action,
reward: self.reward,
terminated: self.terminated,
truncated: self.truncated,
}
}
}
/// An OpenAI Gym session.
pub struct GymEnv {
env: PyObject,
action_space: usize,
observation_space: Vec<usize>,
}
fn w(res: PyErr) -> candle::Error {
candle::Error::wrap(res)
}
impl GymEnv {
/// Creates a new session of the specified OpenAI Gym environment.
pub fn new(name: &str) -> Result<GymEnv> {
Python::with_gil(|py| {
let gym = py.import("gymnasium")?;
let make = gym.getattr("make")?;
let env = make.call1((name,))?;
let action_space = env.getattr("action_space")?;
let action_space = if let Ok(val) = action_space.getattr("n") {
val.extract()?
} else {
let action_space: Vec<usize> = action_space.getattr("shape")?.extract()?;
action_space[0]
};
let observation_space = env.getattr("observation_space")?;
let observation_space = observation_space.getattr("shape")?.extract()?;
Ok(GymEnv {
env: env.into(),
action_space,
observation_space,
})
})
.map_err(w)
}
/// Resets the environment, returning the observation tensor.
pub fn reset(&self, seed: u64) -> Result<Tensor> {
let state: Vec<f32> = Python::with_gil(|py| {
let kwargs = PyDict::new(py);
kwargs.set_item("seed", seed)?;
let state = self.env.call_method(py, "reset", (), Some(kwargs))?;
state.as_ref(py).get_item(0)?.extract()
})
.map_err(w)?;
Tensor::new(state, &Device::Cpu)
}
/// Applies an environment step using the specified action.
pub fn step<A: pyo3::IntoPy<pyo3::Py<pyo3::PyAny>> + Clone>(
&self,
action: A,
) -> Result<Step<A>> {
let (state, reward, terminated, truncated) = Python::with_gil(|py| {
let step = self.env.call_method(py, "step", (action.clone(),), None)?;
let step = step.as_ref(py);
let state: Vec<f32> = step.get_item(0)?.extract()?;
let reward: f64 = step.get_item(1)?.extract()?;
let terminated: bool = step.get_item(2)?.extract()?;
let truncated: bool = step.get_item(3)?.extract()?;
Ok((state, reward, terminated, truncated))
})
.map_err(w)?;
let state = Tensor::new(state, &Device::Cpu)?;
Ok(Step {
state,
action,
reward,
terminated,
truncated,
})
}
/// Returns the number of allowed actions for this environment.
pub fn action_space(&self) -> usize {
self.action_space
}
/// Returns the shape of the observation tensors.
pub fn observation_space(&self) -> &[usize] {
&self.observation_space
}
}

View File

@ -1,144 +0,0 @@
#![allow(unused)]
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
mod gym_env;
mod vec_gym_env;
mod ddpg;
use candle::{Device, Result, Tensor};
use clap::Parser;
use rand::Rng;
// The impact of the q value of the next state on the current state's q value.
const GAMMA: f64 = 0.99;
// The weight for updating the target networks.
const TAU: f64 = 0.005;
// The capacity of the replay buffer used for sampling training data.
const REPLAY_BUFFER_CAPACITY: usize = 100_000;
// The training batch size for each training iteration.
const TRAINING_BATCH_SIZE: usize = 100;
// The total number of episodes.
const MAX_EPISODES: usize = 100;
// The maximum length of an episode.
const EPISODE_LENGTH: usize = 200;
// The number of training iterations after one episode finishes.
const TRAINING_ITERATIONS: usize = 200;
// Ornstein-Uhlenbeck process parameters.
const MU: f64 = 0.0;
const THETA: f64 = 0.15;
const SIGMA: f64 = 0.1;
const ACTOR_LEARNING_RATE: f64 = 1e-4;
const CRITIC_LEARNING_RATE: f64 = 1e-3;
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
let env = gym_env::GymEnv::new("Pendulum-v1")?;
println!("action space: {}", env.action_space());
println!("observation space: {:?}", env.observation_space());
let size_state = env.observation_space().iter().product::<usize>();
let size_action = env.action_space();
let mut agent = ddpg::DDPG::new(
&Device::Cpu,
size_state,
size_action,
true,
ACTOR_LEARNING_RATE,
CRITIC_LEARNING_RATE,
GAMMA,
TAU,
REPLAY_BUFFER_CAPACITY,
ddpg::OuNoise::new(MU, THETA, SIGMA, size_action)?,
)?;
let mut rng = rand::thread_rng();
for episode in 0..MAX_EPISODES {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
agent.remember(
&state,
&Tensor::new(vec![action], &Device::Cpu)?,
&Tensor::new(vec![step.reward as f32], &Device::Cpu)?,
&step.state,
step.terminated,
step.truncated,
);
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
for _ in 0..TRAINING_ITERATIONS {
agent.train(TRAINING_BATCH_SIZE)?;
}
}
println!("Testing...");
agent.train = false;
for episode in 0..10 {
// let mut state = env.reset(episode as u64)?;
let mut state = env.reset(rng.gen::<u64>())?;
let mut total_reward = 0.0;
for _ in 0..EPISODE_LENGTH {
let mut action = 2.0 * agent.actions(&state)?;
action = action.clamp(-2.0, 2.0);
let step = env.step(vec![action])?;
total_reward += step.reward;
if step.terminated || step.truncated {
break;
}
state = step.state;
}
println!("episode {episode} with total reward of {total_reward}");
}
Ok(())
}

View File

@ -1,91 +0,0 @@
#![allow(unused)]
//! Vectorized version of the gym environment.
use candle::{DType, Device, Result, Tensor};
use pyo3::prelude::*;
use pyo3::types::PyDict;
#[derive(Debug)]
pub struct Step {
pub obs: Tensor,
pub reward: Tensor,
pub is_done: Tensor,
}
pub struct VecGymEnv {
env: PyObject,
action_space: usize,
observation_space: Vec<usize>,
}
fn w(res: PyErr) -> candle::Error {
candle::Error::wrap(res)
}
impl VecGymEnv {
pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result<VecGymEnv> {
Python::with_gil(|py| {
let sys = py.import("sys")?;
let path = sys.getattr("path")?;
let _ = path.call_method1(
"append",
("candle-examples/examples/reinforcement-learning",),
)?;
let gym = py.import("atari_wrappers")?;
let make = gym.getattr("make")?;
let env = make.call1((name, img_dir, nprocesses))?;
let action_space = env.getattr("action_space")?;
let action_space = action_space.getattr("n")?.extract()?;
let observation_space = env.getattr("observation_space")?;
let observation_space: Vec<usize> = observation_space.getattr("shape")?.extract()?;
let observation_space =
[vec![nprocesses].as_slice(), observation_space.as_slice()].concat();
Ok(VecGymEnv {
env: env.into(),
action_space,
observation_space,
})
})
.map_err(w)
}
pub fn reset(&self) -> Result<Tensor> {
let obs = Python::with_gil(|py| {
let obs = self.env.call_method0(py, "reset")?;
let obs = obs.call_method0(py, "flatten")?;
obs.extract::<Vec<f32>>(py)
})
.map_err(w)?;
Tensor::new(obs, &Device::Cpu)?.reshape(self.observation_space.as_slice())
}
pub fn step(&self, action: Vec<usize>) -> Result<Step> {
let (obs, reward, is_done) = Python::with_gil(|py| {
let step = self.env.call_method(py, "step", (action,), None)?;
let step = step.as_ref(py);
let obs = step.get_item(0)?.call_method("flatten", (), None)?;
let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?;
let obs: Vec<u8> = obs_buffer.to_vec(py)?;
let reward: Vec<f32> = step.get_item(1)?.extract()?;
let is_done: Vec<f32> = step.get_item(2)?.extract()?;
Ok((obs, reward, is_done))
})
.map_err(w)?;
let obs = Tensor::from_vec(obs, self.observation_space.as_slice(), &Device::Cpu)?
.to_dtype(DType::F32)?;
let reward = Tensor::new(reward, &Device::Cpu)?;
let is_done = Tensor::new(is_done, &Device::Cpu)?;
Ok(Step {
obs,
reward,
is_done,
})
}
pub fn action_space(&self) -> usize {
self.action_space
}
pub fn observation_space(&self) -> &[usize] {
&self.observation_space
}
}

View File

@ -1,40 +0,0 @@
# candle-replit-code: code completion specialized model.
[replit-code-v1_5-3b](https://huggingface.co/replit/replit-code-v1_5-3b) is a
language model specialized for code completion. This model uses 3.3B parameters
in `bfloat16` (so the GPU version will only work on recent nvidia cards).
## Running some example
```bash
cargo run --example replit-code --release -- --prompt 'def fibonacci(n): '
```
This produces the following output.
```
def fibonacci(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
a, b = 0, 1
while a < n:
print(a, end=' ')
a, b = b, a+b
print()
def fibonacci_loop(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
result = []
a, b = 0, 1
while a < n:
result.append(a)
a, b = b, a+b
return result
def fibonacci_generator(n): # write Fibonacci series up to n
"""Print a Fibonacci series up to n."""
a, b = 0, 1
while a < n:
yield a
a, b = b, a+b
```

View File

@ -1,265 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;
use candle_transformers::models::mpt::{Config, Model as M};
use candle_transformers::models::quantized_mpt::Model as Q;
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::Tokenizer;
enum Model {
M(M),
Q(Q),
}
impl Model {
fn forward(&mut self, xs: &Tensor) -> candle::Result<Tensor> {
match self {
Self::M(model) => model.forward(xs),
Self::Q(model) => model.forward(xs),
}
}
}
struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
}
impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
}
}
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
use std::io::Write;
println!("starting the inference loop");
let tokens = self.tokenizer.encode(prompt, true).map_err(E::msg)?;
if tokens.is_empty() {
anyhow::bail!("Empty prompts are not supported in the phi model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => anyhow::bail!("cannot find the endoftext token"),
};
print!("{prompt}");
std::io::stdout().flush()?;
let start_gen = std::time::Instant::now();
for index in 0..sample_len {
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};
let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self.tokenizer.decode(&[next_token], true).map_err(E::msg)?;
print!("{token}");
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,
#[arg(long)]
prompt: String,
/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,
/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,
/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 1000)]
sample_len: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long)]
revision: Option<String>,
#[arg(long)]
quantized: bool,
#[arg(long)]
weight_file: Option<String>,
#[arg(long)]
tokenizer: Option<String>,
/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.)]
repeat_penalty: f32,
/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.),
args.repeat_penalty,
args.repeat_last_n
);
let start = std::time::Instant::now();
let api = Api::new()?;
let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "lmz/candle-replit-code".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};
let filename = match args.weight_file {
Some(weight_file) => std::path::PathBuf::from(weight_file),
None => {
if args.quantized {
repo.get("model-replit-code-v1_5-q4k.gguf")?
} else {
repo.get("model.safetensors")?
}
}
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now();
let config = Config::replit_code_v1_5_3b();
let (model, device) = if args.quantized {
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename)?;
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
(model, Device::Cpu)
} else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
(model, device)
};
println!("loaded the model in {:?}", start.elapsed());
let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
);
pipeline.run(&args.prompt, args.sample_len)?;
Ok(())
}

View File

@ -1,19 +0,0 @@
# candle-resnet
A candle implementation of inference using a pre-trained [ResNet](https://arxiv.org/abs/1512.03385).
This uses a classification head trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example resnet --release -- --image tiger.jpg
loaded image Tensor[dims 3, 224, 224; f32]
model built
tiger, Panthera tigris : 90.21%
tiger cat : 8.93%
lion, king of beasts, Panthera leo: 0.35%
leopard, Panthera pardus: 0.16%
jaguar, panther, Panthera onca, Felis onca: 0.09%
```

View File

@ -1,12 +0,0 @@
# This script exports pre-trained model weights in the safetensors format.
import numpy as np
import torch
import torchvision
from safetensors import torch as stt
m = torchvision.models.resnet50(pretrained=True)
stt.save_file(m.state_dict(), 'resnet50.safetensors')
m = torchvision.models.resnet101(pretrained=True)
stt.save_file(m.state_dict(), 'resnet101.safetensors')
m = torchvision.models.resnet152(pretrained=True)
stt.save_file(m.state_dict(), 'resnet152.safetensors')

View File

@ -1,90 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::resnet;
use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
#[value(name = "18")]
Resnet18,
#[value(name = "34")]
Resnet34,
#[value(name = "50")]
Resnet50,
#[value(name = "101")]
Resnet101,
#[value(name = "152")]
Resnet152,
}
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Variant of the model to use.
#[arg(value_enum, long, default_value_t = Which::Resnet18)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-resnet".into());
let filename = match args.which {
Which::Resnet18 => "resnet18.safetensors",
Which::Resnet34 => "resnet34.safetensors",
Which::Resnet50 => "resnet50.safetensors",
Which::Resnet101 => "resnet101.safetensors",
Which::Resnet152 => "resnet152.safetensors",
};
api.get(filename)?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let class_count = candle_examples::imagenet::CLASS_COUNT as usize;
let model = match args.which {
Which::Resnet18 => resnet::resnet18(class_count, vb)?,
Which::Resnet34 => resnet::resnet34(class_count, vb)?,
Which::Resnet50 => resnet::resnet50(class_count, vb)?,
Which::Resnet101 => resnet::resnet101(class_count, vb)?,
Which::Resnet152 => resnet::resnet152(class_count, vb)?,
};
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -50,9 +50,6 @@ cached.
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
and using the command line flag `--use-flash-attn`.
Note that flash-attention-v2 is only compatible with Ampere, Ada, or Hopper GPUs
(e.g., A100/H100, RTX 3090/4090).
## Image to Image Pipeline
...

View File

@ -143,6 +143,7 @@ fn main() -> Result<()> {
let args = Args::parse();
let _guard = if args.tracing {
println!("tracing...");
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)

View File

@ -1,13 +0,0 @@
## VGG Model Implementation
This example demonstrates the implementation of VGG models (VGG13, VGG16, VGG19) using the Candle library.
The VGG models are defined in `candle-transformers/src/models/vgg.rs`. The main function in `candle-examples/examples/vgg/main.rs` loads an image, selects the VGG model based on the provided argument, and applies the model to the loaded image.
You can run the example with the following command:
```bash
cargo run --example vgg --release -- --image ../yolo-v8/assets/bike.jpg --which vgg13
```
In the command above, `--image` specifies the path to the image file and `--which` specifies the VGG model to use (vgg13, vgg16, or vgg19).

View File

@ -1,77 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{DType, IndexOp, D};
use candle_nn::{ModuleT, VarBuilder};
use candle_transformers::models::vgg::{Models, Vgg};
use clap::{Parser, ValueEnum};
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
Vgg13,
Vgg16,
Vgg19,
}
#[derive(Parser)]
struct Args {
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Variant of the model to use.
#[arg(value_enum, long, default_value_t = Which::Vgg13)]
which: Which,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let api = hf_hub::api::sync::Api::new()?;
let repo = match args.which {
Which::Vgg13 => "timm/vgg13.tv_in1k",
Which::Vgg16 => "timm/vgg16.tv_in1k",
Which::Vgg19 => "timm/vgg19.tv_in1k",
};
let api = api.model(repo.into());
let filename = "model.safetensors";
let model_file = api.get(filename)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = match args.which {
Which::Vgg13 => Vgg::new(vb, Models::Vgg13)?,
Which::Vgg16 => Vgg::new(vb, Models::Vgg16)?,
Which::Vgg19 => Vgg::new(vb, Models::Vgg19)?,
};
let logits = model.forward_t(&image, /*train=*/ false)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
// Sort the predictions and take the top 5
let mut top: Vec<_> = prs.iter().enumerate().collect();
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
let top = top.into_iter().take(5).collect::<Vec<_>>();
// Print the top predictions
for &(i, p) in &top {
println!(
"{:50}: {:.2}%",
candle_examples::imagenet::CLASSES[i],
p * 100.0
);
}
Ok(())
}

View File

@ -1,20 +0,0 @@
# candle-vit
Vision Transformer (ViT) model implementation following the lines of
[vit-base-patch16-224](https://huggingface.co/google/vit-base-patch16-224)
This uses a classification head trained on the ImageNet dataset and returns the
probabilities for the top-5 classes.
## Running an example
```
$ cargo run --example vit --release -- --image tiger.jpg
loaded image Tensor[dims 3, 224, 224; f32]
model built
tiger, Panthera tigris : 100.00%
tiger cat : 0.00%
jaguar, panther, Panthera onca, Felis onca: 0.00%
leopard, Panthera pardus: 0.00%
lion, king of beasts, Panthera leo: 0.00%
```

View File

@ -1,59 +0,0 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use clap::Parser;
use candle::{DType, IndexOp, D};
use candle_nn::VarBuilder;
use candle_transformers::models::vit;
#[derive(Parser)]
struct Args {
#[arg(long)]
model: Option<String>,
#[arg(long)]
image: String,
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
}
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;
let image = candle_examples::imagenet::load_image224(args.image)?;
println!("loaded image {image:?}");
let model_file = match args.model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("google/vit-base-patch16-224".into());
api.get("model.safetensors")?
}
Some(model) => model.into(),
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let model = vit::Model::new(&vit::Config::vit_base_patch16_224(), 1000, vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
.i(0)?
.to_vec1::<f32>()?;
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
for &(category_idx, pr) in prs.iter().take(5) {
println!(
"{:24}: {:.2}%",
candle_examples::imagenet::CLASSES[category_idx],
100. * pr
);
}
Ok(())
}

View File

@ -108,7 +108,7 @@ pub fn parse_config<T: AsRef<Path>>(path: T) -> Result<Darknet> {
}
enum Bl {
Layer(Box<dyn candle_nn::Module + Send + Sync>),
Layer(Box<dyn candle_nn::Module + Send>),
Route(Vec<usize>),
Shortcut(usize),
Yolo(usize, Vec<(usize, usize)>),

View File

@ -1,5 +1,7 @@
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{batch_norm, conv2d, conv2d_no_bias, Conv2d, Conv2dConfig, Module, VarBuilder};
use candle_nn::{
batch_norm, conv2d, conv2d_no_bias, BatchNorm, Conv2d, Conv2dConfig, Module, VarBuilder,
};
#[derive(Clone, Copy, PartialEq, Debug)]
pub struct Multiples {
@ -74,6 +76,7 @@ impl Module for Upsample {
#[derive(Debug)]
struct ConvBlock {
conv: Conv2d,
bn: BatchNorm,
span: tracing::Span,
}
@ -93,10 +96,11 @@ impl ConvBlock {
groups: 1,
dilation: 1,
};
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?;
let bn = batch_norm(c2, 1e-3, vb.pp("bn"))?;
let conv = conv2d_no_bias(c1, c2, k, cfg, vb.pp("conv"))?.absorb_bn(&bn)?;
Ok(Self {
conv,
bn,
span: tracing::span!(tracing::Level::TRACE, "conv-block"),
})
}
@ -106,6 +110,7 @@ impl Module for ConvBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let xs = self.conv.forward(xs)?;
let xs = self.bn.forward(&xs)?;
candle_nn::ops::silu(&xs)
}
}

View File

@ -2,30 +2,17 @@ pub mod coco_classes;
pub mod imagenet;
pub mod token_output_stream;
use candle::utils::{cuda_is_available, metal_is_available};
use candle::{Device, Result, Tensor};
pub fn device(cpu: bool) -> Result<Device> {
if cpu {
Ok(Device::Cpu)
} else {
if cuda_is_available() {
Ok(Device::new_cuda(0)?)
} else if metal_is_available() {
Ok(Device::new_metal(0)?)
} else {
#[cfg(all(target_os = "macos", target_arch = "aarch64"))]
{
println!("Running on CPU, to run on GPU(metal), build this example with `--features metal`");
}
#[cfg(not(all(target_os = "macos", target_arch = "aarch64")))]
{
println!(
"Running on CPU, to run on GPU, build this example with `--features cuda`"
);
}
Ok(Device::Cpu)
let device = Device::cuda_if_available(0)?;
if !device.is_cuda() {
println!("Running on CPU, to run on GPU, build this example with `--features cuda`");
}
Ok(device)
}
}

View File

@ -84,19 +84,12 @@ fn main() -> Result<()> {
(kernel_dir.join(f), obj_file)
})
.collect();
let out_modified: Result<_, _> = out_file.metadata().and_then(|m| m.modified());
let should_compile = if out_file.exists() {
kernel_dir
.read_dir()
.expect("kernels folder should exist")
.any(|entry| {
if let (Ok(entry), Ok(out_modified)) = (entry, &out_modified) {
let in_modified = entry.metadata().unwrap().modified().unwrap();
in_modified.duration_since(*out_modified).is_ok()
} else {
true
}
})
cu_files.iter().any(|(cu_file, _)| {
let out_modified = out_file.metadata().unwrap().modified().unwrap();
let in_modified = cu_file.metadata().unwrap().modified().unwrap();
in_modified.duration_since(out_modified).is_ok()
})
} else {
true
};
@ -107,19 +100,12 @@ fn main() -> Result<()> {
let mut command = std::process::Command::new("nvcc");
command
.arg("-std=c++17")
.arg("-O3")
.arg("-U__CUDA_NO_HALF_OPERATORS__")
.arg("-U__CUDA_NO_HALF_CONVERSIONS__")
.arg("-U__CUDA_NO_HALF2_OPERATORS__")
.arg("-U__CUDA_NO_BFLOAT16_CONVERSIONS__")
.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("-c")
.args(["-o", obj_file.to_str().unwrap()])
.args(["--default-stream", "per-thread"])
.arg("-Icutlass/include")
.arg("--expt-relaxed-constexpr")
.arg("--expt-extended-lambda")
.arg("--use_fast_math")
.arg("--verbose");
if let Ok(ccbin_path) = &ccbin_env {
command
@ -217,21 +203,13 @@ fn set_cuda_include_dir() -> Result<()> {
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse compute cap")?
} else {
// Use nvidia-smi to get the current compute cap
// Grab compute code from nvidia-smi
let mut compute_cap = {
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
@ -242,19 +220,16 @@ fn compute_cap() -> Result<usize> {
.next()
.context("missing line in stdout")?
.replace('.', "");
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
cap.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?
};
// Grab available GPU codes from nvcc and select the highest one
let (supported_nvcc_codes, max_nvcc_code) = {
let max_nvcc_code = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
@ -268,21 +243,30 @@ fn compute_cap() -> Result<usize> {
}
}
codes.sort();
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
(codes, max_nvcc_code)
if !codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}."
);
}
*codes.last().unwrap()
};
// Check that nvcc supports the asked compute caps
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
// then choose the highest gpu code in nvcc
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
println!(
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
);
compute_cap = max_nvcc_code;
}
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str
.parse::<usize>()
.with_context(|| format!("cannot parse as usize '{compute_cap_str}'"))?;
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
}
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
Ok(compute_cap)
}

View File

@ -12,6 +12,5 @@ license = "MIT OR Apache-2.0"
[dependencies]
[build-dependencies]
anyhow = { version = "1", features = ["backtrace"] }
glob = "0.3.1"
rayon = "1.7.0"
rayon = "1.7.0"

View File

@ -1,5 +1,4 @@
use std::io::Write;
fn main() {
println!("cargo:rerun-if-changed=build.rs");
@ -24,8 +23,6 @@ fn main() {
}
mod cuda {
use anyhow::{Context, Result};
pub fn set_include_dir() {
use std::path::PathBuf;
// NOTE: copied from cudarc build.rs.
@ -103,52 +100,107 @@ mod cuda {
include_directories.sort();
include_directories.dedup();
let compute_cap = compute_cap().expect("Could not get Cuda compute cap");
#[allow(unused)]
let include_options: Vec<String> = include_directories
.into_iter()
.map(|s| "-I".to_string() + &s.into_os_string().into_string().unwrap())
.collect::<Vec<_>>();
// let start = std::time::Instant::now();
// Grab compute code from nvidia-smi
let mut compute_cap = {
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.expect("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let mut lines = out.lines();
assert_eq!(lines.next().unwrap(), "compute_cap");
let cap = lines.next().unwrap().replace('.', "");
cap.parse::<usize>().unwrap()
};
// Grab available GPU codes from nvcc and select the highest one
let max_nvcc_code = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
if !codes.contains(&compute_cap) {
panic!("nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {codes:?}.");
}
*codes.last().unwrap()
};
// If nvidia-smi compute_cap is higher than the highest gpu code from nvcc,
// then choose the highest gpu code in nvcc
if compute_cap > max_nvcc_code {
println!(
"cargo:warning=Lowering gpu arch {compute_cap} to max nvcc target {max_nvcc_code}."
);
compute_cap = max_nvcc_code;
}
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
compute_cap = compute_cap_str.parse::<usize>().unwrap();
println!("cargo:warning=Using gpu arch {compute_cap} from $CUDA_COMPUTE_CAP");
}
println!("cargo:rustc-env=CUDA_COMPUTE_CAP=sm_{compute_cap}");
let ccbin_env = std::env::var("CANDLE_NVCC_CCBIN");
println!("cargo:rerun-if-env-changed=CANDLE_NVCC_CCBIN");
let children = kernel_paths
.par_iter()
.flat_map(|p| {
let mut output = p.clone();
output.set_extension("ptx");
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
.par_iter()
.flat_map(|p| {
let mut output = p.clone();
output.set_extension("ptx");
let output_filename = std::path::Path::new(&out_dir).to_path_buf().join("out").with_file_name(output.file_name().unwrap());
let ignore = if output_filename.exists() {
let out_modified = output_filename.metadata().unwrap().modified().unwrap();
let in_modified = p.metadata().unwrap().modified().unwrap();
out_modified.duration_since(in_modified).is_ok()
} else {
false
};
if ignore {
None
} else {
let mut command = std::process::Command::new("nvcc");
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", &out_dir])
// Flash attention only
// .arg("--expt-relaxed-constexpr")
.args(&include_options);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(p);
Some((p, command.spawn()
let ignore = if output_filename.exists() {
let out_modified = output_filename.metadata().unwrap().modified().unwrap();
let in_modified = p.metadata().unwrap().modified().unwrap();
out_modified.duration_since(in_modified).is_ok()
}else{
false
};
if ignore{
None
}else{
let mut command = std::process::Command::new("nvcc");
command.arg(format!("--gpu-architecture=sm_{compute_cap}"))
.arg("--ptx")
.args(["--default-stream", "per-thread"])
.args(["--output-directory", &out_dir])
// Flash attention only
// .arg("--expt-relaxed-constexpr")
.args(&include_options);
if let Ok(ccbin_path) = &ccbin_env {
command
.arg("-allow-unsupported-compiler")
.args(["-ccbin", ccbin_path]);
}
command.arg(p);
Some((p, command.spawn()
.expect("nvcc failed to start. Ensure that you have CUDA installed and that `nvcc` is in your PATH.").wait_with_output()))
}
})
.collect::<Vec<_>>();
}})
.collect::<Vec<_>>();
let ptx_paths: Vec<PathBuf> = glob::glob(&format!("{out_dir}/**/*.ptx"))
.unwrap()
@ -168,76 +220,4 @@ mod cuda {
}
(write, kernel_paths)
}
#[allow(unused)]
fn compute_cap() -> Result<usize> {
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");
// Try to parse compute caps from env
let mut compute_cap = if let Ok(compute_cap_str) = std::env::var("CUDA_COMPUTE_CAP") {
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={compute_cap_str}");
compute_cap_str
.parse::<usize>()
.context("Could not parse code")?
} else {
// Use nvidia-smi to get the current compute cap
let out = std::process::Command::new("nvidia-smi")
.arg("--query-gpu=compute_cap")
.arg("--format=csv")
.output()
.context("`nvidia-smi` failed. Ensure that you have CUDA installed and that `nvidia-smi` is in your PATH.")?;
let out = std::str::from_utf8(&out.stdout).context("stdout is not a utf8 string")?;
let mut lines = out.lines();
assert_eq!(
lines.next().context("missing line in stdout")?,
"compute_cap"
);
let cap = lines
.next()
.context("missing line in stdout")?
.replace('.', "");
let cap = cap
.parse::<usize>()
.with_context(|| format!("cannot parse as int {cap}"))?;
println!("cargo:rustc-env=CUDA_COMPUTE_CAP={cap}");
cap
};
// Grab available GPU codes from nvcc and select the highest one
let (supported_nvcc_codes, max_nvcc_code) = {
let out = std::process::Command::new("nvcc")
.arg("--list-gpu-code")
.output()
.expect("`nvcc` failed. Ensure that you have CUDA installed and that `nvcc` is in your PATH.");
let out = std::str::from_utf8(&out.stdout).unwrap();
let out = out.lines().collect::<Vec<&str>>();
let mut codes = Vec::with_capacity(out.len());
for code in out {
let code = code.split('_').collect::<Vec<&str>>();
if !code.is_empty() && code.contains(&"sm") {
if let Ok(num) = code[1].parse::<usize>() {
codes.push(num);
}
}
}
codes.sort();
let max_nvcc_code = *codes.last().context("no gpu codes parsed from nvcc")?;
(codes, max_nvcc_code)
};
// Check that nvcc supports the asked compute caps
if !supported_nvcc_codes.contains(&compute_cap) {
anyhow::bail!(
"nvcc cannot target gpu arch {compute_cap}. Available nvcc targets are {supported_nvcc_codes:?}."
);
}
if compute_cap > max_nvcc_code {
anyhow::bail!(
"CUDA compute cap {compute_cap} is higher than the highest gpu code from nvcc {max_nvcc_code}"
);
}
Ok(compute_cap)
}
}

View File

@ -1,53 +1,6 @@
#include "cuda_utils.cuh"
#include<stdint.h>
template <typename S, typename T>
__device__ void cast_(
const size_t numel,
const size_t num_dims,
const size_t *info,
const S *inp,
T *out
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = inp[i];
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = inp[strided_i];
}
}
}
template <typename S, typename T, typename I>
__device__ void cast_through(
const size_t numel,
const size_t num_dims,
const size_t *info,
const S *inp,
T *out
) {
const size_t *dims = info;
const size_t *strides = info + num_dims;
if (is_contiguous(num_dims, dims, strides)) {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
out[i] = static_cast<T>(static_cast<I>(inp[i]));
}
}
else {
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
unsigned strided_i = get_strided_index(i, num_dims, dims, strides);
out[i] = static_cast<T>(static_cast<I>(inp[strided_i]));
}
}
}
#define CAST_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
@ -56,10 +9,22 @@ extern "C" __global__ void FN_NAME( \
const SRC_TYPENAME *inp, \
DST_TYPENAME *out \
) { \
cast_<SRC_TYPENAME, DST_TYPENAME>(numel, num_dims, info, inp, out); \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
out[i] = inp[i]; \
} \
} \
else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
out[i] = inp[strided_i]; \
} \
} \
} \
#define CAST_THROUGH_OP(SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME, FN_NAME) \
#define CAST_BF_OP(SRC_TYPENAME, DST_TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const size_t numel, \
const size_t num_dims, \
@ -67,12 +32,25 @@ extern "C" __global__ void FN_NAME( \
const SRC_TYPENAME *inp, \
DST_TYPENAME *out \
) { \
cast_through<SRC_TYPENAME, DST_TYPENAME, INT_TYPENAME>(numel, num_dims, info, inp, out); \
const size_t *dims = info; \
const size_t *strides = info + num_dims; \
if (is_contiguous(num_dims, dims, strides)) { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
out[i] = (DST_TYPENAME) (float) inp[i]; \
} \
} \
else { \
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
out[i] = (DST_TYPENAME) (float) inp[strided_i]; \
} \
} \
} \
#if __CUDA_ARCH__ >= 800
CAST_OP(__nv_bfloat16, __nv_bfloat16, cast_bf16_bf16)
// CAST_OP(__nv_bfloat16, uint8_t, cast_bf16_u8)
CAST_OP(__nv_bfloat16, uint32_t, cast_bf16_u32)
CAST_OP(__nv_bfloat16, float, cast_bf16_f32)
CAST_OP(__nv_bfloat16, double, cast_bf16_f64)
@ -80,15 +58,14 @@ CAST_OP(uint8_t, __nv_bfloat16, cast_u8_bf16)
CAST_OP(uint32_t, __nv_bfloat16, cast_u32_bf16)
CAST_OP(float, __nv_bfloat16, cast_f32_bf16)
CAST_OP(double, __nv_bfloat16, cast_f64_bf16)
CAST_THROUGH_OP(__nv_bfloat16, uint8_t, float, cast_bf16_u8)
CAST_THROUGH_OP(__nv_bfloat16, __half, float, cast_bf16_f16)
CAST_THROUGH_OP(__half, __nv_bfloat16, float, cast_f16_bf16)
CAST_BF_OP(__nv_bfloat16, __half, cast_bf16_f16)
CAST_BF_OP(__half, __nv_bfloat16, cast_f16_bf16)
#endif
#if __CUDA_ARCH__ >= 530
CAST_OP(__half, __half, cast_f16_f16)
CAST_THROUGH_OP(__half, uint8_t, float, cast_f16_u8)
// CAST_OP(__half, uint8_t, cast_f16_u8 )
CAST_OP(__half, uint32_t, cast_f16_u32)
CAST_OP(__half, float, cast_f16_f32)
CAST_OP(__half, double, cast_f16_f64)

View File

@ -1,17 +0,0 @@
[package]
name = "candle-metal-kernels"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true
[dependencies]
metal = { workspace = true }
once_cell = "1.18.0"
thiserror = { workspace = true }
[dev-dependencies]
half = { workspace = true }

View File

@ -1,3 +0,0 @@
# candle-metal-kernels
This crate contains Metal kernels used from candle.

View File

@ -1,44 +0,0 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define AFFINE(FN_NAME, TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
constant float &mul, \
constant float &add, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = input[i] * mul + add; \
} \
} \
AFFINE(affine_float, float)
AFFINE(affine_half, half)
#if __METAL_VERSION__ >= 310
AFFINE(affine_bfloat, bfloat);
#endif

View File

@ -1,78 +0,0 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define BINARY(FN, TYPENAME, OUT_TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[i]; \
TYPENAME y = right[i]; \
output[i] = OUT_TYPENAME(FN); \
} \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *left_strides, \
constant size_t *right_strides, \
device const TYPENAME *left, \
device const TYPENAME *right, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
TYPENAME x = left[get_strided_index(i, num_dims, dims, left_strides)]; \
TYPENAME y = left[get_strided_index(i, num_dims, dims, right_strides)]; \
output[i] = OUT_TYPENAME(FN); \
} \
}
#define BINARY_OP(FN, NAME) \
BINARY(FN, float, float, NAME##_float, NAME##_float_strided); \
BINARY(FN, half, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_BINARY_OP(FN, NAME) \
BINARY(FN, bfloat, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
BINARY_OP(x + y, add)
BINARY_OP(x - y, sub)
BINARY_OP(x * y, mul)
BINARY_OP(x / y, div)
#if __METAL_VERSION__ >= 310
BFLOAT_BINARY_OP(x + y, add)
BFLOAT_BINARY_OP(x - y, sub)
BFLOAT_BINARY_OP(x * y, mul)
BFLOAT_BINARY_OP(x / y, div)
#endif

View File

@ -1,58 +0,0 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
using namespace metal;
#define CAST(FN_NAME, FN_NAME_STRIDED, LEFT_TYPENAME, RIGHT_TYPENAME) \
kernel void FN_NAME( \
constant size_t &dim, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[i]); \
} \
} \
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const LEFT_TYPENAME *input, \
device RIGHT_TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = RIGHT_TYPENAME(input[get_strided_index(i, num_dims, dims, strides)]); \
} \
}
CAST(cast_u32_f32, cast_u32_f32_strided, int32_t, float)
#if __METAL_VERSION__ >= 310
#endif

View File

@ -1,75 +0,0 @@
#include <metal_stdlib>
using namespace metal;
template <typename T, typename I>
void index_add(
device I *ids [[buffer(0)]],
device T *inp [[buffer(1)]],
device T *out [[buffer(2)]],
constant uint &ids_dim_size,
constant uint &left_size,
constant uint &dst_dim_size,
constant uint &right_size,
uint threadgroup_size [[threads_per_threadgroup]],
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]],
uint thread_index [[thread_index_in_threadgroup]]
) {
const uint gid = thread_index + (threadgroup_position_in_grid * threadgroup_size);
if (gid >= left_size * right_size) {
return;
}
const uint i = gid;
const uint pre = i / right_size;
const uint post = i % right_size;
for (uint j = 0; j < ids_dim_size; j++) {
const uint idx = ids[j];
const uint src_i = (pre * ids_dim_size + j) * right_size + post;
const uint dst_i = (pre * dst_dim_size + idx) * right_size + post;
out[dst_i] += inp[src_i];
}
}
#define IA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
kernel void FN_NAME( \
device INDEX_TYPENAME *ids [[buffer(0)]], \
device TYPENAME *inp [[buffer(1)]], \
device TYPENAME *out [[buffer(2)]], \
constant uint &ids_dim_size, \
constant uint &left_size, \
constant uint &dst_dim_size, \
constant uint &right_size, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint threadgroup_position_in_grid [[threadgroup_position_in_grid]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { index_add<TYPENAME, INDEX_TYPENAME>(ids, inp, out, ids_dim_size, left_size, dst_dim_size, right_size, threadgroup_size, threadgroup_position_in_grid, thread_index); } \
#if __METAL_VERSION__ >= 310
IA_OP(bfloat, int64_t, ia_i64_bf16)
IA_OP(bfloat, uint32_t, ia_u32_bf16)
IA_OP(bfloat, uint8_t, ia_u8_bf16)
#endif
IA_OP(half, uint32_t, ia_u32_f16)
IA_OP(half, uint8_t, ia_u8_f16)
IA_OP(float, int64_t, ia_i64_f32)
IA_OP(uint8_t, int64_t, ia_i64_u8)
IA_OP(int64_t, int64_t, ia_i64_i64)
IA_OP(uint32_t, int64_t, ia_i64_u32)
IA_OP(float, uint32_t, ia_u32_f32)
IA_OP(uint8_t, uint32_t, ia_u32_u8)
IA_OP(int64_t, uint32_t, ia_u32_i64)
IA_OP(uint32_t, uint32_t, ia_u32_u32)
IA_OP(float, uint8_t, ia_u8_f32)
IA_OP(uint8_t, uint8_t, ia_u8_u8)
IA_OP(uint32_t, uint8_t, ia_u8_u32)
IA_OP(int64_t, uint8_t, ia_u8_i64)

View File

@ -1,862 +0,0 @@
use metal::{
Buffer, CommandBufferRef, CompileOptions, ComputePipelineDescriptor, Device, Function, Library,
MTLSize,
};
use std::collections::HashMap;
use std::ffi::c_void;
use std::sync::RwLock;
const AFFINE: &str = include_str!("affine.metal");
const INDEXING: &str = include_str!("indexing.metal");
const UNARY: &str = include_str!("unary.metal");
const BINARY: &str = include_str!("binary.metal");
const CAST: &str = include_str!("cast.metal");
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
Indexing,
Unary,
Binary,
Cast,
}
macro_rules! ops{
($($name:ident),+) => {
pub mod contiguous {
pub struct Kernel(pub(crate) &'static str);
$(
pub mod $name {
use super::Kernel;
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat"));
}
)+
}
pub mod strided {
pub struct Kernel(pub(crate) &'static str);
$(
pub mod $name {
use super::Kernel;
pub const FLOAT: Kernel = Kernel(concat!(stringify!($name), "_float_strided"));
pub const HALF: Kernel = Kernel(concat!(stringify!($name), "_half_strided"));
pub const BFLOAT: Kernel = Kernel(concat!(stringify!($name), "_bfloat_strided"));
}
)+
}
};
}
pub mod unary {
ops!(cos, sin, exp, sqr, sqrt, neg, copy);
}
pub mod binary {
ops!(add, sub, mul, div);
}
// static LIBRARY_SOURCES: Lazy<HashMap<&'static str, &'static str>> = Lazy::new(|| {
// let mut l = HashMap::new();
// l.insert("affine", AFFINE);
// l.insert("indexing", INDEXING);
// l.insert("unary", UNARY);
// l
// });
//
#[derive(thiserror::Error, Debug)]
pub enum MetalKernelError {
#[error("Could not lock kernel map: {0}")]
LockError(String),
#[error("Error while loading library: {0}")]
LoadLibraryError(String),
#[error("Error while loading function: {0}")]
LoadFunctionError(String),
}
impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
fn from(e: std::sync::PoisonError<T>) -> Self {
Self::LockError(e.to_string())
}
}
type KernelMap<T> = HashMap<&'static str, T>;
type Libraries = HashMap<Source, Library>;
type Functions = KernelMap<Function>;
#[derive(Debug)]
pub struct Kernels {
libraries: RwLock<Libraries>,
funcs: RwLock<Functions>,
}
impl Kernels {
pub fn new() -> Self {
let libraries = RwLock::new(Libraries::new());
let funcs = RwLock::new(Functions::new());
Self { libraries, funcs }
}
// pub fn init(device: &Device) -> Result<Self, MetalKernelError> {
// let kernels = Self::new();
// kernels.load_libraries(device)?;
// Ok(kernels)
// }
// fn load_libraries(&self, device: &Device) -> Result<(), MetalKernelError> {
// for name in LIBRARY_SOURCES.keys() {
// self.load_library(device, name)?;
// }
// Ok(())
// }
fn get_library_source(&self, source: Source) -> &'static str {
// LIBRARY_SOURCES.get(name).cloned()
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
Source::Binary => BINARY,
Source::Indexing => INDEXING,
Source::Cast => CAST,
}
}
pub fn load_library(
&self,
device: &Device,
source: Source,
) -> Result<Library, MetalKernelError> {
let mut libraries = self.libraries.write()?;
if let Some(lib) = libraries.get(&source) {
Ok(lib.clone())
} else {
let source_content = self.get_library_source(source);
let lib = device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?;
libraries.insert(source, lib.clone());
Ok(lib)
}
}
pub fn load_function(
&self,
device: &Device,
source: Source,
name: &'static str,
) -> Result<Function, MetalKernelError> {
let mut funcs = self.funcs.write()?;
if let Some(func) = funcs.get(name) {
Ok(func.clone())
} else {
let func = self
.load_library(device, source)?
.get_function(name, None)
.map_err(|e| MetalKernelError::LoadFunctionError(e.to_string()))?;
funcs.insert(name, func.clone());
Ok(func)
}
}
}
pub fn call_unary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: unary::contiguous::Kernel,
length: usize,
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Unary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&length));
encoder.set_buffer(1, Some(&input), 0);
encoder.set_buffer(2, Some(&output), 0);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn call_unary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: unary::strided::Kernel,
shape: &[usize],
input: &Buffer,
strides: &[usize],
offset: usize,
output: &mut Buffer,
output_offset: usize,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Unary, name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let num_dims: usize = shape.len() as usize;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length));
encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
encoder.set_bytes(
2,
(shape.len() * std::mem::size_of::<usize>()) as u64,
shape.as_ptr() as *const c_void,
);
encoder.set_bytes(
3,
(strides.len() * std::mem::size_of::<usize>()) as u64,
strides.as_ptr() as *const c_void,
);
encoder.set_buffer(4, Some(&input), offset as u64);
encoder.set_buffer(5, Some(&output), output_offset as u64);
let width = output.length();
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn call_binary_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: binary::contiguous::Kernel,
length: usize,
left: &Buffer,
right: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Binary, kernel_name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&length));
encoder.set_buffer(1, Some(&left), 0);
encoder.set_buffer(2, Some(&right), 0);
encoder.set_buffer(3, Some(&output), 0);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn call_binary_strided(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
name: binary::strided::Kernel,
shape: &[usize],
left_input: &Buffer,
left_strides: &[usize],
left_offset: usize,
right_input: &Buffer,
right_strides: &[usize],
right_offset: usize,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Binary, name.0)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let num_dims: usize = shape.len() as usize;
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
let length: usize = shape.iter().product();
encoder.set_bytes(0, std::mem::size_of::<usize>() as u64, void_ptr(&length));
encoder.set_bytes(1, std::mem::size_of::<usize>() as u64, void_ptr(&num_dims));
encoder.set_bytes(
2,
(shape.len() * std::mem::size_of::<usize>()) as u64,
shape.as_ptr() as *const c_void,
);
encoder.set_bytes(
3,
(left_strides.len() * std::mem::size_of::<usize>()) as u64,
left_strides.as_ptr() as *const c_void,
);
encoder.set_bytes(
4,
(right_strides.len() * std::mem::size_of::<usize>()) as u64,
right_strides.as_ptr() as *const c_void,
);
encoder.set_buffer(5, Some(&left_input), left_offset as u64);
encoder.set_buffer(6, Some(&right_input), right_offset as u64);
encoder.set_buffer(7, Some(&output), 0);
let width = output.length();
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: std::cmp::min(pipeline.max_total_threads_per_threadgroup(), width),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn call_cast_contiguous(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
kernel_name: &'static str,
length: usize,
input: &Buffer,
output: &mut Buffer,
) -> Result<(), MetalKernelError> {
// println!("Kernel {:?}", kernel_name.0);
// assert_eq!(input.length(), output.length());
let func = kernels.load_function(device, Source::Cast, kernel_name)?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, 4, void_ptr(&length));
encoder.set_buffer(1, Some(&input), 0);
encoder.set_buffer(2, Some(&output), 0);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), length as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
pub fn void_ptr<T>(v: &T) -> *const c_void {
(v as *const T).cast()
}
pub fn call_affine(
device: &Device,
command_buffer: &CommandBufferRef,
kernels: &Kernels,
size: usize,
input: &Buffer,
output: &mut Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
let func = kernels.load_function(device, Source::Affine, "affine_float")?;
let pipeline_state_descriptor = ComputePipelineDescriptor::new();
pipeline_state_descriptor.set_compute_function(Some(&func));
let pipeline = device
.new_compute_pipeline_state_with_function(
pipeline_state_descriptor.compute_function().unwrap(),
)
.unwrap();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_bytes(0, core::mem::size_of::<usize>() as u64, void_ptr(&size));
encoder.set_bytes(1, core::mem::size_of::<f32>() as u64, void_ptr(&mul));
encoder.set_bytes(2, core::mem::size_of::<f32>() as u64, void_ptr(&add));
encoder.set_buffer(3, Some(&input), 0);
encoder.set_buffer(4, Some(&output), 0);
let thread_group_count = MTLSize {
width: 1,
height: 1,
depth: 1,
};
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size as u64);
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
encoder.end_encoding();
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use half::f16;
use metal::{CompileOptions, Device, MTLResourceOptions, MTLSize, NSUInteger};
use std::mem;
fn device() -> Device {
Device::system_default().unwrap()
}
fn approx(v: Vec<f32>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t * b) / b).collect()
}
fn approx_f16(v: Vec<f16>, digits: i32) -> Vec<f32> {
let b = 10f32.powi(digits);
v.iter().map(|t| f32::round(t.to_f32() * b) / b).collect()
}
fn run<T: Clone>(v: &[T], name: unary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
options,
);
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
call_unary_contiguous(
&device,
&command_buffer,
&kernels,
name,
v.len(),
&input,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
fn run_binary<T: Clone>(x: &[T], y: &[T], name: binary::contiguous::Kernel) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let left = device.new_buffer_with_data(
x.as_ptr() as *const core::ffi::c_void,
(x.len() * core::mem::size_of::<T>()) as u64,
options,
);
let right = device.new_buffer_with_data(
y.as_ptr() as *const core::ffi::c_void,
(y.len() * core::mem::size_of::<T>()) as u64,
options,
);
let mut output = device.new_buffer((x.len() * core::mem::size_of::<T>()) as u64, options);
call_binary_contiguous(
&device,
&command_buffer,
&kernels,
name,
x.len(),
&left,
&right,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(x.len())
}
fn run_strided<T: Clone>(
v: &[T],
kernel: unary::strided::Kernel,
shape: &[usize],
strides: &[usize],
offset: usize,
) -> Vec<T> {
let device = device();
let options = MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
options,
);
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
let kernels = Kernels::new();
call_unary_strided(
&device,
&command_buffer,
&kernels,
kernel,
shape,
&input,
strides,
offset,
&mut output,
0,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
#[test]
fn cos_f32() {
let v = vec![1.0f32, 2.0, 3.0];
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403, -0.4161, -0.99]);
assert_eq!(approx(expected, 4), vec![0.5403, -0.4161, -0.99]);
let v = vec![1.0f32; 10_000];
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
#[test]
fn cos_f32_strided() {
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
// Shape = [6], strides = [1];
let shape = vec![6];
let strides = vec![1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Contiguous
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![3, 2];
let strides = vec![2, 1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Transposed
let v = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let shape = vec![3, 2];
let strides = vec![1, 3];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(
approx(results, 4),
vec![0.5403, -0.6536, -0.4161, 0.2837, -0.99, 0.9602]
);
assert_eq!(
approx(expected, 4),
vec![0.5403, -0.4161, -0.99, -0.6536, 0.2837, 0.9602]
);
// Very large
let v = vec![1.0f32; 10_000];
let shape = vec![2, 5_000];
let strides = vec![2, 1];
let offset = 0;
let results = run_strided(&v, unary::strided::cos::FLOAT, &shape, &strides, offset);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
#[test]
fn binary_add_f32() {
let left = vec![1.0f32, 2.0, 3.0];
let right = vec![2.0f32, 3.1, 4.2];
let results = run_binary(&left, &right, binary::contiguous::add::FLOAT);
let expected: Vec<_> = left
.iter()
.zip(right.iter())
.map(|(&x, &y)| x + y)
.collect();
assert_eq!(approx(results, 4), vec![3.0f32, 5.1, 7.2]);
assert_eq!(approx(expected, 4), vec![3.0f32, 5.1, 7.2]);
}
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
options,
);
let mut output = device.new_buffer((v.len() * core::mem::size_of::<U>()) as u64, options);
call_cast_contiguous(
&device,
&command_buffer,
&kernels,
name,
v.len(),
&input,
&mut output,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<U>(v.len())
}
#[test]
fn cast_u32_f32() {
let v = vec![1u32, 2, 3];
let results = cast(&v, "cast_u32_f32");
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
let v = vec![1.0f32; 10_000];
let results = run(&v, unary::contiguous::cos::FLOAT);
let expected: Vec<_> = v.iter().map(|v| v.cos()).collect();
assert_eq!(approx(results, 4), vec![0.5403; 10_000]);
assert_eq!(approx(expected, 4), vec![0.5403; 10_000]);
}
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
let device = device();
let kernels = Kernels::new();
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let options = MTLResourceOptions::StorageModeManaged;
let input = device.new_buffer_with_data(
v.as_ptr() as *const core::ffi::c_void,
(v.len() * core::mem::size_of::<T>()) as u64,
options,
);
let mut output = device.new_buffer((v.len() * core::mem::size_of::<T>()) as u64, options);
let size = v.len();
call_affine(
&device,
&command_buffer,
&kernels,
size,
&input,
&mut output,
mul as f32,
add as f32,
)
.unwrap();
command_buffer.commit();
command_buffer.wait_until_completed();
output.read_to_vec::<T>(v.len())
}
#[test]
fn affine() {
let input = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let mul = 1.5;
let add = 1.1;
let result = run_affine(&input, mul, add);
assert_eq!(result, vec![2.6, 4.1, 5.6, 7.1, 8.6, 10.1, 11.6, 13.1]);
let input = [1.0f32; 40_000];
let mul = 1.5;
let add = 1.1;
let result = run_affine(&input, mul, add);
assert_eq!(result, vec![2.6; 40_000]);
}
#[test]
fn index_add() {
let device = Device::system_default().expect("no device found");
let options = CompileOptions::new();
let library = device.new_library_with_source(INDEXING, &options).unwrap();
let left = [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
let right = [1.0f32; 15];
let index = [0u32, 4, 2];
let ids_dim_size = index.len() as u32;
let dst_dim_size: u32 = 15;
let left_size: u32 = 3;
let right_size: u32 = 3;
let function = library.get_function("ia_u32_f32", None).unwrap();
let pipeline = device
.new_compute_pipeline_state_with_function(&function)
.unwrap();
let options = MTLResourceOptions::StorageModeManaged;
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
let ids_size = (index.len() * mem::size_of::<u32>()) as NSUInteger;
let input_size = (left.len() * mem::size_of::<f32>()) as NSUInteger;
let output_size = (right.len() * mem::size_of::<f32>()) as NSUInteger;
encoder.set_compute_pipeline_state(&pipeline);
encoder.set_threadgroup_memory_length(0, output_size as NSUInteger);
let index_buffer = device.new_buffer_with_data(void_ptr(&index), ids_size, options);
let inputs_buffer = device.new_buffer_with_data(void_ptr(&left), input_size, options);
let outputs_buffer = device.new_buffer_with_data(void_ptr(&right), output_size, options);
encoder.set_buffer(0, Some(&index_buffer), 0);
encoder.set_buffer(1, Some(&inputs_buffer), 0);
encoder.set_buffer(2, Some(&outputs_buffer), 0);
encoder.set_bytes(3, 4, void_ptr(&ids_dim_size));
encoder.set_bytes(4, 4, void_ptr(&left_size));
encoder.set_bytes(5, 4, void_ptr(&dst_dim_size));
encoder.set_bytes(6, 4, void_ptr(&right_size));
let grid_size = MTLSize {
width: right.len() as NSUInteger,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width: pipeline.max_total_threads_per_threadgroup(),
height: 1,
depth: 1,
};
encoder.dispatch_thread_groups(grid_size, thread_group_size);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let expected = vec![
2.0, 3.0, 4.0, 1.0, 1.0, 1.0, 8.0, 9.0, 10.0, 1.0, 1.0, 1.0, 5.0, 6.0, 7.0,
];
let result = outputs_buffer.read_to_vec::<f32>(right.len());
assert_eq!(result, expected);
}
#[test]
fn cos_f16() {
let v: Vec<f16> = [1.0f32, 2.0, 3.0]
.iter()
.map(|v| f16::from_f32(*v))
.collect();
let results = run(&v, unary::contiguous::cos::HALF);
let expected: Vec<f16> = v.iter().map(|v| f16::from_f32(v.to_f32().cos())).collect();
assert_eq!(approx_f16(results, 4), vec![0.54, -0.4165, -0.9902]);
assert_eq!(approx_f16(expected, 4), vec![0.5405, -0.4163, -0.9902]);
}
}

View File

@ -1,82 +0,0 @@
#include <metal_stdlib>
METAL_FUNC uint get_strided_index(
uint idx,
constant size_t &num_dims,
constant size_t *dims,
constant size_t *strides
) {
uint strided_i = 0;
for (uint d = 0; d < num_dims; d++) {
uint dim_idx = num_dims - 1 - d;
strided_i += (idx % dims[dim_idx]) * strides[dim_idx];
idx /= dims[dim_idx];
}
return strided_i;
}
template <typename T> METAL_FUNC T sqr(T in){ return in * in; }
template <typename T> METAL_FUNC T neg(T in){ return -in; }
template <typename T> METAL_FUNC T id(T in){ return in; }
using namespace metal;
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
kernel void FN_NAME( \
constant size_t &dim, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[i])); \
} \
}\
kernel void FN_NAME_STRIDED( \
constant size_t &dim, \
constant size_t &num_dims, \
constant size_t *dims, \
constant size_t *strides, \
device const TYPENAME *input, \
device TYPENAME *output, \
uint threadgroup_size [[threads_per_threadgroup]], \
uint thread_index [[thread_index_in_threadgroup]] \
) { \
const size_t length = (dim + threadgroup_size - 1) / threadgroup_size; \
const size_t start = thread_index * length; \
const size_t stop = min(start + length, dim); \
for (size_t i = start; i < stop; i++){ \
output[i] = TYPENAME(FN(input[get_strided_index(i, num_dims, dims, strides)])); \
} \
}
#define UNARY_OP(NAME) \
UNARY(NAME, float, NAME##_float, NAME##_float_strided); \
UNARY(NAME, half, NAME##_half, NAME##_half_strided);
#define BFLOAT_UNARY_OP(NAME) \
UNARY(NAME, bfloat, NAME##_bfloat, NAME##_bfloat_strided);
UNARY_OP(cos)
UNARY_OP(sin)
UNARY_OP(sqr)
UNARY_OP(sqrt)
UNARY_OP(neg)
UNARY_OP(exp)
UNARY(id, float, copy_float, copy_float_strided)
UNARY(id, half, copy_half, copy_half_strided)
#if __METAL_VERSION__ >= 310
BFLOAT_UNARY_OP(cos)
BFLOAT_UNARY_OP(sin)
BFLOAT_UNARY_OP(sqr)
BFLOAT_UNARY_OP(sqrt)
BFLOAT_UNARY_OP(neg)
BFLOAT_UNARY_OP(exp)
#endif

View File

@ -14,7 +14,6 @@ accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
half = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rayon = { workspace = true }
@ -29,5 +28,4 @@ clap = { workspace = true }
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
metal = ["candle/metal"]
mkl = ["dep:intel-mkl-src", "candle/mkl"]

View File

@ -180,25 +180,8 @@ impl Benchmark for Conv2dIm2Col {
const ITERS: usize = 5;
}
struct MatMul;
impl Benchmark for MatMul {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
let lhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
let rhs = Tensor::randn(0f32, 1., (1024, 1024), &Device::Cpu)?;
Ok((lhs, rhs))
}
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
d.0.matmul(&d.1)
}
const ITERS: usize = 100;
}
struct MatVec;
impl Benchmark for MatVec {
struct Matmul;
impl Benchmark for Matmul {
type PreProcessData = (Tensor, Tensor);
type RunResult = Tensor;
fn preprocess() -> Result<Self::PreProcessData> {
@ -288,7 +271,6 @@ enum Task {
Conv2d,
Conv2dIm2Col,
Matmul,
Matvec,
Qmatmul,
Softmax,
SoftmaxLastDim,
@ -311,8 +293,7 @@ fn main() -> Result<()> {
Task::Conv1d => run::<Conv1d>(args.iters)?,
Task::Conv2d => run::<Conv2d>(args.iters)?,
Task::Conv2dIm2Col => run::<Conv2dIm2Col>(args.iters)?,
Task::Matmul => run::<MatMul>(args.iters)?,
Task::Matvec => run::<MatVec>(args.iters)?,
Task::Matmul => run::<Matmul>(args.iters)?,
Task::Softmax => run::<Softmax>(args.iters)?,
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
Task::Qmatmul => run::<QMatMul>(args.iters)?,

View File

@ -9,11 +9,8 @@ pub enum Activation {
#[serde(rename = "gated-gelu")]
NewGelu,
Relu,
Relu2,
Relu6,
Silu,
Sigmoid,
Swish,
Elu(f64),
LeakyRelu(f64),
}
@ -25,11 +22,8 @@ impl super::Module for Activation {
// https://github.com/huggingface/transformers/blob/12f043eaeaabfef6f6efea411d98e6f6d3c094b7/src/transformers/activations.py#L49-L78
Self::NewGelu => xs.gelu(),
Self::Relu => xs.relu(),
Self::Relu2 => xs.relu()?.sqr(),
Self::Relu6 => xs.clamp(0f32, 6f32),
Self::Silu => crate::ops::silu(xs),
Self::Sigmoid => crate::ops::sigmoid(xs),
Self::Swish => xs * crate::ops::sigmoid(xs)?,
&Self::Elu(alpha) => xs.elu(alpha),
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
}

View File

@ -100,23 +100,9 @@ impl BatchNorm {
num_features,
})
}
}
pub fn running_mean(&self) -> &Tensor {
&self.running_mean
}
pub fn running_var(&self) -> &Tensor {
&self.running_var
}
pub fn eps(&self) -> f64 {
self.eps
}
pub fn weight_and_bias(&self) -> Option<(&Tensor, &Tensor)> {
self.weight_and_bias.as_ref().map(|v| (&v.0, &v.1))
}
impl BatchNorm {
pub fn forward_learning(&self, x: &Tensor) -> Result<Tensor> {
let x_dtype = x.dtype();
let internal_dtype = match x_dtype {

View File

@ -1,5 +1,4 @@
//! Convolution Layers.
use crate::BatchNorm;
use candle::{Result, Tensor};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -116,26 +115,6 @@ impl Conv2d {
pub fn bias(&self) -> Option<&Tensor> {
self.bias.as_ref()
}
pub fn absorb_bn(&self, bn: &BatchNorm) -> Result<Self> {
if let Some((w_bn, b_bn)) = bn.weight_and_bias() {
let std_ = w_bn.div(&((bn.running_var() + bn.eps())?.sqrt()?))?;
let weight = self
.weight()
.broadcast_mul(&(std_.reshape((self.weight().dims4()?.0, 1, 1, 1))?))?;
let bias = match &self.bias {
None => b_bn.sub(&(std_.mul(bn.running_mean())?))?,
Some(bias) => b_bn.add(&(std_.mul(&bias.sub(bn.running_mean())?)?))?,
};
Ok(Self {
weight,
bias: Some(bias),
config: self.config,
})
} else {
candle::bail!("batch norm does not have weight_and_bias")
}
}
}
impl crate::Module for Conv2d {

View File

@ -1,12 +1,10 @@
//! Layers defined by closures.
use candle::{Result, Tensor};
use std::sync::Arc;
/// A layer defined by a simple closure.
#[derive(Clone)]
pub struct Func<'a> {
#[allow(clippy::type_complexity)]
f: Arc<dyn 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync>,
f: Box<dyn 'a + Fn(&Tensor) -> Result<Tensor> + Send>,
}
impl<'a> std::fmt::Debug for Func<'a> {
@ -17,9 +15,9 @@ impl<'a> std::fmt::Debug for Func<'a> {
pub fn func<'a, F>(f: F) -> Func<'a>
where
F: 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync,
F: 'a + Fn(&Tensor) -> Result<Tensor> + Send,
{
Func { f: Arc::new(f) }
Func { f: Box::new(f) }
}
impl<'a> super::Module for Func<'a> {
@ -27,47 +25,3 @@ impl<'a> super::Module for Func<'a> {
(*self.f)(xs)
}
}
impl<'a> Func<'a> {
pub fn new<F>(f: F) -> Self
where
F: 'a + Fn(&Tensor) -> Result<Tensor> + Send + Sync,
{
Self { f: Arc::new(f) }
}
}
/// A layer defined by a simple closure.
#[derive(Clone)]
pub struct FuncT<'a> {
#[allow(clippy::type_complexity)]
f: Arc<dyn 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync>,
}
impl<'a> std::fmt::Debug for FuncT<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "func")
}
}
pub fn func_t<'a, F>(f: F) -> FuncT<'a>
where
F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
{
FuncT { f: Arc::new(f) }
}
impl<'a> super::ModuleT for FuncT<'a> {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
(*self.f)(xs, train)
}
}
impl<'a> FuncT<'a> {
pub fn new<F>(f: F) -> Self
where
F: 'a + Fn(&Tensor, bool) -> Result<Tensor> + Send + Sync,
{
Self { f: Arc::new(f) }
}
}

View File

@ -11,7 +11,6 @@ pub mod loss;
pub mod ops;
pub mod optim;
pub mod rnn;
pub mod sequential;
pub mod var_builder;
pub mod var_map;
@ -22,7 +21,7 @@ pub use conv::{
Conv1dConfig, Conv2d, Conv2dConfig, ConvTranspose2d, ConvTranspose2dConfig,
};
pub use embedding::{embedding, Embedding};
pub use func::{func, func_t, Func, FuncT};
pub use func::{func, Func};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
@ -30,8 +29,7 @@ pub use linear::{linear, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
pub use sequential::{seq, Sequential};
pub use var_builder::VarBuilder;
pub use var_map::VarMap;
pub use candle::{Module, ModuleT};
pub use candle::Module;

View File

@ -48,25 +48,3 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
(inp - target)?.sqr()?.mean_all()
}
/// The binary cross-entropy with logit loss.
///
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to raw logits.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
/// of categories.
///
/// The resulting tensor is a scalar containing the average value over the batch.
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
let inp = crate::ops::sigmoid(inp)?;
let left_side = target * inp.log()?;
let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;
let loss = left_side? + right_side?;
let loss = loss?.neg()?.mean_all()?;
Ok(loss)
}

View File

@ -1,6 +1,5 @@
use candle::{CpuStorage, Layout, Result, Shape, Tensor};
use rayon::prelude::*;
use tracing::debug;
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
/// a slice of fixed index on dimension `dim` are between 0 and 1 and sum to 1.
@ -85,12 +84,6 @@ impl Dropout {
}
}
impl candle::ModuleT for Dropout {
fn forward_t(&self, xs: &Tensor, train: bool) -> Result<Tensor> {
self.forward(xs, train)
}
}
struct SoftmaxLastDim;
impl candle::CustomOp1 for SoftmaxLastDim {
@ -192,16 +185,6 @@ impl candle::CustomOp1 for SoftmaxLastDim {
};
Ok((dst, layout.shape().clone()))
}
#[cfg(feature = "metal")]
fn metal_fwd(
&self,
storage: &candle::MetalStorage,
layout: &Layout,
) -> Result<(candle::MetalStorage, Shape)> {
debug!("TODO softmax-last-dim");
Ok((storage.clone(), layout.shape().clone()))
}
}
pub fn softmax_last_dim(xs: &Tensor) -> Result<Tensor> {

View File

@ -1,62 +0,0 @@
//! A sequential layer used to chain multiple layers and closures.
use candle::{Module, Result, Tensor};
/// A sequential layer combining multiple other layers.
pub struct Sequential {
layers: Vec<Box<dyn Module>>,
}
/// Creates a new empty sequential layer.
pub fn seq() -> Sequential {
Sequential { layers: vec![] }
}
impl Sequential {
/// The number of sub-layers embedded in this layer.
pub fn len(&self) -> i64 {
self.layers.len() as i64
}
/// Returns true if this layer does not have any sub-layer.
pub fn is_empty(&self) -> bool {
self.layers.is_empty()
}
}
impl Module for Sequential {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs)?
}
Ok(xs)
}
}
impl Sequential {
/// Appends a layer after all the current layers.
#[allow(clippy::should_implement_trait)]
pub fn add<M: Module + 'static>(mut self, layer: M) -> Self {
self.layers.push(Box::new(layer));
self
}
/// Appends a closure after all the current layers.
pub fn add_fn<F>(self, f: F) -> Self
where
F: 'static + Fn(&Tensor) -> Result<Tensor> + Send + Sync,
{
self.add(super::func(f))
}
/// Applies the forward pass and returns the output for each layer.
pub fn forward_all(&self, xs: &Tensor) -> Result<Vec<Tensor>> {
let mut vec = Vec::with_capacity(self.layers.len());
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs)?;
vec.push(xs.clone())
}
Ok(vec)
}
}

View File

@ -191,7 +191,6 @@ impl<'a, B: Backend> VarBuilderArgs<'a, B> {
}
struct Zeros;
impl SimpleBackend for Zeros {
fn get(&self, s: Shape, _: &str, _: crate::Init, dtype: DType, dev: &Device) -> Result<Tensor> {
Tensor::zeros(s, dtype, dev)
@ -326,39 +325,6 @@ impl SimpleBackend for candle::npy::NpzTensors {
}
}
impl SimpleBackend for candle::pickle::PthTensors {
fn get(
&self,
s: Shape,
path: &str,
_: crate::Init,
dtype: DType,
dev: &Device,
) -> Result<Tensor> {
let tensor = match self.get(path)? {
None => Err(Error::CannotFindTensor {
path: path.to_string(),
}
.bt())?,
Some(tensor) => tensor,
};
let tensor = tensor.to_device(dev)?.to_dtype(dtype)?;
if tensor.shape() != &s {
Err(candle::Error::UnexpectedShape {
msg: format!("shape mismatch for {path}"),
expected: s,
got: tensor.shape().clone(),
}
.bt())?
}
Ok(tensor)
}
fn contains_tensor(&self, name: &str) -> bool {
self.get(name).map_or(false, |v| v.is_some())
}
}
impl SimpleBackend for candle::safetensors::MmapedSafetensors {
fn get(
&self,
@ -472,16 +438,9 @@ impl<'a> VarBuilder<'a> {
let npz = candle::npy::NpzTensors::new(p)?;
Ok(Self::new(Box::new(npz), dtype, dev.clone()))
}
/// Initializes a `VarBuilder` that retrieves tensors stored in a pytorch pth file.
pub fn from_pth<P: AsRef<std::path::Path>>(p: P, dtype: DType, dev: &Device) -> Result<Self> {
let pth = candle::pickle::PthTensors::new(p)?;
Ok(Self::new(Box::new(pth), dtype, dev.clone()))
}
}
pub struct ShardedSafeTensors(candle::safetensors::MmapedSafetensors);
pub type ShardedVarBuilder<'a> = VarBuilderArgs<'a, ShardedSafeTensors>;
impl ShardedSafeTensors {

View File

@ -39,50 +39,3 @@ fn nll_and_cross_entropy() -> Result<()> {
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
Ok(())
}
/* Equivalent python code:
import torch
import torch.nn.functional as F
inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178],
[ 0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[ 1.3081, 0.6641, 1.1802, -0.2547],
[ 0.5292, 0.7636, 0.3692, -0.8318]])
target = torch.Tensor([[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.]])
print(F.binary_cross_entropy_with_logits(inp, target))
*/
#[test]
fn binary_cross_entropy_with_logit() -> Result<()> {
let cpu = Device::Cpu;
let inp = [
[2.3611f32, -0.8813, -0.5006, -0.2178],
[0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[1.3081, 0.6641, 1.1802, -0.2547],
[0.5292, 0.7636, 0.3692, -0.8318],
];
let target = [
[0.0f32, 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
];
let inp = Tensor::new(&inp, &cpu)?;
let target = Tensor::new(&target, &cpu)?;
let loss = candle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?;
assert_eq!(to_vec0_round(&loss, 4)?, 0.8224);
Ok(())
}

View File

@ -14,18 +14,14 @@ name = "candle"
crate-type = ["cdylib"]
[dependencies]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
pyo3 = { version = "0.19.0", features = ["extension-module"] }
[build-dependencies]
pyo3-build-config = "0.20"
pyo3-build-config = "0.19"
[features]
default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src","candle/mkl"]

View File

@ -1,3 +0,0 @@
This python module contains external typehinting for certain `candle` classes. This is only necessary for `magic` methodes e.g. `__add__` as their text signature cant be set via pyo3.
The classes in this module will be parsed by the `stub.py` script and interleafed with the signatures of the actual pyo3 `candle.candle` module.

Some files were not shown because too many files have changed in this diff Show More