Make the Python Wrapper more Hackable and simplify Quantization (#1010)

* Some first `Module` implementations

* Add `state_dict` and `load_state_dict` functionality

* Move modules around and create `candle.nn.Linear`

* Add `nn.Embedding` and `nn.LayerNorm`

* Add BERT implementation

* Batch q-matmul

* Automatically dequantize `QTensors` if a `Tensor` is expected

* Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality

* Unittests for `Module`, `Tensor` and `candle.utils`

* Add `pytorch` like slicing to `Tensor`

* Cleanup and BERT fixes

* `black` formatting + unit-test for `nn.Linear`

* Refactor slicing implementation
This commit is contained in:
Lukas Kreussel
2023-10-06 20:01:07 +02:00
committed by GitHub
parent b0442eff8a
commit 904bbdae65
25 changed files with 2426 additions and 182 deletions

11
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,11 @@
{
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none",
"python.testing.pytestArgs": [
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}

View File

@ -1,3 +1,4 @@
tests/_workdir
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]

104
candle-pyo3/e5.py Normal file
View File

@ -0,0 +1,104 @@
from candle.utils import load_safetensors, save_gguf, load_gguf
from candle.models.bert import BertModel, Config
import json
from candle import Tensor
from tqdm import tqdm
from dataclasses import fields
import os
import time
from huggingface_hub import hf_hub_download
from transformers import BertTokenizer, AutoModel
import torch
if __name__ == "__main__":
model_name = "intfloat/e5-small-v2"
model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors")
config_file = hf_hub_download(repo_id=model_name, filename="config.json")
tensors = load_safetensors(model_file)
config = Config()
with open(config_file, "r") as f:
raw_config = json.load(f)
for field in fields(config):
if field.name in raw_config:
setattr(config, field.name, raw_config[field.name])
# Load the model
model = BertModel(config)
model.load_state_dict(tensors)
hf_model = AutoModel.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
sentences = [
"The cat sits outside",
"A man is playing guitar",
"I love pasta",
"The new movie is awesome",
"The cat plays in the garden",
"A woman watches TV",
"The new movie is so great",
"Do you like pizza?",
]
def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor):
"""Average the hidden states according to the attention mask"""
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
tokenized = tokenizer(sentences, padding=True)
tokens = Tensor(tokenized["input_ids"])
token_type_ids = Tensor(tokenized["token_type_ids"])
encoder_out, _ = model.forward(tokens, token_type_ids)
hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt")
hf_result = hf_model(**hf_tokenized)["last_hidden_state"]
hf_pooled = average_pool(hf_result, hf_tokenized["attention_mask"])
candle_pooled = average_pool(torch.tensor(encoder_out.values()), hf_tokenized["attention_mask"])
loss = torch.nn.L1Loss()
error = loss(hf_pooled, candle_pooled).mean().item()
print(f"Mean error between torch-referenze and candle: {error}")
# Quantize all attention 'weights'
quantized_tensors = {}
for name, tensor in tqdm(tensors.items(), desc="Quantizing tensors to 5-Bit"):
if name.endswith("weight") and ("attention" in name or "intermediate" in name or "output" in name):
# check if the tensor is k-quantizable
if tensor.shape[-1] % 256 == 0:
new_tensor = tensor.quantize("q4k")
else:
new_tensor = tensor.quantize("q5_0")
quantized_tensors[name] = new_tensor
else:
quantized_tensors[name] = tensor.quantize("q8_0")
print(f"Saving quantized tensors")
# Remove all None values from the config
config_to_save = {k: v for k, v in config.__dict__.items() if v is not None}
# Save the model
quantized_model_file = "e5_small.gguf"
save_gguf(quantized_model_file, quantized_tensors, config_to_save)
file_size_mb = os.path.getsize(model_file) / 1024 / 1024
file_size_mb_compressed = os.path.getsize(quantized_model_file) / 1024 / 1024
print(f"Compressed model from {file_size_mb:.2f} MB to {file_size_mb_compressed:.2f} MB")
# Load the model from the gguf
tensors, raw_config = load_gguf(quantized_model_file)
config = Config()
for field in fields(config):
if field.name in raw_config:
setattr(config, field.name, raw_config[field.name])
model = BertModel(config)
# "embeddings.position_ids" is missing in the gguf as it is i64
model.load_state_dict(tensors, strict=False)
# Run the model again
encoder_out_2, pooled_output_2 = model.forward(tokens, token_type_ids)
encoder_out_2, pooled_output_2 = encoder_out_2.to_device("cpu"), pooled_output_2.to_device("cpu")
candle_pooled_2 = average_pool(torch.tensor(encoder_out_2.values()), hf_tokenized["attention_mask"])
error = loss(hf_pooled, candle_pooled_2).mean().item()
print(f"Mean error between torch-referenze and quantized-candle: {error}")

View File

@ -1,5 +1,30 @@
from .candle import *
import logging
try:
from .candle import *
except ImportError as e:
# If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here
logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...")
import os
import platform
# Try to locate CUDA_PATH environment variable
cuda_path = os.environ.get("CUDA_PATH", None)
if cuda_path:
logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}")
if platform.system() == "Windows":
cuda_path = os.path.join(cuda_path, "bin")
else:
cuda_path = os.path.join(cuda_path, "lib64")
logging.warning(f"Adding {cuda_path} to DLL search path...")
os.add_dll_directory(cuda_path)
try:
from .candle import *
except ImportError as inner_e:
raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.")
__doc__ = candle.__doc__
if hasattr(candle, "__all__"):
__all__ = candle.__all__
__all__ = candle.__all__

View File

@ -0,0 +1,8 @@
# Generated content DO NOT EDIT
from .. import functional
gelu = functional.gelu
relu = functional.relu
silu = functional.silu
softmax = functional.softmax
tanh = functional.tanh

View File

@ -4,6 +4,20 @@ from os import PathLike
from candle.typing import _ArrayLike, Device
from candle import Tensor, DType, QTensor
@staticmethod
def gelu(tensor: Tensor) -> Tensor:
"""
Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
"""
pass
@staticmethod
def relu(tensor: Tensor) -> Tensor:
"""
Applies the Rectified Linear Unit (ReLU) function to a given tensor.
"""
pass
@staticmethod
def silu(tensor: Tensor) -> Tensor:
"""
@ -17,3 +31,10 @@ def softmax(tensor: Tensor, dim: int) -> Tensor:
Applies the Softmax function to a given tensor.#
"""
pass
@staticmethod
def tanh(tensor: Tensor) -> Tensor:
"""
Applies the tanh function to a given tensor.
"""
pass

View File

@ -0,0 +1,194 @@
from dataclasses import dataclass
from typing import Optional
from candle.nn import Module, Embedding, LayerNorm, Linear, ModuleList
from candle import Tensor
import candle
import candle.functional as F
from typing import Tuple, Optional
@dataclass
class Config:
vocab_size: int = 30522
hidden_size: int = 768
num_hidden_layers: int = 12
num_attention_heads: int = 12
intermediate_size: int = 3072
hidden_act: str = "gelu"
hidden_dropout_prob: float = 0.1
max_position_embeddings: int = 512
type_vocab_size: int = 2
initializer_range: float = 0.02
layer_norm_eps: float = 1e-12
pad_token_id: int = 0
position_embedding_type: str = "absolute"
use_cache: bool = True
classifier_dropout: Optional[float] = None
model_type: Optional[str] = "bert"
class BertSelfAttention(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.num_attention_heads = config.num_attention_heads
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
all_head_size = int(config.num_attention_heads * self.attention_head_size)
hidden_size = config.hidden_size
self.query = Linear(hidden_size, all_head_size)
self.key = Linear(hidden_size, all_head_size)
self.value = Linear(hidden_size, all_head_size)
def transpose_for_scores(self, x: Tensor) -> Tensor:
new_x_shape = x.shape[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
x = x.reshape(new_x_shape).transpose(1, 2)
return x.contiguous()
def forward(self, hidden_states: Tensor) -> Tensor:
query = self.query.forward(hidden_states)
key = self.key.forward(hidden_states)
value = self.value.forward(hidden_states)
query = self.transpose_for_scores(query)
key = self.transpose_for_scores(key)
value = self.transpose_for_scores(value)
attention_scores = query.matmul(key.t())
attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5)
attention_probs = F.softmax(attention_scores, dim=-1)
context_layer = attention_probs.matmul(value)
context_layer = context_layer.transpose(1, 2).contiguous()
context_layer = context_layer.flatten_from(-2)
return context_layer
class BertSelfOutput(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.dense = Linear(config.hidden_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
hidden_states = self.dense.forward(hidden_states)
return self.LayerNorm.forward(hidden_states + input_tensor)
class BertAttention(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.self = BertSelfAttention(config)
self.output = BertSelfOutput(config)
def forward(self, hidden_states: Tensor) -> Tensor:
self_outputs = self.self.forward(hidden_states)
attention_output = self.output.forward(self_outputs, hidden_states)
return attention_output
class BertIntermediate(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.dense = Linear(config.hidden_size, config.intermediate_size)
self.act = F.gelu if config.hidden_act == "gelu" else F.relu
def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.dense.forward(hidden_states)
return self.act(hidden_states)
class BertOutput(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.dense = Linear(config.intermediate_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
hidden_states = self.dense.forward(hidden_states)
return self.LayerNorm.forward(hidden_states + input_tensor)
class BertLayer(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states: Tensor) -> Tensor:
attention_output = self.attention.forward(hidden_states)
# TODO: Support cross-attention?
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
# TODO: Support something similar to `apply_chunking_to_forward`?
intermediate_output = self.intermediate.forward(attention_output)
layer_output = self.output.forward(intermediate_output, attention_output)
return layer_output
class BertEncoder(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.layer = ModuleList()
for _ in range(config.num_hidden_layers):
self.layer.append(BertLayer(config))
def forward(self, hidden_states: Tensor) -> Tensor:
for l in self.layer:
hidden_states = l.forward(hidden_states)
return hidden_states
class BertEmbeddings(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.word_embeddings = Embedding(config.vocab_size, config.hidden_size)
self.position_embeddings = Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = Embedding(config.type_vocab_size, config.hidden_size)
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.position_ids = candle.Tensor(list(range(config.max_position_embeddings))).reshape(
(1, config.max_position_embeddings)
)
def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tensor:
(_batch_size, seq_len) = input_ids.shape
input_embeddings = self.word_embeddings.forward(input_ids)
token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)
embeddings: Tensor = input_embeddings + token_type_embeddings
position_ids = list(range(seq_len))
position_ids = Tensor(position_ids).to_dtype(input_ids.dtype).to_device(input_ids.device)
embeddings = embeddings.broadcast_add(self.position_embeddings.forward(position_ids))
embeddings = self.LayerNorm(embeddings)
return embeddings
class BertPooler(Module):
def __init__(self, config: Config) -> None:
super().__init__()
self.dense = Linear(config.hidden_size, config.hidden_size)
self.activation = F.tanh
def forward(self, hidden_states: Tensor) -> Tensor:
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense.forward(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
class BertModel(Module):
def __init__(self, config: Config, add_pooling_layer=True) -> None:
super().__init__()
self.config = config
self.embeddings = BertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config) if add_pooling_layer else None
def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
embeddings = self.embeddings.forward(input_ids, token_type_ids)
encoder_out = self.encoder.forward(embeddings)
pooled_output = self.pooler(encoder_out) if self.pooler is not None else None
return encoder_out, pooled_output

View File

@ -0,0 +1,150 @@
import candle
from typing import Dict, Tuple, Any
from candle import Tensor, QTensor, utils, nn
from candle.nn import Module, ModuleList
def masked_fill(on_false: Tensor, mask: Tensor, on_true: Tensor):
shape = mask.shape
on_true = candle.tensor(on_true).broadcast_as(shape)
return mask.where_cond(on_true, on_false)
def precompute_freqs_cis(hparams: Dict[str, Any], freq_base: float, max_seq_len: int):
head_dim = hparams["n_embd"] // hparams["n_head"]
theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)]
theta = candle.tensor(theta)
idx_theta = [float(i) for i in range(max_seq_len)]
idx_theta = candle.tensor(idx_theta).reshape((max_seq_len, 1))
m = idx_theta.matmul(theta.unsqueeze(0))
return (m.cos(), m.sin())
class RmsNorm(Module):
def __init__(self, qtensor: QTensor):
super().__init__()
self.weight = qtensor.dequantize()
def forward(self, x: Tensor) -> Tensor:
b_size, seq_len, hidden_size = x.shape
norm_x = x.sqr().sum_keepdim(2) / hidden_size
x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
return x_normed.broadcast_mul(self.weight)
class QuantizedLayer(Module):
def __init__(
self,
layer_idx: int,
hparams: Dict[str, Any],
all_tensors: Dict[str, QTensor],
cos_sin: Tuple[Tensor, Tensor],
):
super().__init__()
p = f"layers.{layer_idx}"
self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
self.attention_wv = all_tensors[f"{p}.attention.wv.weight"]
self.attention_wo = all_tensors[f"{p}.attention.wo.weight"]
self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"]
self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"]
self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"]
self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"])
self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"])
self.n_head = hparams["n_head"]
self.n_kv_head = self.n_head
self.head_dim = hparams["n_embd"] // self.n_head
self.kv_cache = None
self.cos = cos_sin[0]
self.sin = cos_sin[1]
self._non_persistent_buffers_set.add("cos")
self._non_persistent_buffers_set.add("sin")
def forward(self, x: Tensor, mask: Tensor, index_pos: int) -> Tensor:
residual = x
x = self.attn_norm(x)
attn = self.forward_attn(x, mask, index_pos)
x = attn + residual
residual = x
x = self.ffn_norm(x)
w1 = self.ffw1.matmul_t(x)
w3 = self.ffw3.matmul_t(x)
mlp = self.ffw2.matmul_t(nn.silu(w1) * w3)
return mlp + residual
def forward_attn(self, x: Tensor, mask: Tensor, index_pos: int):
b_size, seq_len, n_embd = x.shape
q = self.attention_wq.matmul_t(x)
k = self.attention_wk.matmul_t(x)
v = self.attention_wv.matmul_t(x)
q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2)
k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
q = self.apply_rotary_emb(q, index_pos)
k = self.apply_rotary_emb(k, index_pos)
if self.kv_cache is not None and index_pos > 0:
prev_k, prev_v = self.kv_cache
k = candle.cat([prev_k, k], 2).contiguous()
v = candle.cat([prev_v, v], 2).contiguous()
self.kv_cache = (k, v)
# TODO: maybe repeat k/v here if we start supporting MQA.
att = q.matmul(k.t()) / self.head_dim**0.5
mask = mask.broadcast_as(att.shape)
att = masked_fill(att, mask, float("-inf"))
att = nn.softmax(att, -1)
y = att.matmul(v.contiguous())
y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
return self.attention_wo.matmul_t(y)
def apply_rotary_emb(self, x: Tensor, index_pos: int):
b_size, n_head, seq_len, n_embd = x.shape
cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd // 2, 1))
sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd // 2, 1))
x = x.reshape((b_size, n_head, seq_len, n_embd // 2, 2))
x0 = x.narrow(-1, 0, 1)
x1 = x.narrow(-1, 1, 1)
y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin)
y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos)
rope = candle.cat([y0, y1], -1)
return rope.flatten_from(-2)
class QuantizedLlama(Module):
def __init__(self, hparams: Dict[str, Any], all_tensors: Dict[str, QTensor]):
super().__init__()
self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
self.norm = RmsNorm(all_tensors["norm.weight"])
self.output = all_tensors["output.weight"]
self.layers = ModuleList()
rope_freq = hparams.get("rope_freq", 10000.0)
cos_sin = precompute_freqs_cis(hparams, rope_freq, hparams["context_length"])
for layer_idx in range(hparams["n_layer"]):
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
self.layers.append(layer)
def forward(self, token: Tensor, index_pos: int) -> Tensor:
b_size, seq_len = token.shape
vocab_size, hidden_size = self.tok_embeddings.shape
token = token.reshape((b_size * seq_len,))
x = self.tok_embeddings.index_select(token, 0)
x = x.reshape((b_size, seq_len, hidden_size))
mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)]
mask = candle.tensor(mask).reshape((seq_len, seq_len))
for layer in self.layers:
x = layer(x, mask, index_pos)
x = self.norm(x)
x = x.narrow(1, -1, 1).squeeze(1)
x = self.output.matmul_t(x)
return x

View File

@ -1,5 +1,5 @@
# Generated content DO NOT EDIT
from .. import nn
silu = nn.silu
softmax = nn.softmax
from .module import Module
from .container import Sequential, ModuleList, ModuleDict
from .sparse import Embedding
from .normalization import LayerNorm
from .linear import Linear

View File

@ -0,0 +1,483 @@
# see https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py
from .module import Module
from typing import (
Any,
Dict,
Iterable,
Iterator,
Mapping,
Optional,
overload,
Tuple,
TypeVar,
Union,
)
from collections import OrderedDict, abc as container_abcs
import operator
from itertools import chain, islice
__all__ = ["Sequential", "ModuleList", "ModuleDict"]
T = TypeVar("T", bound=Module)
def _addindent(s_: str, numSpaces: int):
s = s_.split("\n")
# don't do anything for single-line stuff
if len(s) == 1:
return s_
first = s.pop(0)
s = [(numSpaces * " ") + line for line in s]
s = "\n".join(s)
s = first + "\n" + s
return s
class Sequential(Module):
r"""A sequential container.
Modules will be added to it in the order they are passed in the
constructor. Alternatively, an ``OrderedDict`` of modules can be
passed in. The ``forward()`` method of ``Sequential`` accepts any
input and forwards it to the first module it contains. It then
"chains" outputs to inputs sequentially for each subsequent module,
finally returning the output of the last module.
The value a ``Sequential`` provides over manually calling a sequence
of modules is that it allows treating the whole container as a
single module, such that performing a transformation on the
``Sequential`` applies to each of the modules it stores (which are
each a registered submodule of the ``Sequential``).
What's the difference between a ``Sequential`` and a
:class:`candle.nn.ModuleList`? A ``ModuleList`` is exactly what it
sounds like--a list for storing ``Module`` s! On the other hand,
the layers in a ``Sequential`` are connected in a cascading way.
"""
_modules: Dict[str, Module] # type: ignore[assignment]
@overload
def __init__(self, *args: Module) -> None:
...
@overload
def __init__(self, arg: "OrderedDict[str, Module]") -> None:
...
def __init__(self, *args):
super().__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def _get_item_by_idx(self, iterator, idx) -> T:
"""Get the idx-th item of the iterator"""
size = len(self)
idx = operator.index(idx)
if not -size <= idx < size:
raise IndexError("index {} is out of range".format(idx))
idx %= size
return next(islice(iterator, idx, None))
def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]:
if isinstance(idx, slice):
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
else:
return self._get_item_by_idx(self._modules.values(), idx)
def __setitem__(self, idx: int, module: Module) -> None:
key: str = self._get_item_by_idx(self._modules.keys(), idx)
return setattr(self, key, module)
def __delitem__(self, idx: Union[slice, int]) -> None:
if isinstance(idx, slice):
for key in list(self._modules.keys())[idx]:
delattr(self, key)
else:
key = self._get_item_by_idx(self._modules.keys(), idx)
delattr(self, key)
# To preserve numbering
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
def __len__(self) -> int:
return len(self._modules)
def __add__(self, other) -> "Sequential":
if isinstance(other, Sequential):
ret = Sequential()
for layer in self:
ret.append(layer)
for layer in other:
ret.append(layer)
return ret
else:
raise ValueError(
"add operator supports only objects " "of Sequential class, but {} is given.".format(str(type(other)))
)
def pop(self, key: Union[int, slice]) -> Module:
v = self[key]
del self[key]
return v
def __iadd__(self, other) -> "Sequential":
if isinstance(other, Sequential):
offset = len(self)
for i, module in enumerate(other):
self.add_module(str(i + offset), module)
return self
else:
raise ValueError(
"add operator supports only objects " "of Sequential class, but {} is given.".format(str(type(other)))
)
def __mul__(self, other: int) -> "Sequential":
if not isinstance(other, int):
raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
elif other <= 0:
raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
else:
combined = Sequential()
offset = 0
for _ in range(other):
for module in self:
combined.add_module(str(offset), module)
offset += 1
return combined
def __rmul__(self, other: int) -> "Sequential":
return self.__mul__(other)
def __imul__(self, other: int) -> "Sequential":
if not isinstance(other, int):
raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
elif other <= 0:
raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
else:
len_original = len(self)
offset = len(self)
for _ in range(other - 1):
for i in range(len_original):
self.add_module(str(i + offset), self._modules[str(i)])
offset += len_original
return self
def __dir__(self):
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())
# NB: We can't really type check this function as the type of input
# may change dynamically (as is tested in
# TestScript.test_sequential_intermediary_types). Cannot annotate
# with Any as TorchScript expects a more precise type
def forward(self, input):
for module in self:
input = module(input)
return input
def append(self, module: Module) -> "Sequential":
r"""Appends a given module to the end.
Args:
module (nn.Module): module to append
"""
self.add_module(str(len(self)), module)
return self
def insert(self, index: int, module: Module) -> "Sequential":
if not isinstance(module, Module):
raise AssertionError("module should be of type: {}".format(Module))
n = len(self._modules)
if not (-n <= index <= n):
raise IndexError("Index out of range: {}".format(index))
if index < 0:
index += n
for i in range(n, index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
return self
def extend(self, sequential) -> "Sequential":
for layer in sequential:
self.append(layer)
return self
class ModuleList(Module):
r"""Holds submodules in a list.
:class:`~candle.nn.ModuleList` can be indexed like a regular Python list, but
modules it contains are properly registered, and will be visible by all
:class:`~candle.nn.Module` methods.
Args:
modules (iterable, optional): an iterable of modules to add
Example::
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
"""
_modules: Dict[str, Module] # type: ignore[assignment]
def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
super().__init__()
if modules is not None:
self += modules
def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of modules"""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError("index {} is out of range".format(idx))
if idx < 0:
idx += len(self)
return str(idx)
def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]:
if isinstance(idx, slice):
return self.__class__(list(self._modules.values())[idx])
else:
return self._modules[self._get_abs_string_index(idx)]
def __setitem__(self, idx: int, module: Module) -> None:
idx = self._get_abs_string_index(idx)
return setattr(self, str(idx), module)
def __delitem__(self, idx: Union[int, slice]) -> None:
if isinstance(idx, slice):
for k in range(len(self._modules))[idx]:
delattr(self, str(k))
else:
delattr(self, self._get_abs_string_index(idx))
# To preserve numbering, self._modules is being reconstructed with modules after deletion
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
def __len__(self) -> int:
return len(self._modules)
def __iter__(self) -> Iterator[Module]:
return iter(self._modules.values())
def __iadd__(self, modules: Iterable[Module]) -> "ModuleList":
return self.extend(modules)
def __add__(self, other: Iterable[Module]) -> "ModuleList":
combined = ModuleList()
for i, module in enumerate(chain(self, other)):
combined.add_module(str(i), module)
return combined
def __repr__(self):
"""A custom repr for ModuleList that compresses repeated module representations"""
list_of_reprs = [repr(item) for item in self]
if len(list_of_reprs) == 0:
return self._get_name() + "()"
start_end_indices = [[0, 0]]
repeated_blocks = [list_of_reprs[0]]
for i, r in enumerate(list_of_reprs[1:], 1):
if r == repeated_blocks[-1]:
start_end_indices[-1][1] += 1
continue
start_end_indices.append([i, i])
repeated_blocks.append(r)
lines = []
main_str = self._get_name() + "("
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
local_repr = f"({start_id}): {b}" # default repr
if start_id != end_id:
n = end_id - start_id + 1
local_repr = f"({start_id}-{end_id}): {n} x {b}"
local_repr = _addindent(local_repr, 2)
lines.append(local_repr)
main_str += "\n " + "\n ".join(lines) + "\n"
main_str += ")"
return main_str
def __dir__(self):
keys = super().__dir__()
keys = [key for key in keys if not key.isdigit()]
return keys
def insert(self, index: int, module: Module) -> None:
r"""Insert a given module before a given index in the list.
Args:
index (int): index to insert.
module (nn.Module): module to insert
"""
for i in range(len(self._modules), index, -1):
self._modules[str(i)] = self._modules[str(i - 1)]
self._modules[str(index)] = module
def append(self, module: Module) -> "ModuleList":
r"""Appends a given module to the end of the list.
Args:
module (nn.Module): module to append
"""
self.add_module(str(len(self)), module)
return self
def pop(self, key: Union[int, slice]) -> Module:
v = self[key]
del self[key]
return v
def extend(self, modules: Iterable[Module]) -> "ModuleList":
r"""Appends modules from a Python iterable to the end of the list.
Args:
modules (iterable): iterable of modules to append
"""
if not isinstance(modules, container_abcs.Iterable):
raise TypeError(
"ModuleList.extend should be called with an " "iterable, but got " + type(modules).__name__
)
offset = len(self)
for i, module in enumerate(modules):
self.add_module(str(offset + i), module)
return self
# remove forward alltogether to fallback on Module's _forward_unimplemented
class ModuleDict(Module):
r"""Holds submodules in a dictionary.
:class:`~candle.nn.ModuleDict` can be indexed like a regular Python dictionary,
but modules it contains are properly registered, and will be visible by all
:class:`~candle.nn.Module` methods.
:class:`~candle.nn.ModuleDict` is an **ordered** dictionary that respects
* the order of insertion, and
* in :meth:`~candle.nn.ModuleDict.update`, the order of the merged
``OrderedDict``, ``dict`` (started from Python 3.6) or another
:class:`~candle.nn.ModuleDict` (the argument to
:meth:`~candle.nn.ModuleDict.update`).
Note that :meth:`~candle.nn.ModuleDict.update` with other unordered mapping
types (e.g., Python's plain ``dict`` before Python version 3.6) does not
preserve the order of the merged mapping.
Args:
modules (iterable, optional): a mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module)
"""
_modules: Dict[str, Module] # type: ignore[assignment]
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
super().__init__()
if modules is not None:
self.update(modules)
def __getitem__(self, key: str) -> Module:
return self._modules[key]
def __setitem__(self, key: str, module: Module) -> None:
self.add_module(key, module)
def __delitem__(self, key: str) -> None:
del self._modules[key]
def __len__(self) -> int:
return len(self._modules)
def __iter__(self) -> Iterator[str]:
return iter(self._modules)
def __contains__(self, key: str) -> bool:
return key in self._modules
def clear(self) -> None:
"""Remove all items from the ModuleDict."""
self._modules.clear()
def pop(self, key: str) -> Module:
r"""Remove key from the ModuleDict and return its module.
Args:
key (str): key to pop from the ModuleDict
"""
v = self[key]
del self[key]
return v
def keys(self) -> Iterable[str]:
r"""Return an iterable of the ModuleDict keys."""
return self._modules.keys()
def items(self) -> Iterable[Tuple[str, Module]]:
r"""Return an iterable of the ModuleDict key/value pairs."""
return self._modules.items()
def values(self) -> Iterable[Module]:
r"""Return an iterable of the ModuleDict values."""
return self._modules.values()
def update(self, modules: Mapping[str, Module]) -> None:
r"""Update the :class:`~candle.nn.ModuleDict` with the key-value pairs from a
mapping or an iterable, overwriting existing keys.
.. note::
If :attr:`modules` is an ``OrderedDict``, a :class:`~candle.nn.ModuleDict`, or
an iterable of key-value pairs, the order of new elements in it is preserved.
Args:
modules (iterable): a mapping (dictionary) from string to :class:`~candle.nn.Module`,
or an iterable of key-value pairs of type (string, :class:`~candle.nn.Module`)
"""
if not isinstance(modules, container_abcs.Iterable):
raise TypeError(
"ModuleDict.update should be called with an "
"iterable of key/value pairs, but got " + type(modules).__name__
)
if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
for key, module in modules.items():
self[key] = module
else:
# modules here can be a list with two items
for j, m in enumerate(modules):
if not isinstance(m, container_abcs.Iterable):
raise TypeError(
"ModuleDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(m).__name__
)
if not len(m) == 2:
raise ValueError(
"ModuleDict update sequence element "
"#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
)
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
# that's too cumbersome to type correctly with overloads, so we add an ignore here
self[m[0]] = m[1] # type: ignore[assignment]
# remove forward alltogether to fallback on Module's _forward_unimplemented

View File

@ -0,0 +1,119 @@
import math
from typing import Any
import candle
from candle import Tensor
from .module import Module
# See https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/linear.py
class Identity(Module):
r"""A placeholder identity operator that is argument-insensitive.
Args:
args: any argument (unused)
kwargs: any keyword argument (unused)
Shape:
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
- Output: :math:`(*)`, same shape as the input.
Examples::
>>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
>>> input = candle.randn(128, 20)
>>> output = m(input)
>>> print(output.shape)
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
def forward(self, input: Tensor) -> Tensor:
return input
class Linear(Module):
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
Args:
in_features: size of each input sample
out_features: size of each output sample
bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``
Shape:
- Input: :math:`(*, H_{in})` where :math:`*` means any number of
dimensions including none and :math:`H_{in} = \text{in\_features}`.
- Output: :math:`(*, H_{out})` where all but the last dimension
are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
Attributes:
weight: the learnable weights of the module of shape
:math:`(\text{out\_features}, \text{in\_features})`. The values are
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
:math:`k = \frac{1}{\text{in\_features}}`
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
If :attr:`bias` is ``True``, the values are initialized from
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
:math:`k = \frac{1}{\text{in\_features}}`
"""
__constants__ = ["in_features", "out_features"]
in_features: int
out_features: int
weight: Tensor
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
# Allow 'weight' to be quantized
self._quantizable_buffers.add("weight")
self.in_features = in_features
self.out_features = out_features
# TODO: Do actual initialization here: e.g. kaiming_uniform or xavier_uniform
self.weight = candle.ones((out_features, in_features), **factory_kwargs)
if bias:
self.bias = candle.zeros((out_features,), **factory_kwargs)
else:
self.bias = None
def forward(self, x: Tensor) -> Tensor:
dims = x.shape
last_dim = dims[-1]
if isinstance(self.weight, candle.QTensor):
if len(dims) < 3:
matmul_result = self.weight.matmul_t(x).broadcast_add(self.bias)
elif len(dims) == 3:
b, n, m = dims
output_shape = (b, n, self.out_features)
re = x.reshape((b * n, m))
matmul_result = self.weight.matmul_t(re).reshape((output_shape))
else:
raise NotImplementedError("'QTensor.matmul_t' is not implemented for more than 3 dimensions")
if self.bias:
return matmul_result.broadcast_add(self.bias)
else:
if self.weight.shape[-1] == last_dim and len(dims) < 3:
w = self.weight.t()
else:
batch_size = dims[0]
w = self.weight.broadcast_left((batch_size,)).t()
x = x.matmul(w)
if self.bias is not None:
x = x.broadcast_add(self.bias)
return x
def extra_repr(self) -> str:
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"

View File

@ -0,0 +1,702 @@
from candle import Tensor, QTensor, DType
from typing import (
Dict,
Tuple,
Any,
Optional,
Union,
Iterator,
Set,
overload,
Mapping,
TypeVar,
List,
)
from collections import OrderedDict, namedtuple
TensorLike = Union[Tensor, QTensor]
T = TypeVar("T", bound="Module")
class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])):
def __repr__(self):
if not self.missing_keys and not self.unexpected_keys:
return "<All keys matched successfully>"
return super().__repr__()
__str__ = __repr__
# see: https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py
class Module:
"""
Pytorch like Module.
Base class for all neural network modules.
Your models should also subclass this class.
"""
_modules: Dict[str, Optional["Module"]]
_buffers: Dict[str, Optional[TensorLike]]
_non_persistent_buffers_set: Set[str]
_quantizable_buffers: Set[str]
_version: int = 1
def __init__(self, *args, **kwargs) -> None:
"""
Initializes internal Module state
"""
super().__setattr__("_modules", OrderedDict())
super().__setattr__("_buffers", OrderedDict())
super().__setattr__("_non_persistent_buffers_set", set())
super().__setattr__("_quantizable_buffers", set())
def __call__(self, *input):
"""
Call self as a function.
"""
return self.forward(*input)
def forward(self, *input):
"""
Defines the computation performed at every call.
Should be overridden by all subclasses.
"""
pass
def children(self) -> Iterator["Module"]:
r"""Returns an iterator over immediate children modules.
Yields:
Module: a child module
"""
for name, module in self.named_children():
yield module
def named_children(self) -> Iterator[Tuple[str, "Module"]]:
r"""Returns an iterator over immediate children modules, yielding both
the name of the module as well as the module itself.
Yields:
(str, Module): Tuple containing a name and child module
Example::
>>> for name, module in model.named_children():
>>> if name in ['conv4', 'conv5']:
>>> print(module)
"""
memo = set()
for name, module in self._modules.items():
if module is not None and module not in memo:
memo.add(module)
yield name, module
def add_module(self, name: str, module: Optional["Module"]) -> None:
r"""Adds a child module to the current module.
The module can be accessed as an attribute using the given name.
Args:
name (str): name of the child module. The child module can be
accessed from this module using the given name
module (Module): child module to be added to the module.
"""
if not isinstance(module, Module) and module is not None:
raise TypeError(f"{str(module)} is not a Module subclass")
elif not isinstance(name, str):
raise TypeError(f"module name should be a string. Got {name}")
elif hasattr(self, name) and name not in self._modules:
raise KeyError(f"attribute '{name}' already exists")
elif "." in name:
raise KeyError(f'module name can\'t contain ".", got: {name}')
elif name == "":
raise KeyError('module name can\'t be empty string ""')
self._modules[name] = module
def register_module(self, name: str, module: Optional["Module"]) -> None:
r"""Alias for :func:`add_module`."""
self.add_module(name, module)
def modules(self) -> Iterator["Module"]:
r"""Returns an iterator over all modules in the network."""
for _, module in self.named_modules():
yield module
def named_modules(
self,
memo: Optional[Set["Module"]] = None,
prefix: str = "",
remove_duplicate: bool = True,
):
r"""Returns an iterator over all modules in the network, yielding
both the name of the module as well as the module itself.
Args:
memo: a memo to store the set of modules already added to the result
prefix: a prefix that will be added to the name of the module
remove_duplicate: whether to remove the duplicated module instances in the result
or not
Yields:
(str, Module): Tuple of name and module
Note:
Duplicate modules are returned only once. In the following
example, ``l`` will be returned only once.
"""
if memo is None:
memo = set()
if self not in memo:
if remove_duplicate:
memo.add(self)
yield prefix, self
for name, module in self._modules.items():
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
yield m
def buffers(self, recurse: bool = True) -> Iterator[TensorLike]:
"""
Returns an iterator over module buffers.
"""
for name, buf in self.named_buffers(recurse=recurse):
yield buf
def named_buffers(
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
) -> Iterator[Tuple[str, TensorLike]]:
r"""Returns an iterator over module buffers, yielding both the
name of the buffer as well as the buffer itself.
Args:
prefix (str): prefix to prepend to all buffer names.
recurse (bool, optional): if True, then yields buffers of this module
and all submodules. Otherwise, yields only buffers that
are direct members of this module. Defaults to True.
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
Yields:
(str, Tensor): Tuple containing the name and buffer
Example::
>>> for name, buf in self.named_buffers():
>>> if name in ['running_var']:
>>> print(buf.size())
"""
gen = self._named_members(
lambda module: module._buffers.items(),
prefix=prefix,
recurse=recurse,
remove_duplicate=remove_duplicate,
)
yield from gen
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
# back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
@overload
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
...
@overload
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
...
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
r"""Returns a dictionary containing references to the whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
.. note::
The returned object is a shallow copy. It contains references
to the module's parameters and buffers.
.. warning::
Currently ``state_dict()`` also accepts positional arguments for
``destination``, ``prefix`` and ``keep_vars`` in order. However,
this is being deprecated and keyword arguments will be enforced in
future releases.
.. warning::
Please avoid the use of argument ``destination`` as it is not
designed for end-users.
Args:
destination (dict, optional): If provided, the state of module will
be updated into the dict and the same object is returned.
Otherwise, an ``OrderedDict`` will be created and returned.
Default: ``None``.
prefix (str, optional): a prefix added to parameter and buffer
names to compose the keys in state_dict. Default: ``''``.
keep_vars (bool, optional): by default the :class:`~candle.Tensor` s
returned in the state dict are detached from autograd. If it's
set to ``True``, detaching will not be performed.
Default: ``False``.
Returns:
dict:
a dictionary containing a whole state of the module
Example::
>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
"""
# TODO: Remove `args` and the parsing logic when BC allows.
if len(args) > 0:
if destination is None:
destination = args[0]
if len(args) > 1 and prefix == "":
prefix = args[1]
if len(args) > 2 and keep_vars is False:
keep_vars = args[2]
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
self._save_to_state_dict(destination, prefix, keep_vars)
for name, module in self._modules.items():
if module is not None:
module.state_dict(
destination=destination,
prefix=prefix + name + ".",
keep_vars=keep_vars,
)
return destination
def _save_to_state_dict(self, destination, prefix, keep_vars):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~candle.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
if isinstance(buf, Tensor):
destination[prefix + name] = buf if keep_vars else buf.detach()
else:
destination[prefix + name] = buf
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
r"""Copies parameters and buffers from :attr:`state_dict` into
this module and its descendants. If :attr:`strict` is ``True``, then
the keys of :attr:`state_dict` must exactly match the keys returned
by this module's :meth:`~candle.nn.Module.state_dict` function.
.. warning::
If :attr:`assign` is ``True`` the optimizer must be created after
the call to :attr:`load_state_dict`.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
strict (bool, optional): whether to strictly enforce that the keys
in :attr:`state_dict` match the keys returned by this module's
:meth:`~candle.nn.Module.state_dict` function. Default: ``True``
assign (bool, optional): whether to assign items in the state
dictionary to their corresponding keys in the module instead
of copying them inplace into the module's current parameters and buffers.
When ``False``, the properties of the tensors in the current
module are preserved while when ``True``, the properties of the
Tensors in the state dict are preserved.
Default: ``False``
Returns:
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
Note:
If a parameter or buffer is registered as ``None`` and its corresponding key
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
``RuntimeError``.
"""
if not isinstance(state_dict, Mapping):
raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")
missing_keys: List[str] = []
unexpected_keys: List[str] = []
error_msgs: List[str] = []
# copy state_dict so _load_from_state_dict can modify it
metadata = getattr(state_dict, "_metadata", None)
state_dict = OrderedDict(state_dict)
if metadata is not None:
# mypy isn't aware that "_metadata" exists in state_dict
state_dict._metadata = metadata # type: ignore[attr-defined]
def load(module, local_state_dict, prefix=""):
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
if assign:
local_metadata["assign_to_params_buffers"] = assign
module._load_from_state_dict(
local_state_dict,
prefix,
local_metadata,
True,
missing_keys,
unexpected_keys,
error_msgs,
)
for name, child in module._modules.items():
if child is not None:
child_prefix = prefix + name + "."
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
load(child, child_state_dict, child_prefix)
load(self, state_dict)
del load
if strict:
if len(unexpected_keys) > 0:
error_msgs.insert(
0,
"Unexpected key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in unexpected_keys)),
)
if len(missing_keys) > 0:
error_msgs.insert(
0,
"Missing key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in missing_keys)),
)
if len(error_msgs) > 0:
raise RuntimeError(
"Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs))
)
return _IncompatibleKeys(missing_keys, unexpected_keys)
def _load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
strict,
missing_keys,
unexpected_keys,
error_msgs,
):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~candle.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
Additionally, :attr:`local_metadata` can also contain the key
`assign_to_params_buffers` that indicates whether keys should be
assigned their corresponding tensor in the state_dict.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~candle.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~candle.nn.Module.load_state_dict`
"""
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = persistent_buffers.items()
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not isinstance(input_param, (Tensor, QTensor)):
error_msgs.append(
f'While copying the parameter named "{key}", '
"expected Tensor-like object from checkpoint but "
f"received {type(input_param)}"
)
continue
if input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append(
"size mismatch for {}: copying a param with shape {} from checkpoint, "
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
)
continue
try:
# Shape checks are already done above -> Just assign tensor
setattr(self, name, input_param)
except Exception as ex:
error_msgs.append(
f'While copying the parameter named "{key}", '
f"whose dimensions in the model are {param.shape} and "
f"whose dimensions in the checkpoint are {input_param.shape}, "
f"an exception occurred : {ex.args}."
)
elif strict:
missing_keys.append(key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix):
input_name = key[len(prefix) :]
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def _named_members(self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True):
r"""Helper method for yielding various names + members of modules."""
memo = set()
modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)]
for module_prefix, module in modules:
members = get_members_fn(module)
for k, v in members:
if v is None or v in memo:
continue
if remove_duplicate:
memo.add(v)
name = module_prefix + ("." if module_prefix else "") + k
yield name, v
def _get_name(self):
return self.__class__.__name__
def _apply(self, fn):
for module in self.children():
module._apply(fn)
for key, buf in self._buffers.items():
if buf is not None:
self._buffers[key] = fn(buf)
return self
def __move_tensor_to_device(self, tensor: TensorLike, device: str):
if isinstance(tensor, Tensor):
return tensor.to_device(device)
else:
raise NotImplementedError("Cannot offload QTensor to cuda, yet!")
def device(self) -> str:
"""
Gets the device of the module, by inspecting its tensors.
"""
tensor = next(self.buffers())
if isinstance(tensor, Tensor):
return tensor.device
else:
# QTensors can only be on the CPU
return "cpu"
def cuda(self: T) -> T:
r"""Moves all model parameters and buffers to the GPU.
This also makes associated parameters and buffers different objects. So
it should be called before constructing optimizer if the module will
live on GPU while being optimized.
.. note::
This method modifies the module in-place.
Returns:
Module: self
"""
def to_cuda(t: TensorLike):
return self.__move_tensor_to_device(t, "cuda")
return self._apply(to_cuda)
def cpu(self: T) -> T:
r"""Moves all model parameters and buffers to the CPU.
.. note::
This method modifies the module in-place.
Returns:
Module: self
"""
def to_cpu(t: TensorLike):
return self.__move_tensor_to_device(t, "cpu")
return self._apply(to_cpu)
def __cast_tensor(self, tensor: TensorLike, dtype: Union[DType, str]):
if isinstance(tensor, Tensor):
return tensor.to_dtype(dtype)
else:
raise TypeError("candle.Module.to only accepts Tensor dtypes, but got desired dtype={}".format(dtype))
def type(self: T, dst_type: Union[DType, str]) -> T:
r"""Casts all parameters and buffers to :attr:`dst_type`.
.. note::
This method modifies the module in-place.
Args:
dst_type (type or string): the desired type
Returns:
Module: self
"""
def cast(t: TensorLike):
return self.__cast_tensor(t, dst_type)
return self._apply(cast)
@overload
def to(
self: T,
device: str = ...,
dtype: Optional[Union[DType, str]] = ...,
) -> T:
...
@overload
def to(self: T, dtype: Union[DType, str]) -> T:
...
def to(self, *args, **kwargs):
r"""Moves and/or casts the parameters and buffers.
This can be called as
.. function:: to(device=None, dtype=None)
:noindex:
.. function:: to(dtype)
:noindex:
See below for examples.
.. note::
This method modifies the module in-place.
Args:
device (:class:`candle.device`): the desired device of the parameters
and buffers in this module
dtype (:class:`candle.dtype`): the desired floating point dtype of
the parameters and buffers in this module
Returns:
Module: self
"""
device = None
dtype = None
if args:
for arg in args:
# Assuming arg can be a string representing a device or a dtype
if isinstance(arg, str):
lower_arg = str(arg).lower()
if lower_arg.startswith("cuda") or lower_arg == "cpu":
device = lower_arg
else:
dtype = arg
elif isinstance(arg, DType):
dtype = str(arg)
else:
raise TypeError("Module.to() received an invalid combination of arguments. Got: {}".format(args))
if kwargs:
device = kwargs.get("device", device)
dtype = str(kwargs.get("dtype", dtype))
if device:
device = device.lower()
if dtype:
dtype = dtype.lower()
if dtype not in ["f32", "f16", "f64"]:
raise TypeError(
"candle.Module.to only accepts floating point" "dtypes, but got desired dtype={}".format(dtype)
)
def convert(t):
if dtype:
t = self.__cast_tensor(t, dtype)
if device:
t = self.__move_tensor_to_device(t, device)
return t
return self._apply(convert)
def __setattr__(self, __name: str, __value: Any) -> None:
if isinstance(__value, Module):
self._modules[__name] = __value
elif isinstance(__value, QTensor):
if __name in self._quantizable_buffers:
type = __value.ggml_dtype.lower()
if type in ["f32", "f16"]:
# It is faster to just dequantize the tensor here and use the normal tensor operations
dequant = __value.dequantize()
if type == "f16":
dequant = dequant.to_dtype("f16")
self._buffers[__name] = dequant
else:
self._buffers[__name] = __value
else:
# We expect a normal tensor here => dequantize it
self._buffers[__name] = __value.dequantize()
elif isinstance(__value, Tensor):
self._buffers[__name] = __value
else:
super().__setattr__(__name, __value)
def __getattr__(self, __name: str) -> Any:
if "_modules" in self.__dict__:
modules = self.__dict__["_modules"]
if __name in modules:
return modules[__name]
if "_buffers" in self.__dict__:
tensors = self.__dict__["_buffers"]
if __name in tensors:
return tensors[__name]
return super().__getattribute__(__name)
def __delattr__(self, name):
if name in self._buffers:
del self._buffers[name]
elif name in self._modules:
del self._modules[name]
else:
super().__delattr__(name)

View File

@ -0,0 +1,54 @@
import candle
from candle import Tensor
from .module import Module
from typing import Union, List, Tuple, Optional, Any
_shape_t = Union[int, List[int]]
import numbers
class LayerNorm(Module):
r"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`
math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
"""
__constants__ = ["normalized_shape", "eps"]
normalized_shape: Tuple[int, ...]
eps: float
def __init__(
self,
normalized_shape: _shape_t,
eps: float = 1e-5,
bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape,)
self.normalized_shape = tuple(normalized_shape)
self.eps = eps
self.weight = candle.ones(normalized_shape, **factory_kwargs)
if bias:
self.bias = candle.zeros(normalized_shape, **factory_kwargs)
else:
self.bias = None
def forward(self, input: Tensor) -> Tensor:
mean_x = input.sum_keepdim(2) / float(self.normalized_shape[-1])
x = input.broadcast_sub(mean_x)
norm_x = x.sqr().sum_keepdim(2) / float(self.normalized_shape[-1])
x_normed = x.broadcast_div((norm_x + self.eps).sqrt())
x = x_normed.broadcast_mul(self.weight)
if self.bias:
x = x.broadcast_add(self.bias)
return x
def extra_repr(self) -> str:
return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)

View File

@ -0,0 +1,39 @@
from .module import Module
from typing import Optional, Tuple, Any
from candle import Tensor
import candle
class Embedding(Module):
"""A simple lookup table that stores embeddings of a fixed dictionary and size.
This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, and the output is the corresponding
word embeddings.
Args:
num_embeddings (int): size of the dictionary of embeddings
embedding_dim (int): the size of each embedding vector
Attributes:
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
initialized from :math:`\mathcal{N}(0, 1)`
Shape:
- Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
"""
def __init__(self, num_embeddings: int, embedding_dim: int, device=None) -> None:
factory_kwargs = {"device": device}
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.weight = candle.randn((num_embeddings, embedding_dim), **factory_kwargs)
def forward(self, indexes: Tensor) -> Tensor:
final_dims = list(indexes.shape)
final_dims.append(self.embedding_dim)
indexes = indexes.flatten_all()
values = self.weight.index_select(indexes, 0)
return values.reshape(final_dims)

View File

@ -2,7 +2,7 @@ from typing import TypeVar, Union, Sequence
_T = TypeVar("_T")
_ArrayLike = Union[
_ArrayLike = Union[
_T,
Sequence[_T],
Sequence[Sequence[_T]],
@ -10,7 +10,7 @@ _ArrayLike = Union[
Sequence[Sequence[Sequence[Sequence[_T]]]],
]
CPU:str = "cpu"
CUDA:str = "cuda"
CPU: str = "cpu"
CUDA: str = "cuda"
Device = TypeVar("Device", CPU, CUDA)
Device = TypeVar("Device", CPU, CUDA)

View File

@ -28,3 +28,7 @@ features = ["pyo3/extension-module"]
[tool.black]
line-length = 119
target-version = ['py35']
[project.optional-dependencies]
testing = ["pytest", "black==22.3"]
huggingface = ["transformers>=4.33.3", "huggingface-hub>=0.17.3"]

View File

@ -2,181 +2,59 @@
import sys
from typing import Dict, Tuple, Any
import candle
from candle import Tensor, QTensor, utils, nn
from candle.models.llama import QuantizedLlama
from candle import utils
MAX_SEQ_LEN = 4096
def masked_fill(on_false:Tensor, mask:Tensor, on_true:Tensor):
shape = mask.shape
on_true = candle.tensor(on_true).broadcast_as(shape)
return mask.where_cond(on_true, on_false)
class RmsNorm:
def __init__(self, qtensor:QTensor):
self.weight = qtensor.dequantize()
def __call__(self, x:Tensor):
b_size, seq_len, hidden_size = x.shape
norm_x = x.sqr().sum_keepdim(2) / hidden_size
x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
return x_normed.broadcast_mul(self.weight)
class QuantizedLayer:
def __init__(self, layer_idx:int, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor], cos_sin:Tuple[Tensor,Tensor]):
p = f"layers.{layer_idx}"
self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
self.attention_wv = all_tensors[f"{p}.attention.wv.weight"]
self.attention_wo = all_tensors[f"{p}.attention.wo.weight"]
self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"]
self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"]
self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"]
self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"])
self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"])
self.n_head = hparams["n_head"]
self.n_kv_head = self.n_head
self.head_dim = hparams["n_embd"] // self.n_head
self.kv_cache = None
self.cos = cos_sin[0]
self.sin = cos_sin[1]
def __call__(self, x:Tensor, mask:Tensor, index_pos:int):
residual = x
x = self.attn_norm(x)
attn = self.forward_attn(x, mask, index_pos)
x = attn + residual
residual = x
x = self.ffn_norm(x)
w1 = self.ffw1.matmul_t(x)
w3 = self.ffw3.matmul_t(x)
mlp = self.ffw2.matmul_t(nn.silu(w1) * w3)
return mlp + residual
def forward_attn(self, x:Tensor, mask:Tensor, index_pos:int):
b_size, seq_len, n_embd = x.shape
q = self.attention_wq.matmul_t(x)
k = self.attention_wk.matmul_t(x)
v = self.attention_wv.matmul_t(x)
q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2)
k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
q = self.apply_rotary_emb(q, index_pos)
k = self.apply_rotary_emb(k, index_pos)
if self.kv_cache is not None and index_pos > 0:
prev_k, prev_v = self.kv_cache
k = candle.cat([prev_k, k], 2).contiguous()
v = candle.cat([prev_v, v], 2).contiguous()
self.kv_cache = (k, v)
# TODO: maybe repeat k/v here if we start supporting MQA.
att = q.matmul(k.t()) / self.head_dim**0.5
mask = mask.broadcast_as(att.shape)
att = masked_fill(att, mask, float("-inf"))
att = nn.softmax(att, -1)
y = att.matmul(v.contiguous())
y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
return self.attention_wo.matmul_t(y)
def apply_rotary_emb(self, x:Tensor, index_pos:int):
(b_size, n_head, seq_len, n_embd) = x.shape
cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
x = x.reshape((b_size, n_head, seq_len, n_embd//2, 2))
x0 = x.narrow(-1, 0, 1)
x1 = x.narrow(-1, 1, 1)
y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin)
y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos)
rope = candle.cat([y0, y1], -1)
return rope.flatten_from(-2)
def precompute_freqs_cis(hparams, freq_base):
head_dim = hparams["n_embd"] // hparams["n_head"]
theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)]
theta = candle.tensor(theta)
idx_theta = [float(i) for i in range(MAX_SEQ_LEN)]
idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1))
m = idx_theta.matmul(theta.unsqueeze(0))
return (m.cos(), m.sin())
class QuantizedLlama:
def __init__(self, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor]):
self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
self.norm = RmsNorm(all_tensors["norm.weight"])
self.output = all_tensors["output.weight"]
self.layers = []
rope_freq = hparams.get("rope_freq", 10000.)
cos_sin = precompute_freqs_cis(hparams, rope_freq)
for layer_idx in range(hparams["n_layer"]):
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
self.layers.append(layer)
def __call__(self, token:Tensor, index_pos:int):
b_size, seq_len = token.shape
vocab_size, hidden_size = self.tok_embeddings.shape
token = token.reshape((b_size * seq_len,))
x = self.tok_embeddings.index_select(token, 0)
x = x.reshape((b_size, seq_len, hidden_size))
mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)]
mask = candle.tensor(mask).reshape((seq_len, seq_len))
for layer in self.layers:
x = layer(x, mask, index_pos)
x = self.norm(x)
x = x.narrow(1, -1, 1).squeeze(1)
x = self.output.matmul_t(x)
return x
def gguf_rename(tensor_name:str):
if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight'
if tensor_name == 'output_norm.weight': return 'norm.weight'
tensor_name = tensor_name.replace('blk.', 'layers.')
tensor_name = tensor_name.replace('.attn_q.', '.attention.wq.')
tensor_name = tensor_name.replace('.attn_k.', '.attention.wk.')
tensor_name = tensor_name.replace('.attn_v.', '.attention.wv.')
tensor_name = tensor_name.replace('.attn_output.', '.attention.wo.')
tensor_name = tensor_name.replace('.ffn_gate.', '.feed_forward.w1.')
tensor_name = tensor_name.replace('.ffn_down.', '.feed_forward.w2.')
tensor_name = tensor_name.replace('.ffn_up.', '.feed_forward.w3.')
tensor_name = tensor_name.replace('.attn_norm.', '.attention_norm.')
def gguf_rename(tensor_name: str):
if tensor_name == "token_embd.weight":
return "tok_embeddings.weight"
if tensor_name == "output_norm.weight":
return "norm.weight"
tensor_name = tensor_name.replace("blk.", "layers.")
tensor_name = tensor_name.replace(".attn_q.", ".attention.wq.")
tensor_name = tensor_name.replace(".attn_k.", ".attention.wk.")
tensor_name = tensor_name.replace(".attn_v.", ".attention.wv.")
tensor_name = tensor_name.replace(".attn_output.", ".attention.wo.")
tensor_name = tensor_name.replace(".ffn_gate.", ".feed_forward.w1.")
tensor_name = tensor_name.replace(".ffn_down.", ".feed_forward.w2.")
tensor_name = tensor_name.replace(".ffn_up.", ".feed_forward.w3.")
tensor_name = tensor_name.replace(".attn_norm.", ".attention_norm.")
return tensor_name
def main():
if len(sys.argv) < 2:
raise ValueError("missing weight file argument")
filename = sys.argv[1]
print(f"reading model file {filename}")
if filename.endswith("gguf"):
all_tensors, metadata = utils.load_gguf(sys.argv[1])
all_tensors, metadata = utils.load_gguf(filename)
vocab = metadata["tokenizer.ggml.tokens"]
for i, v in enumerate(vocab):
vocab[i] = '\n' if v == '<0x0A>' else v.replace('', ' ')
vocab[i] = "\n" if v == "<0x0A>" else v.replace("", " ")
hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")}
print(hparams)
hparams = {
'n_vocab': len(vocab),
'n_embd': metadata['llama.embedding_length'],
'n_mult': 256,
'n_head': metadata['llama.attention.head_count'],
'n_head_kv': metadata['llama.attention.head_count_kv'],
'n_layer': metadata['llama.block_count'],
'n_rot': metadata['llama.rope.dimension_count'],
'rope_freq': metadata.get('llama.rope.freq_base', 10000.),
'ftype': metadata['general.file_type'],
"n_vocab": len(vocab),
"n_embd": metadata["llama.embedding_length"],
"n_mult": 256,
"n_head": metadata["llama.attention.head_count"],
"n_head_kv": metadata["llama.attention.head_count_kv"],
"n_layer": metadata["llama.block_count"],
"n_rot": metadata["llama.rope.dimension_count"],
"rope_freq": metadata.get("llama.rope.freq_base", 10000.0),
"ftype": metadata["general.file_type"],
"context_length": metadata["llama.context_length"],
}
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
all_tensors = {gguf_rename(k): v for k, v in all_tensors.items()}
else:
all_tensors, hparams, vocab = utils.load_ggml(sys.argv[1])
all_tensors, hparams, vocab = utils.load_ggml(filename)
hparams["context_length"] = 2048
print(hparams)
model = QuantizedLlama(hparams, all_tensors)
print("model built, starting inference")
@ -185,13 +63,14 @@ def main():
for token_idx in range(500):
last_token = tokens[-1]
lt = candle.tensor([last_token]).unsqueeze(0)
logits = model(lt, len(tokens))
logits = model.forward(lt, len(tokens))
# Greedy sampling for now
# pr = candle.nn.softmax(logits, -1)
m = logits.get(0).argmax_keepdim(-1)
next_token = m.values()[0]
print(vocab[next_token], end='', flush=True)
print(vocab[next_token], end="", flush=True)
tokens.append(next_token)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -3,6 +3,7 @@ use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
use pyo3::ToPyObject;
use std::os::raw::c_long;
use std::sync::Arc;
use half::{bf16, f16};
@ -196,6 +197,12 @@ trait MapDType {
}
}
enum Indexer {
Index(usize),
Slice(usize, usize),
Elipsis,
}
#[pymethods]
impl PyTensor {
#[new]
@ -436,6 +443,95 @@ impl PyTensor {
))
}
#[getter]
/// Index a tensor.
/// &RETURNS&: Tensor
fn __getitem__(&self, py: Python, idx: PyObject) -> PyResult<Self> {
let mut indexers: Vec<Indexer> = vec![];
let dims = self.0.shape().dims();
let to_absolute_index = |index: isize, current_dim: usize| {
// Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0]
let actual_index = if index < 0 {
dims[current_dim] as isize + index
} else {
index
};
// Check that the index is in range
if actual_index < 0 || actual_index >= dims[current_dim] as isize {
return Err(PyTypeError::new_err(format!(
"index out of range for dimension '{i}' with indexer '{value}'",
i = current_dim,
value = index
)));
}
Ok(actual_index as usize)
};
if let Ok(index) = idx.extract(py) {
// Handle a single index e.g. tensor[0] or tensor[-1]
indexers.push(Indexer::Index(to_absolute_index(index, 0)?));
} else if let Ok(slice) = idx.downcast::<pyo3::types::PySlice>(py) {
// Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
let index = slice.indices(dims[0] as c_long)?;
indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
} else if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
// Handle multiple indices e.g. tensor[0,0] or tensor[0:1,0:1]
if tuple.len() > dims.len() {
return Err(PyTypeError::new_err("provided too many indices"));
}
for (i, item) in tuple.iter().enumerate() {
if item.is_ellipsis() {
// Handle '...' e.g. tensor[..., 0]
if i > 0 {
return Err(PyTypeError::new_err("Ellipsis ('...') can only be used at the start of an indexing operation"));
}
indexers.push(Indexer::Elipsis);
} else if let Ok(slice) = item.downcast::<pyo3::types::PySlice>() {
// Handle slice
let index = slice.indices(dims[i] as c_long)?;
indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
} else if let Ok(index) = item.extract::<isize>() {
indexers.push(Indexer::Index(to_absolute_index(index, i)?));
} else {
return Err(PyTypeError::new_err("unsupported index"));
}
}
} else {
return Err(PyTypeError::new_err("unsupported index"));
}
let mut x = self.0.clone();
let mut current_dim = 0;
// Apply the indexers
for indexer in indexers.iter() {
x = match indexer {
Indexer::Index(n) => x
.narrow(current_dim, *n, 1)
.map_err(wrap_err)?
.squeeze(current_dim)
.map_err(wrap_err)?,
Indexer::Slice(start, stop) => {
let out = x
.narrow(current_dim, *start, stop.saturating_sub(*start))
.map_err(wrap_err)?;
current_dim += 1;
out
}
Indexer::Elipsis => {
// Elipsis is a special case, it means that all remaining dimensions should be selected => advance the current_dim to the last dimension we have indexers for
current_dim += dims.len() - (indexers.len() - 1);
x
}
}
}
Ok(Self(x))
}
/// Add two tensors.
/// &RETURNS&: Tensor
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
@ -697,7 +793,7 @@ impl PyTensor {
/// &RETURNS&: QTensor
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
use ::candle::quantized;
let res = match quantized_dtype {
let res = match quantized_dtype.to_lowercase().as_str() {
"q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
"q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
"q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
@ -1137,9 +1233,39 @@ fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
Ok(PyTensor(s))
}
fn candle_nn_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[pyfunction]
#[pyo3(text_signature = "(tensor:Tensor)")]
/// Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
/// &RETURNS&: Tensor
fn gelu(tensor: PyTensor) -> PyResult<PyTensor> {
let s = tensor.0.gelu_erf().map_err(wrap_err)?;
Ok(PyTensor(s))
}
#[pyfunction]
#[pyo3(text_signature = "(tensor:Tensor)")]
/// Applies the Rectified Linear Unit (ReLU) function to a given tensor.
/// &RETURNS&: Tensor
fn relu(tensor: PyTensor) -> PyResult<PyTensor> {
let s = tensor.0.relu().map_err(wrap_err)?;
Ok(PyTensor(s))
}
#[pyfunction]
#[pyo3(text_signature = "(tensor:Tensor)")]
/// Applies the tanh function to a given tensor.
/// &RETURNS&: Tensor
fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
let s = tensor.0.tanh().map_err(wrap_err)?;
Ok(PyTensor(s))
}
fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(silu, m)?)?;
m.add_function(wrap_pyfunction!(softmax, m)?)?;
m.add_function(wrap_pyfunction!(gelu, m)?)?;
m.add_function(wrap_pyfunction!(relu, m)?)?;
m.add_function(wrap_pyfunction!(tanh, m)?)?;
Ok(())
}
@ -1148,8 +1274,8 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?;
candle_utils(py, utils)?;
m.add_submodule(utils)?;
let nn = PyModule::new(py, "nn")?;
candle_nn_m(py, nn)?;
let nn = PyModule::new(py, "functional")?;
candle_functional_m(py, nn)?;
m.add_submodule(nn)?;
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;

View File

@ -1,4 +1,4 @@
#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
# See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
import argparse
import inspect
import os
@ -23,7 +23,7 @@ def do_indent(text: Optional[str], indent: str):
return text.replace("\n", f"\n{indent}")
def function(obj, indent:str, text_signature:str=None):
def function(obj, indent: str, text_signature: str = None):
if text_signature is None:
text_signature = obj.__text_signature__
@ -32,12 +32,12 @@ def function(obj, indent:str, text_signature:str=None):
if doc_string is None:
doc_string = ""
# Check if we have a return type annotation in the docstring
# Check if we have a return type annotation in the docstring
return_type = None
doc_lines = doc_string.split("\n")
if doc_lines[-1].lstrip().startswith(RETURN_TYPE_MARKER):
# Extract the return type and remove it from the docstring
return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER):].strip()
return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER) :].strip()
doc_string = "\n".join(doc_lines[:-1])
string = ""
@ -115,7 +115,7 @@ def pyi_file(obj, indent=""):
body += f"{indent+INDENT}pass\n"
body += "\n"
for (name, fn) in fns:
for name, fn in fns:
body += pyi_file(fn, indent=indent)
if not body:
@ -221,12 +221,12 @@ if __name__ == "__main__":
args = parser.parse_args()
#Enable execution from the candle and candle-pyo3 directories
# Enable execution from the candle and candle-pyo3 directories
cwd = Path.cwd()
directory = "py_src/candle/"
if cwd.name != "candle-pyo3":
directory = f"candle-pyo3/{directory}"
import candle
write(candle.candle, directory, "candle", check=args.check)

View File

@ -7,7 +7,7 @@ print(t + t)
t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
print(t)
print(t+t)
print(t + t)
t = t.reshape([2, 4])
print(t.matmul(t.t()))

View File

View File

@ -0,0 +1,38 @@
import candle
from candle import Tensor
from candle.nn import Linear
def test_linear_layer_can_be_constructed():
linear = Linear(10, 10)
assert linear is not None
def test_linear_layer_can_forward_a_singular_input():
linear = Linear(384, 1536)
input_tensor = candle.randn((8, 384))
output = linear.forward(input_tensor)
assert output.shape == (8, 1536)
def test_linear_layer_can_forward_a_batched_input():
linear = Linear(384, 1536)
input_tensor = candle.randn((16, 8, 384))
output = linear.forward(input_tensor)
assert output.shape == (16, 8, 1536)
def test_quantized_linear_layer_can_forward_a_singular_input():
linear = Linear(384, 1536)
linear.weight = linear.weight.quantize("q4_0")
input_tensor = candle.randn((8, 384))
output = linear.forward(input_tensor)
assert output.shape == (8, 1536)
def test_quantized_linear_layer_can_forward_a_batched_input():
linear = Linear(384, 1536)
linear.weight = linear.weight.quantize("q4_0")
input_tensor = candle.randn((16, 8, 384))
output = linear.forward(input_tensor)
assert output.shape == (16, 8, 1536)

View File

@ -0,0 +1,161 @@
import candle
from candle import Tensor, QTensor
from candle.nn import Module, Linear
from candle.utils import cuda_is_available
import pytest
def test_module_can_be_constructed():
class A(Module):
pass
a = A()
assert a is not None
assert len(list(a.buffers())) == 0
def test_module_registers_tensors():
class A(Module):
def __init__(self):
super().__init__()
self.t = Tensor(42.0)
a = A()
named_buffers = dict(a.named_buffers())
assert len(named_buffers) == 1
assert "t" in named_buffers
def test_module_registers_submodules():
class A(Module):
def __init__(self):
super().__init__()
self.linear = Linear(10, 20)
a = A()
named_modules = dict(a.named_modules())
named_buffers = dict(a.named_buffers())
assert len(named_buffers) == 2
assert "linear" in named_modules
assert "linear.weight" in named_buffers
assert "linear.bias" in named_buffers
def test_module_can_dump_statedict():
class A(Module):
def __init__(self):
super().__init__()
self.linear = Linear(10, 20)
self.t = Tensor(42.0)
a = A()
state_dict = a.state_dict()
assert hasattr(state_dict, "_metadata")
assert "t" in state_dict
assert "linear.weight" in state_dict
assert "linear.bias" in state_dict
assert len(state_dict) == 3
def test_module_can_load_statedict():
class A(Module):
def __init__(self):
super().__init__()
self.linear = Linear(10, 20)
self.t = Tensor(42.0)
statedict = {
"linear.weight": candle.ones((20, 10)),
"linear.bias": candle.zeros((20,)),
"t": Tensor(42.0),
}
a = A()
a.load_state_dict(statedict)
def test_module_throws_on_shape_missmatch():
class A(Module):
def __init__(self):
super().__init__()
self.t = Tensor(42.0)
statedict = {
"t": candle.ones((20,)),
}
a = A()
with pytest.raises(RuntimeError) as excinfo:
a.load_state_dict(statedict)
assert "size mismatch" in str(excinfo.value)
def test_module_throws_on_missing_key():
class A(Module):
def __init__(self):
super().__init__()
self.t = Tensor(42.0)
statedict = {
"not_t": Tensor(42.0),
}
a = A()
with pytest.raises(RuntimeError) as excinfo:
a.load_state_dict(statedict)
assert 'Missing key(s) in state_dict: "t".' in str(excinfo.value)
def test_module_can_load_quantized_tensors():
class A(Module):
def __init__(self):
super().__init__()
self.t = candle.randn((16, 256))
self._quantizable_buffers.add("t")
statedict = {
"t": candle.ones((16, 256)).quantize("q4_0"),
}
a = A()
a.load_state_dict(statedict)
assert isinstance(a.t, QTensor)
assert a.t.ggml_dtype == "Q4_0"
def test_module_dequantizes_tensors_automaticaly():
class A(Module):
def __init__(self):
super().__init__()
self.t = candle.randn((16, 256))
statedict = {
"t": candle.ones((16, 256)).quantize("q4_0"),
}
a = A()
a.load_state_dict(statedict)
assert isinstance(a.t, Tensor)
@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
def test_module_can_be_moved_to_cuda():
class A(Module):
def __init__(self):
super().__init__()
self.t = candle.randn((16, 256))
a = A()
a.cuda()
assert a.t.device == "cuda"
@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
def test_module_can_be_moved_from_cuda_to_cpu():
class A(Module):
def __init__(self):
super().__init__()
self.t = candle.randn((16, 256))
a = A()
a.cuda()
assert a.t.device == "cuda"
a.cpu()
assert a.t.device == "cpu"

View File

@ -0,0 +1,74 @@
import candle
from candle import Tensor
def test_tensor_can_be_constructed():
t = Tensor(42.0)
assert t.values() == 42.0
def test_tensor_can_be_constructed_from_list():
t = Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
assert t.values() == [3.0, 1, 4, 1, 5, 9, 2, 6]
def test_tensor_can_be_constructed_from_list_of_lists():
t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]])
assert t.values() == [[3.0, 1, 4, 1], [5, 9, 2, 6]]
def test_tensor_can_be_quantized():
t = candle.randn((16, 256))
for format in [
"q4_0",
"q4_1",
"q5_0",
"q5_1",
"q8_0",
"q2k",
"q3k",
"q4k",
"q5k",
"q8k",
]:
for formatted_format in [format.upper(), format.lower()]:
quant_t = t.quantize(formatted_format)
assert quant_t.ggml_dtype.lower() == format.lower()
assert quant_t.shape == t.shape
def test_tensor_can_be_indexed():
t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]])
assert t[0].values() == [3.0, 1.0, 4.0, 1.0]
assert t[1].values() == [5.0, 9.0, 2.0, 6.0]
assert t[-1].values() == [5.0, 9.0, 2.0, 6.0]
assert t[-2].values() == [3.0, 1.0, 4.0, 1.0]
def test_tensor_can_be_sliced():
t = Tensor([3.0, 1, 4, 10, 5, 9, 2, 6])
assert t[0:4].values() == [3.0, 1.0, 4.0, 10.0]
assert t[4:8].values() == [5.0, 9.0, 2.0, 6.0]
assert t[-4:].values() == [5.0, 9.0, 2.0, 6.0]
assert t[:-4].values() == [3.0, 1.0, 4.0, 10.0]
assert t[-4:-2].values() == [5.0, 9.0]
def test_tensor_can_be_sliced_2d():
t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]])
assert t[:, 0].values() == [3.0, 5]
assert t[:, 1].values() == [1.0, 9.0]
assert t[0, 0].values() == 3.0
assert t[:, -1].values() == [1.0, 6.0]
assert t[:, -4].values() == [3.0, 5]
assert t[..., 0].values() == [3.0, 5]
def test_tensor_can_be_scliced_3d():
t = Tensor([[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]]])
assert t[:, :, 0].values() == [[1, 5], [9, 13]]
assert t[:, :, 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
assert t[:, 0, 0].values() == [1, 9]
assert t[..., 0].values() == [[1, 5], [9, 13]]
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]

View File

@ -0,0 +1,51 @@
import candle
from candle import Tensor, QTensor
from candle.utils import load_safetensors, save_gguf, load_gguf, save_safetensors
from pathlib import Path
TEST_DIR = Path(__file__).parent.parent / "_workdir"
TEST_DIR.mkdir(exist_ok=True)
def test_can_roundtrip_safetensors():
tensors = {
"a": candle.randn((16, 256)),
"b": candle.randn((16, 16)),
}
file = str(TEST_DIR / "test.safetensors")
save_safetensors(file, tensors)
loaded_tensors = load_safetensors(file)
assert set(tensors.keys()) == set(loaded_tensors.keys())
for key in tensors.keys():
assert tensors[key].values() == loaded_tensors[key].values(), "Values are not equal"
assert tensors[key].shape == loaded_tensors[key].shape, "Shapes are not equal"
assert str(tensors[key].dtype) == str(loaded_tensors[key].dtype), "Dtypes are not equal"
def test_can_roundtrip_gguf():
metadata = {
"a": 1,
"b": "foo",
"c": [1, 2, 3],
"d": [[1, 2], [3, 4]],
}
tensors = {
"a": candle.randn((16, 256)).quantize("q4_0"),
"b": candle.randn((16, 16)).quantize("f32"),
}
file = str(TEST_DIR / "test.gguf")
save_gguf(file, tensors, metadata)
loaded_tensors, loaded_metadata = load_gguf(file)
assert set(metadata.keys()) == set(loaded_metadata.keys())
for key in metadata.keys():
assert metadata[key] == loaded_metadata[key]
assert set(tensors.keys()) == set(loaded_tensors.keys())
for key in tensors.keys():
assert tensors[key].dequantize().values() == loaded_tensors[key].dequantize().values(), "Values are not equal"
assert tensors[key].shape == loaded_tensors[key].shape, "Shapes are not equal"
assert str(tensors[key].ggml_dtype) == str(loaded_tensors[key].ggml_dtype), "Dtypes are not equal"