mirror of
https://github.com/huggingface/candle.git
synced 2025-06-14 01:48:08 +00:00
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
This commit is contained in:
11
.vscode/settings.json
vendored
Normal file
11
.vscode/settings.json
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"[python]": {
|
||||
"editor.defaultFormatter": "ms-python.black-formatter"
|
||||
},
|
||||
"python.formatting.provider": "none",
|
||||
"python.testing.pytestArgs": [
|
||||
"candle-pyo3"
|
||||
],
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestEnabled": true
|
||||
}
|
1
candle-pyo3/.gitignore
vendored
1
candle-pyo3/.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
tests/_workdir
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
|
104
candle-pyo3/e5.py
Normal file
104
candle-pyo3/e5.py
Normal file
@ -0,0 +1,104 @@
|
||||
from candle.utils import load_safetensors, save_gguf, load_gguf
|
||||
from candle.models.bert import BertModel, Config
|
||||
import json
|
||||
from candle import Tensor
|
||||
from tqdm import tqdm
|
||||
from dataclasses import fields
|
||||
import os
|
||||
import time
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import BertTokenizer, AutoModel
|
||||
import torch
|
||||
|
||||
if __name__ == "__main__":
|
||||
model_name = "intfloat/e5-small-v2"
|
||||
model_file = hf_hub_download(repo_id=model_name, filename="model.safetensors")
|
||||
config_file = hf_hub_download(repo_id=model_name, filename="config.json")
|
||||
|
||||
tensors = load_safetensors(model_file)
|
||||
config = Config()
|
||||
with open(config_file, "r") as f:
|
||||
raw_config = json.load(f)
|
||||
for field in fields(config):
|
||||
if field.name in raw_config:
|
||||
setattr(config, field.name, raw_config[field.name])
|
||||
|
||||
# Load the model
|
||||
model = BertModel(config)
|
||||
model.load_state_dict(tensors)
|
||||
|
||||
hf_model = AutoModel.from_pretrained(model_name)
|
||||
tokenizer = BertTokenizer.from_pretrained(model_name)
|
||||
|
||||
sentences = [
|
||||
"The cat sits outside",
|
||||
"A man is playing guitar",
|
||||
"I love pasta",
|
||||
"The new movie is awesome",
|
||||
"The cat plays in the garden",
|
||||
"A woman watches TV",
|
||||
"The new movie is so great",
|
||||
"Do you like pizza?",
|
||||
]
|
||||
|
||||
def average_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor):
|
||||
"""Average the hidden states according to the attention mask"""
|
||||
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
||||
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
||||
|
||||
tokenized = tokenizer(sentences, padding=True)
|
||||
tokens = Tensor(tokenized["input_ids"])
|
||||
token_type_ids = Tensor(tokenized["token_type_ids"])
|
||||
encoder_out, _ = model.forward(tokens, token_type_ids)
|
||||
|
||||
hf_tokenized = tokenizer(sentences, padding=True, return_tensors="pt")
|
||||
hf_result = hf_model(**hf_tokenized)["last_hidden_state"]
|
||||
|
||||
hf_pooled = average_pool(hf_result, hf_tokenized["attention_mask"])
|
||||
candle_pooled = average_pool(torch.tensor(encoder_out.values()), hf_tokenized["attention_mask"])
|
||||
|
||||
loss = torch.nn.L1Loss()
|
||||
error = loss(hf_pooled, candle_pooled).mean().item()
|
||||
print(f"Mean error between torch-referenze and candle: {error}")
|
||||
|
||||
# Quantize all attention 'weights'
|
||||
quantized_tensors = {}
|
||||
for name, tensor in tqdm(tensors.items(), desc="Quantizing tensors to 5-Bit"):
|
||||
if name.endswith("weight") and ("attention" in name or "intermediate" in name or "output" in name):
|
||||
# check if the tensor is k-quantizable
|
||||
if tensor.shape[-1] % 256 == 0:
|
||||
new_tensor = tensor.quantize("q4k")
|
||||
else:
|
||||
new_tensor = tensor.quantize("q5_0")
|
||||
quantized_tensors[name] = new_tensor
|
||||
else:
|
||||
quantized_tensors[name] = tensor.quantize("q8_0")
|
||||
|
||||
print(f"Saving quantized tensors")
|
||||
# Remove all None values from the config
|
||||
config_to_save = {k: v for k, v in config.__dict__.items() if v is not None}
|
||||
# Save the model
|
||||
quantized_model_file = "e5_small.gguf"
|
||||
save_gguf(quantized_model_file, quantized_tensors, config_to_save)
|
||||
|
||||
file_size_mb = os.path.getsize(model_file) / 1024 / 1024
|
||||
file_size_mb_compressed = os.path.getsize(quantized_model_file) / 1024 / 1024
|
||||
print(f"Compressed model from {file_size_mb:.2f} MB to {file_size_mb_compressed:.2f} MB")
|
||||
# Load the model from the gguf
|
||||
tensors, raw_config = load_gguf(quantized_model_file)
|
||||
config = Config()
|
||||
for field in fields(config):
|
||||
if field.name in raw_config:
|
||||
setattr(config, field.name, raw_config[field.name])
|
||||
model = BertModel(config)
|
||||
# "embeddings.position_ids" is missing in the gguf as it is i64
|
||||
model.load_state_dict(tensors, strict=False)
|
||||
|
||||
# Run the model again
|
||||
encoder_out_2, pooled_output_2 = model.forward(tokens, token_type_ids)
|
||||
encoder_out_2, pooled_output_2 = encoder_out_2.to_device("cpu"), pooled_output_2.to_device("cpu")
|
||||
|
||||
candle_pooled_2 = average_pool(torch.tensor(encoder_out_2.values()), hf_tokenized["attention_mask"])
|
||||
error = loss(hf_pooled, candle_pooled_2).mean().item()
|
||||
print(f"Mean error between torch-referenze and quantized-candle: {error}")
|
@ -1,5 +1,30 @@
|
||||
from .candle import *
|
||||
import logging
|
||||
|
||||
try:
|
||||
from .candle import *
|
||||
except ImportError as e:
|
||||
# If we are in development mode, or we did not bundle the CUDA DLLs, we try to locate them here
|
||||
logging.warning("CUDA DLLs were not bundled with this package. Trying to locate them...")
|
||||
import os
|
||||
import platform
|
||||
|
||||
# Try to locate CUDA_PATH environment variable
|
||||
cuda_path = os.environ.get("CUDA_PATH", None)
|
||||
if cuda_path:
|
||||
logging.warning(f"Found CUDA_PATH environment variable: {cuda_path}")
|
||||
if platform.system() == "Windows":
|
||||
cuda_path = os.path.join(cuda_path, "bin")
|
||||
else:
|
||||
cuda_path = os.path.join(cuda_path, "lib64")
|
||||
|
||||
logging.warning(f"Adding {cuda_path} to DLL search path...")
|
||||
os.add_dll_directory(cuda_path)
|
||||
|
||||
try:
|
||||
from .candle import *
|
||||
except ImportError as inner_e:
|
||||
raise ImportError("Could not locate CUDA DLLs. Please check the documentation for more information.")
|
||||
|
||||
__doc__ = candle.__doc__
|
||||
if hasattr(candle, "__all__"):
|
||||
__all__ = candle.__all__
|
||||
__all__ = candle.__all__
|
||||
|
8
candle-pyo3/py_src/candle/functional/__init__.py
Normal file
8
candle-pyo3/py_src/candle/functional/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from .. import functional
|
||||
|
||||
gelu = functional.gelu
|
||||
relu = functional.relu
|
||||
silu = functional.silu
|
||||
softmax = functional.softmax
|
||||
tanh = functional.tanh
|
@ -4,6 +4,20 @@ from os import PathLike
|
||||
from candle.typing import _ArrayLike, Device
|
||||
from candle import Tensor, DType, QTensor
|
||||
|
||||
@staticmethod
|
||||
def gelu(tensor: Tensor) -> Tensor:
|
||||
"""
|
||||
Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def relu(tensor: Tensor) -> Tensor:
|
||||
"""
|
||||
Applies the Rectified Linear Unit (ReLU) function to a given tensor.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def silu(tensor: Tensor) -> Tensor:
|
||||
"""
|
||||
@ -17,3 +31,10 @@ def softmax(tensor: Tensor, dim: int) -> Tensor:
|
||||
Applies the Softmax function to a given tensor.#
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def tanh(tensor: Tensor) -> Tensor:
|
||||
"""
|
||||
Applies the tanh function to a given tensor.
|
||||
"""
|
||||
pass
|
194
candle-pyo3/py_src/candle/models/bert.py
Normal file
194
candle-pyo3/py_src/candle/models/bert.py
Normal file
@ -0,0 +1,194 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from candle.nn import Module, Embedding, LayerNorm, Linear, ModuleList
|
||||
from candle import Tensor
|
||||
import candle
|
||||
import candle.functional as F
|
||||
from typing import Tuple, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
vocab_size: int = 30522
|
||||
hidden_size: int = 768
|
||||
num_hidden_layers: int = 12
|
||||
num_attention_heads: int = 12
|
||||
intermediate_size: int = 3072
|
||||
hidden_act: str = "gelu"
|
||||
hidden_dropout_prob: float = 0.1
|
||||
max_position_embeddings: int = 512
|
||||
type_vocab_size: int = 2
|
||||
initializer_range: float = 0.02
|
||||
layer_norm_eps: float = 1e-12
|
||||
pad_token_id: int = 0
|
||||
position_embedding_type: str = "absolute"
|
||||
use_cache: bool = True
|
||||
classifier_dropout: Optional[float] = None
|
||||
model_type: Optional[str] = "bert"
|
||||
|
||||
|
||||
class BertSelfAttention(Module):
|
||||
def __init__(self, config: Config) -> None:
|
||||
super().__init__()
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
|
||||
all_head_size = int(config.num_attention_heads * self.attention_head_size)
|
||||
hidden_size = config.hidden_size
|
||||
self.query = Linear(hidden_size, all_head_size)
|
||||
self.key = Linear(hidden_size, all_head_size)
|
||||
self.value = Linear(hidden_size, all_head_size)
|
||||
|
||||
def transpose_for_scores(self, x: Tensor) -> Tensor:
|
||||
new_x_shape = x.shape[:-1] + (
|
||||
self.num_attention_heads,
|
||||
self.attention_head_size,
|
||||
)
|
||||
x = x.reshape(new_x_shape).transpose(1, 2)
|
||||
return x.contiguous()
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
query = self.query.forward(hidden_states)
|
||||
key = self.key.forward(hidden_states)
|
||||
value = self.value.forward(hidden_states)
|
||||
|
||||
query = self.transpose_for_scores(query)
|
||||
key = self.transpose_for_scores(key)
|
||||
value = self.transpose_for_scores(value)
|
||||
|
||||
attention_scores = query.matmul(key.t())
|
||||
attention_scores = attention_scores / (float(self.attention_head_size) ** 0.5)
|
||||
attention_probs = F.softmax(attention_scores, dim=-1)
|
||||
|
||||
context_layer = attention_probs.matmul(value)
|
||||
context_layer = context_layer.transpose(1, 2).contiguous()
|
||||
context_layer = context_layer.flatten_from(-2)
|
||||
return context_layer
|
||||
|
||||
|
||||
class BertSelfOutput(Module):
|
||||
def __init__(self, config: Config) -> None:
|
||||
super().__init__()
|
||||
self.dense = Linear(config.hidden_size, config.hidden_size)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
|
||||
hidden_states = self.dense.forward(hidden_states)
|
||||
return self.LayerNorm.forward(hidden_states + input_tensor)
|
||||
|
||||
|
||||
class BertAttention(Module):
|
||||
def __init__(self, config: Config) -> None:
|
||||
super().__init__()
|
||||
self.self = BertSelfAttention(config)
|
||||
self.output = BertSelfOutput(config)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
self_outputs = self.self.forward(hidden_states)
|
||||
attention_output = self.output.forward(self_outputs, hidden_states)
|
||||
return attention_output
|
||||
|
||||
|
||||
class BertIntermediate(Module):
|
||||
def __init__(self, config: Config) -> None:
|
||||
super().__init__()
|
||||
self.dense = Linear(config.hidden_size, config.intermediate_size)
|
||||
self.act = F.gelu if config.hidden_act == "gelu" else F.relu
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
hidden_states = self.dense.forward(hidden_states)
|
||||
return self.act(hidden_states)
|
||||
|
||||
|
||||
class BertOutput(Module):
|
||||
def __init__(self, config: Config) -> None:
|
||||
super().__init__()
|
||||
self.dense = Linear(config.intermediate_size, config.hidden_size)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: Tensor, input_tensor: Tensor) -> Tensor:
|
||||
hidden_states = self.dense.forward(hidden_states)
|
||||
return self.LayerNorm.forward(hidden_states + input_tensor)
|
||||
|
||||
|
||||
class BertLayer(Module):
|
||||
def __init__(self, config: Config) -> None:
|
||||
super().__init__()
|
||||
self.attention = BertAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
attention_output = self.attention.forward(hidden_states)
|
||||
# TODO: Support cross-attention?
|
||||
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523
|
||||
# TODO: Support something similar to `apply_chunking_to_forward`?
|
||||
intermediate_output = self.intermediate.forward(attention_output)
|
||||
layer_output = self.output.forward(intermediate_output, attention_output)
|
||||
return layer_output
|
||||
|
||||
|
||||
class BertEncoder(Module):
|
||||
def __init__(self, config: Config) -> None:
|
||||
super().__init__()
|
||||
self.layer = ModuleList()
|
||||
for _ in range(config.num_hidden_layers):
|
||||
self.layer.append(BertLayer(config))
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
for l in self.layer:
|
||||
hidden_states = l.forward(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class BertEmbeddings(Module):
|
||||
def __init__(self, config: Config) -> None:
|
||||
super().__init__()
|
||||
self.word_embeddings = Embedding(config.vocab_size, config.hidden_size)
|
||||
self.position_embeddings = Embedding(config.max_position_embeddings, config.hidden_size)
|
||||
self.token_type_embeddings = Embedding(config.type_vocab_size, config.hidden_size)
|
||||
self.LayerNorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.position_ids = candle.Tensor(list(range(config.max_position_embeddings))).reshape(
|
||||
(1, config.max_position_embeddings)
|
||||
)
|
||||
|
||||
def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tensor:
|
||||
(_batch_size, seq_len) = input_ids.shape
|
||||
input_embeddings = self.word_embeddings.forward(input_ids)
|
||||
token_type_embeddings = self.token_type_embeddings.forward(token_type_ids)
|
||||
embeddings: Tensor = input_embeddings + token_type_embeddings
|
||||
|
||||
position_ids = list(range(seq_len))
|
||||
position_ids = Tensor(position_ids).to_dtype(input_ids.dtype).to_device(input_ids.device)
|
||||
|
||||
embeddings = embeddings.broadcast_add(self.position_embeddings.forward(position_ids))
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class BertPooler(Module):
|
||||
def __init__(self, config: Config) -> None:
|
||||
super().__init__()
|
||||
self.dense = Linear(config.hidden_size, config.hidden_size)
|
||||
self.activation = F.tanh
|
||||
|
||||
def forward(self, hidden_states: Tensor) -> Tensor:
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense.forward(first_token_tensor)
|
||||
pooled_output = self.activation(pooled_output)
|
||||
return pooled_output
|
||||
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L874
|
||||
class BertModel(Module):
|
||||
def __init__(self, config: Config, add_pooling_layer=True) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = BertEncoder(config)
|
||||
self.pooler = BertPooler(config) if add_pooling_layer else None
|
||||
|
||||
def forward(self, input_ids: Tensor, token_type_ids: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
embeddings = self.embeddings.forward(input_ids, token_type_ids)
|
||||
encoder_out = self.encoder.forward(embeddings)
|
||||
pooled_output = self.pooler(encoder_out) if self.pooler is not None else None
|
||||
return encoder_out, pooled_output
|
150
candle-pyo3/py_src/candle/models/llama.py
Normal file
150
candle-pyo3/py_src/candle/models/llama.py
Normal file
@ -0,0 +1,150 @@
|
||||
import candle
|
||||
from typing import Dict, Tuple, Any
|
||||
from candle import Tensor, QTensor, utils, nn
|
||||
from candle.nn import Module, ModuleList
|
||||
|
||||
|
||||
def masked_fill(on_false: Tensor, mask: Tensor, on_true: Tensor):
|
||||
shape = mask.shape
|
||||
on_true = candle.tensor(on_true).broadcast_as(shape)
|
||||
return mask.where_cond(on_true, on_false)
|
||||
|
||||
|
||||
def precompute_freqs_cis(hparams: Dict[str, Any], freq_base: float, max_seq_len: int):
|
||||
head_dim = hparams["n_embd"] // hparams["n_head"]
|
||||
theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)]
|
||||
theta = candle.tensor(theta)
|
||||
idx_theta = [float(i) for i in range(max_seq_len)]
|
||||
idx_theta = candle.tensor(idx_theta).reshape((max_seq_len, 1))
|
||||
m = idx_theta.matmul(theta.unsqueeze(0))
|
||||
return (m.cos(), m.sin())
|
||||
|
||||
|
||||
class RmsNorm(Module):
|
||||
def __init__(self, qtensor: QTensor):
|
||||
super().__init__()
|
||||
self.weight = qtensor.dequantize()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
b_size, seq_len, hidden_size = x.shape
|
||||
norm_x = x.sqr().sum_keepdim(2) / hidden_size
|
||||
x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
|
||||
return x_normed.broadcast_mul(self.weight)
|
||||
|
||||
|
||||
class QuantizedLayer(Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_idx: int,
|
||||
hparams: Dict[str, Any],
|
||||
all_tensors: Dict[str, QTensor],
|
||||
cos_sin: Tuple[Tensor, Tensor],
|
||||
):
|
||||
super().__init__()
|
||||
p = f"layers.{layer_idx}"
|
||||
self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
|
||||
self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
|
||||
self.attention_wv = all_tensors[f"{p}.attention.wv.weight"]
|
||||
self.attention_wo = all_tensors[f"{p}.attention.wo.weight"]
|
||||
self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"]
|
||||
self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"]
|
||||
self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"]
|
||||
self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"])
|
||||
self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"])
|
||||
|
||||
self.n_head = hparams["n_head"]
|
||||
self.n_kv_head = self.n_head
|
||||
self.head_dim = hparams["n_embd"] // self.n_head
|
||||
|
||||
self.kv_cache = None
|
||||
self.cos = cos_sin[0]
|
||||
self.sin = cos_sin[1]
|
||||
self._non_persistent_buffers_set.add("cos")
|
||||
self._non_persistent_buffers_set.add("sin")
|
||||
|
||||
def forward(self, x: Tensor, mask: Tensor, index_pos: int) -> Tensor:
|
||||
residual = x
|
||||
x = self.attn_norm(x)
|
||||
attn = self.forward_attn(x, mask, index_pos)
|
||||
x = attn + residual
|
||||
|
||||
residual = x
|
||||
x = self.ffn_norm(x)
|
||||
w1 = self.ffw1.matmul_t(x)
|
||||
w3 = self.ffw3.matmul_t(x)
|
||||
mlp = self.ffw2.matmul_t(nn.silu(w1) * w3)
|
||||
|
||||
return mlp + residual
|
||||
|
||||
def forward_attn(self, x: Tensor, mask: Tensor, index_pos: int):
|
||||
b_size, seq_len, n_embd = x.shape
|
||||
q = self.attention_wq.matmul_t(x)
|
||||
k = self.attention_wk.matmul_t(x)
|
||||
v = self.attention_wv.matmul_t(x)
|
||||
|
||||
q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2)
|
||||
k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
|
||||
v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
|
||||
|
||||
q = self.apply_rotary_emb(q, index_pos)
|
||||
k = self.apply_rotary_emb(k, index_pos)
|
||||
|
||||
if self.kv_cache is not None and index_pos > 0:
|
||||
prev_k, prev_v = self.kv_cache
|
||||
k = candle.cat([prev_k, k], 2).contiguous()
|
||||
v = candle.cat([prev_v, v], 2).contiguous()
|
||||
|
||||
self.kv_cache = (k, v)
|
||||
|
||||
# TODO: maybe repeat k/v here if we start supporting MQA.
|
||||
|
||||
att = q.matmul(k.t()) / self.head_dim**0.5
|
||||
mask = mask.broadcast_as(att.shape)
|
||||
att = masked_fill(att, mask, float("-inf"))
|
||||
att = nn.softmax(att, -1)
|
||||
y = att.matmul(v.contiguous())
|
||||
y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
|
||||
return self.attention_wo.matmul_t(y)
|
||||
|
||||
def apply_rotary_emb(self, x: Tensor, index_pos: int):
|
||||
b_size, n_head, seq_len, n_embd = x.shape
|
||||
cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd // 2, 1))
|
||||
sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd // 2, 1))
|
||||
x = x.reshape((b_size, n_head, seq_len, n_embd // 2, 2))
|
||||
x0 = x.narrow(-1, 0, 1)
|
||||
x1 = x.narrow(-1, 1, 1)
|
||||
y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin)
|
||||
y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos)
|
||||
rope = candle.cat([y0, y1], -1)
|
||||
return rope.flatten_from(-2)
|
||||
|
||||
|
||||
class QuantizedLlama(Module):
|
||||
def __init__(self, hparams: Dict[str, Any], all_tensors: Dict[str, QTensor]):
|
||||
super().__init__()
|
||||
self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
|
||||
self.norm = RmsNorm(all_tensors["norm.weight"])
|
||||
self.output = all_tensors["output.weight"]
|
||||
self.layers = ModuleList()
|
||||
rope_freq = hparams.get("rope_freq", 10000.0)
|
||||
cos_sin = precompute_freqs_cis(hparams, rope_freq, hparams["context_length"])
|
||||
for layer_idx in range(hparams["n_layer"]):
|
||||
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
|
||||
self.layers.append(layer)
|
||||
|
||||
def forward(self, token: Tensor, index_pos: int) -> Tensor:
|
||||
b_size, seq_len = token.shape
|
||||
vocab_size, hidden_size = self.tok_embeddings.shape
|
||||
token = token.reshape((b_size * seq_len,))
|
||||
x = self.tok_embeddings.index_select(token, 0)
|
||||
x = x.reshape((b_size, seq_len, hidden_size))
|
||||
|
||||
mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)]
|
||||
mask = candle.tensor(mask).reshape((seq_len, seq_len))
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, index_pos)
|
||||
x = self.norm(x)
|
||||
x = x.narrow(1, -1, 1).squeeze(1)
|
||||
x = self.output.matmul_t(x)
|
||||
return x
|
@ -1,5 +1,5 @@
|
||||
# Generated content DO NOT EDIT
|
||||
from .. import nn
|
||||
|
||||
silu = nn.silu
|
||||
softmax = nn.softmax
|
||||
from .module import Module
|
||||
from .container import Sequential, ModuleList, ModuleDict
|
||||
from .sparse import Embedding
|
||||
from .normalization import LayerNorm
|
||||
from .linear import Linear
|
||||
|
483
candle-pyo3/py_src/candle/nn/container.py
Normal file
483
candle-pyo3/py_src/candle/nn/container.py
Normal file
@ -0,0 +1,483 @@
|
||||
# see https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/container.py
|
||||
from .module import Module
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
Mapping,
|
||||
Optional,
|
||||
overload,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from collections import OrderedDict, abc as container_abcs
|
||||
import operator
|
||||
from itertools import chain, islice
|
||||
|
||||
__all__ = ["Sequential", "ModuleList", "ModuleDict"]
|
||||
|
||||
T = TypeVar("T", bound=Module)
|
||||
|
||||
|
||||
def _addindent(s_: str, numSpaces: int):
|
||||
s = s_.split("\n")
|
||||
# don't do anything for single-line stuff
|
||||
if len(s) == 1:
|
||||
return s_
|
||||
first = s.pop(0)
|
||||
s = [(numSpaces * " ") + line for line in s]
|
||||
s = "\n".join(s)
|
||||
s = first + "\n" + s
|
||||
return s
|
||||
|
||||
|
||||
class Sequential(Module):
|
||||
r"""A sequential container.
|
||||
Modules will be added to it in the order they are passed in the
|
||||
constructor. Alternatively, an ``OrderedDict`` of modules can be
|
||||
passed in. The ``forward()`` method of ``Sequential`` accepts any
|
||||
input and forwards it to the first module it contains. It then
|
||||
"chains" outputs to inputs sequentially for each subsequent module,
|
||||
finally returning the output of the last module.
|
||||
|
||||
The value a ``Sequential`` provides over manually calling a sequence
|
||||
of modules is that it allows treating the whole container as a
|
||||
single module, such that performing a transformation on the
|
||||
``Sequential`` applies to each of the modules it stores (which are
|
||||
each a registered submodule of the ``Sequential``).
|
||||
|
||||
What's the difference between a ``Sequential`` and a
|
||||
:class:`candle.nn.ModuleList`? A ``ModuleList`` is exactly what it
|
||||
sounds like--a list for storing ``Module`` s! On the other hand,
|
||||
the layers in a ``Sequential`` are connected in a cascading way.
|
||||
"""
|
||||
|
||||
_modules: Dict[str, Module] # type: ignore[assignment]
|
||||
|
||||
@overload
|
||||
def __init__(self, *args: Module) -> None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __init__(self, arg: "OrderedDict[str, Module]") -> None:
|
||||
...
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__()
|
||||
if len(args) == 1 and isinstance(args[0], OrderedDict):
|
||||
for key, module in args[0].items():
|
||||
self.add_module(key, module)
|
||||
else:
|
||||
for idx, module in enumerate(args):
|
||||
self.add_module(str(idx), module)
|
||||
|
||||
def _get_item_by_idx(self, iterator, idx) -> T:
|
||||
"""Get the idx-th item of the iterator"""
|
||||
size = len(self)
|
||||
idx = operator.index(idx)
|
||||
if not -size <= idx < size:
|
||||
raise IndexError("index {} is out of range".format(idx))
|
||||
idx %= size
|
||||
return next(islice(iterator, idx, None))
|
||||
|
||||
def __getitem__(self, idx: Union[slice, int]) -> Union["Sequential", T]:
|
||||
if isinstance(idx, slice):
|
||||
return self.__class__(OrderedDict(list(self._modules.items())[idx]))
|
||||
else:
|
||||
return self._get_item_by_idx(self._modules.values(), idx)
|
||||
|
||||
def __setitem__(self, idx: int, module: Module) -> None:
|
||||
key: str = self._get_item_by_idx(self._modules.keys(), idx)
|
||||
return setattr(self, key, module)
|
||||
|
||||
def __delitem__(self, idx: Union[slice, int]) -> None:
|
||||
if isinstance(idx, slice):
|
||||
for key in list(self._modules.keys())[idx]:
|
||||
delattr(self, key)
|
||||
else:
|
||||
key = self._get_item_by_idx(self._modules.keys(), idx)
|
||||
delattr(self, key)
|
||||
# To preserve numbering
|
||||
str_indices = [str(i) for i in range(len(self._modules))]
|
||||
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._modules)
|
||||
|
||||
def __add__(self, other) -> "Sequential":
|
||||
if isinstance(other, Sequential):
|
||||
ret = Sequential()
|
||||
for layer in self:
|
||||
ret.append(layer)
|
||||
for layer in other:
|
||||
ret.append(layer)
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(
|
||||
"add operator supports only objects " "of Sequential class, but {} is given.".format(str(type(other)))
|
||||
)
|
||||
|
||||
def pop(self, key: Union[int, slice]) -> Module:
|
||||
v = self[key]
|
||||
del self[key]
|
||||
return v
|
||||
|
||||
def __iadd__(self, other) -> "Sequential":
|
||||
if isinstance(other, Sequential):
|
||||
offset = len(self)
|
||||
for i, module in enumerate(other):
|
||||
self.add_module(str(i + offset), module)
|
||||
return self
|
||||
else:
|
||||
raise ValueError(
|
||||
"add operator supports only objects " "of Sequential class, but {} is given.".format(str(type(other)))
|
||||
)
|
||||
|
||||
def __mul__(self, other: int) -> "Sequential":
|
||||
if not isinstance(other, int):
|
||||
raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
|
||||
elif other <= 0:
|
||||
raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
|
||||
else:
|
||||
combined = Sequential()
|
||||
offset = 0
|
||||
for _ in range(other):
|
||||
for module in self:
|
||||
combined.add_module(str(offset), module)
|
||||
offset += 1
|
||||
return combined
|
||||
|
||||
def __rmul__(self, other: int) -> "Sequential":
|
||||
return self.__mul__(other)
|
||||
|
||||
def __imul__(self, other: int) -> "Sequential":
|
||||
if not isinstance(other, int):
|
||||
raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
|
||||
elif other <= 0:
|
||||
raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
|
||||
else:
|
||||
len_original = len(self)
|
||||
offset = len(self)
|
||||
for _ in range(other - 1):
|
||||
for i in range(len_original):
|
||||
self.add_module(str(i + offset), self._modules[str(i)])
|
||||
offset += len_original
|
||||
return self
|
||||
|
||||
def __dir__(self):
|
||||
keys = super().__dir__()
|
||||
keys = [key for key in keys if not key.isdigit()]
|
||||
return keys
|
||||
|
||||
def __iter__(self) -> Iterator[Module]:
|
||||
return iter(self._modules.values())
|
||||
|
||||
# NB: We can't really type check this function as the type of input
|
||||
# may change dynamically (as is tested in
|
||||
# TestScript.test_sequential_intermediary_types). Cannot annotate
|
||||
# with Any as TorchScript expects a more precise type
|
||||
def forward(self, input):
|
||||
for module in self:
|
||||
input = module(input)
|
||||
return input
|
||||
|
||||
def append(self, module: Module) -> "Sequential":
|
||||
r"""Appends a given module to the end.
|
||||
|
||||
Args:
|
||||
module (nn.Module): module to append
|
||||
"""
|
||||
self.add_module(str(len(self)), module)
|
||||
return self
|
||||
|
||||
def insert(self, index: int, module: Module) -> "Sequential":
|
||||
if not isinstance(module, Module):
|
||||
raise AssertionError("module should be of type: {}".format(Module))
|
||||
n = len(self._modules)
|
||||
if not (-n <= index <= n):
|
||||
raise IndexError("Index out of range: {}".format(index))
|
||||
if index < 0:
|
||||
index += n
|
||||
for i in range(n, index, -1):
|
||||
self._modules[str(i)] = self._modules[str(i - 1)]
|
||||
self._modules[str(index)] = module
|
||||
return self
|
||||
|
||||
def extend(self, sequential) -> "Sequential":
|
||||
for layer in sequential:
|
||||
self.append(layer)
|
||||
return self
|
||||
|
||||
|
||||
class ModuleList(Module):
|
||||
r"""Holds submodules in a list.
|
||||
|
||||
:class:`~candle.nn.ModuleList` can be indexed like a regular Python list, but
|
||||
modules it contains are properly registered, and will be visible by all
|
||||
:class:`~candle.nn.Module` methods.
|
||||
|
||||
Args:
|
||||
modules (iterable, optional): an iterable of modules to add
|
||||
|
||||
Example::
|
||||
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
|
||||
|
||||
def forward(self, x):
|
||||
# ModuleList can act as an iterable, or be indexed using ints
|
||||
for i, l in enumerate(self.linears):
|
||||
x = self.linears[i // 2](x) + l(x)
|
||||
return x
|
||||
"""
|
||||
|
||||
_modules: Dict[str, Module] # type: ignore[assignment]
|
||||
|
||||
def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
|
||||
super().__init__()
|
||||
if modules is not None:
|
||||
self += modules
|
||||
|
||||
def _get_abs_string_index(self, idx):
|
||||
"""Get the absolute index for the list of modules"""
|
||||
idx = operator.index(idx)
|
||||
if not (-len(self) <= idx < len(self)):
|
||||
raise IndexError("index {} is out of range".format(idx))
|
||||
if idx < 0:
|
||||
idx += len(self)
|
||||
return str(idx)
|
||||
|
||||
def __getitem__(self, idx: Union[int, slice]) -> Union[Module, "ModuleList"]:
|
||||
if isinstance(idx, slice):
|
||||
return self.__class__(list(self._modules.values())[idx])
|
||||
else:
|
||||
return self._modules[self._get_abs_string_index(idx)]
|
||||
|
||||
def __setitem__(self, idx: int, module: Module) -> None:
|
||||
idx = self._get_abs_string_index(idx)
|
||||
return setattr(self, str(idx), module)
|
||||
|
||||
def __delitem__(self, idx: Union[int, slice]) -> None:
|
||||
if isinstance(idx, slice):
|
||||
for k in range(len(self._modules))[idx]:
|
||||
delattr(self, str(k))
|
||||
else:
|
||||
delattr(self, self._get_abs_string_index(idx))
|
||||
# To preserve numbering, self._modules is being reconstructed with modules after deletion
|
||||
str_indices = [str(i) for i in range(len(self._modules))]
|
||||
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._modules)
|
||||
|
||||
def __iter__(self) -> Iterator[Module]:
|
||||
return iter(self._modules.values())
|
||||
|
||||
def __iadd__(self, modules: Iterable[Module]) -> "ModuleList":
|
||||
return self.extend(modules)
|
||||
|
||||
def __add__(self, other: Iterable[Module]) -> "ModuleList":
|
||||
combined = ModuleList()
|
||||
for i, module in enumerate(chain(self, other)):
|
||||
combined.add_module(str(i), module)
|
||||
return combined
|
||||
|
||||
def __repr__(self):
|
||||
"""A custom repr for ModuleList that compresses repeated module representations"""
|
||||
list_of_reprs = [repr(item) for item in self]
|
||||
if len(list_of_reprs) == 0:
|
||||
return self._get_name() + "()"
|
||||
|
||||
start_end_indices = [[0, 0]]
|
||||
repeated_blocks = [list_of_reprs[0]]
|
||||
for i, r in enumerate(list_of_reprs[1:], 1):
|
||||
if r == repeated_blocks[-1]:
|
||||
start_end_indices[-1][1] += 1
|
||||
continue
|
||||
|
||||
start_end_indices.append([i, i])
|
||||
repeated_blocks.append(r)
|
||||
|
||||
lines = []
|
||||
main_str = self._get_name() + "("
|
||||
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
|
||||
local_repr = f"({start_id}): {b}" # default repr
|
||||
|
||||
if start_id != end_id:
|
||||
n = end_id - start_id + 1
|
||||
local_repr = f"({start_id}-{end_id}): {n} x {b}"
|
||||
|
||||
local_repr = _addindent(local_repr, 2)
|
||||
lines.append(local_repr)
|
||||
|
||||
main_str += "\n " + "\n ".join(lines) + "\n"
|
||||
main_str += ")"
|
||||
return main_str
|
||||
|
||||
def __dir__(self):
|
||||
keys = super().__dir__()
|
||||
keys = [key for key in keys if not key.isdigit()]
|
||||
return keys
|
||||
|
||||
def insert(self, index: int, module: Module) -> None:
|
||||
r"""Insert a given module before a given index in the list.
|
||||
|
||||
Args:
|
||||
index (int): index to insert.
|
||||
module (nn.Module): module to insert
|
||||
"""
|
||||
for i in range(len(self._modules), index, -1):
|
||||
self._modules[str(i)] = self._modules[str(i - 1)]
|
||||
self._modules[str(index)] = module
|
||||
|
||||
def append(self, module: Module) -> "ModuleList":
|
||||
r"""Appends a given module to the end of the list.
|
||||
|
||||
Args:
|
||||
module (nn.Module): module to append
|
||||
"""
|
||||
self.add_module(str(len(self)), module)
|
||||
return self
|
||||
|
||||
def pop(self, key: Union[int, slice]) -> Module:
|
||||
v = self[key]
|
||||
del self[key]
|
||||
return v
|
||||
|
||||
def extend(self, modules: Iterable[Module]) -> "ModuleList":
|
||||
r"""Appends modules from a Python iterable to the end of the list.
|
||||
|
||||
Args:
|
||||
modules (iterable): iterable of modules to append
|
||||
"""
|
||||
if not isinstance(modules, container_abcs.Iterable):
|
||||
raise TypeError(
|
||||
"ModuleList.extend should be called with an " "iterable, but got " + type(modules).__name__
|
||||
)
|
||||
offset = len(self)
|
||||
for i, module in enumerate(modules):
|
||||
self.add_module(str(offset + i), module)
|
||||
return self
|
||||
|
||||
# remove forward alltogether to fallback on Module's _forward_unimplemented
|
||||
|
||||
|
||||
class ModuleDict(Module):
|
||||
r"""Holds submodules in a dictionary.
|
||||
|
||||
:class:`~candle.nn.ModuleDict` can be indexed like a regular Python dictionary,
|
||||
but modules it contains are properly registered, and will be visible by all
|
||||
:class:`~candle.nn.Module` methods.
|
||||
|
||||
:class:`~candle.nn.ModuleDict` is an **ordered** dictionary that respects
|
||||
|
||||
* the order of insertion, and
|
||||
|
||||
* in :meth:`~candle.nn.ModuleDict.update`, the order of the merged
|
||||
``OrderedDict``, ``dict`` (started from Python 3.6) or another
|
||||
:class:`~candle.nn.ModuleDict` (the argument to
|
||||
:meth:`~candle.nn.ModuleDict.update`).
|
||||
|
||||
Note that :meth:`~candle.nn.ModuleDict.update` with other unordered mapping
|
||||
types (e.g., Python's plain ``dict`` before Python version 3.6) does not
|
||||
preserve the order of the merged mapping.
|
||||
|
||||
Args:
|
||||
modules (iterable, optional): a mapping (dictionary) of (string: module)
|
||||
or an iterable of key-value pairs of type (string, module)
|
||||
"""
|
||||
|
||||
_modules: Dict[str, Module] # type: ignore[assignment]
|
||||
|
||||
def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
|
||||
super().__init__()
|
||||
if modules is not None:
|
||||
self.update(modules)
|
||||
|
||||
def __getitem__(self, key: str) -> Module:
|
||||
return self._modules[key]
|
||||
|
||||
def __setitem__(self, key: str, module: Module) -> None:
|
||||
self.add_module(key, module)
|
||||
|
||||
def __delitem__(self, key: str) -> None:
|
||||
del self._modules[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._modules)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._modules)
|
||||
|
||||
def __contains__(self, key: str) -> bool:
|
||||
return key in self._modules
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all items from the ModuleDict."""
|
||||
self._modules.clear()
|
||||
|
||||
def pop(self, key: str) -> Module:
|
||||
r"""Remove key from the ModuleDict and return its module.
|
||||
|
||||
Args:
|
||||
key (str): key to pop from the ModuleDict
|
||||
"""
|
||||
v = self[key]
|
||||
del self[key]
|
||||
return v
|
||||
|
||||
def keys(self) -> Iterable[str]:
|
||||
r"""Return an iterable of the ModuleDict keys."""
|
||||
return self._modules.keys()
|
||||
|
||||
def items(self) -> Iterable[Tuple[str, Module]]:
|
||||
r"""Return an iterable of the ModuleDict key/value pairs."""
|
||||
return self._modules.items()
|
||||
|
||||
def values(self) -> Iterable[Module]:
|
||||
r"""Return an iterable of the ModuleDict values."""
|
||||
return self._modules.values()
|
||||
|
||||
def update(self, modules: Mapping[str, Module]) -> None:
|
||||
r"""Update the :class:`~candle.nn.ModuleDict` with the key-value pairs from a
|
||||
mapping or an iterable, overwriting existing keys.
|
||||
|
||||
.. note::
|
||||
If :attr:`modules` is an ``OrderedDict``, a :class:`~candle.nn.ModuleDict`, or
|
||||
an iterable of key-value pairs, the order of new elements in it is preserved.
|
||||
|
||||
Args:
|
||||
modules (iterable): a mapping (dictionary) from string to :class:`~candle.nn.Module`,
|
||||
or an iterable of key-value pairs of type (string, :class:`~candle.nn.Module`)
|
||||
"""
|
||||
if not isinstance(modules, container_abcs.Iterable):
|
||||
raise TypeError(
|
||||
"ModuleDict.update should be called with an "
|
||||
"iterable of key/value pairs, but got " + type(modules).__name__
|
||||
)
|
||||
|
||||
if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
|
||||
for key, module in modules.items():
|
||||
self[key] = module
|
||||
else:
|
||||
# modules here can be a list with two items
|
||||
for j, m in enumerate(modules):
|
||||
if not isinstance(m, container_abcs.Iterable):
|
||||
raise TypeError(
|
||||
"ModuleDict update sequence element "
|
||||
"#" + str(j) + " should be Iterable; is" + type(m).__name__
|
||||
)
|
||||
if not len(m) == 2:
|
||||
raise ValueError(
|
||||
"ModuleDict update sequence element "
|
||||
"#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
|
||||
)
|
||||
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
|
||||
# that's too cumbersome to type correctly with overloads, so we add an ignore here
|
||||
self[m[0]] = m[1] # type: ignore[assignment]
|
||||
|
||||
# remove forward alltogether to fallback on Module's _forward_unimplemented
|
119
candle-pyo3/py_src/candle/nn/linear.py
Normal file
119
candle-pyo3/py_src/candle/nn/linear.py
Normal file
@ -0,0 +1,119 @@
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
import candle
|
||||
from candle import Tensor
|
||||
from .module import Module
|
||||
|
||||
# See https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/linear.py
|
||||
|
||||
|
||||
class Identity(Module):
|
||||
r"""A placeholder identity operator that is argument-insensitive.
|
||||
|
||||
Args:
|
||||
args: any argument (unused)
|
||||
kwargs: any keyword argument (unused)
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(*)`, where :math:`*` means any number of dimensions.
|
||||
- Output: :math:`(*)`, same shape as the input.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
|
||||
>>> input = candle.randn(128, 20)
|
||||
>>> output = m(input)
|
||||
>>> print(output.shape)
|
||||
"""
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return input
|
||||
|
||||
|
||||
class Linear(Module):
|
||||
r"""Applies a linear transformation to the incoming data: :math:`y = xA^T + b`
|
||||
Args:
|
||||
in_features: size of each input sample
|
||||
out_features: size of each output sample
|
||||
bias: If set to ``False``, the layer will not learn an additive bias.
|
||||
Default: ``True``
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(*, H_{in})` where :math:`*` means any number of
|
||||
dimensions including none and :math:`H_{in} = \text{in\_features}`.
|
||||
- Output: :math:`(*, H_{out})` where all but the last dimension
|
||||
are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
|
||||
|
||||
Attributes:
|
||||
weight: the learnable weights of the module of shape
|
||||
:math:`(\text{out\_features}, \text{in\_features})`. The values are
|
||||
initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
|
||||
:math:`k = \frac{1}{\text{in\_features}}`
|
||||
bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
|
||||
If :attr:`bias` is ``True``, the values are initialized from
|
||||
:math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
|
||||
:math:`k = \frac{1}{\text{in\_features}}`
|
||||
"""
|
||||
|
||||
__constants__ = ["in_features", "out_features"]
|
||||
in_features: int
|
||||
out_features: int
|
||||
weight: Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
# Allow 'weight' to be quantized
|
||||
self._quantizable_buffers.add("weight")
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
# TODO: Do actual initialization here: e.g. kaiming_uniform or xavier_uniform
|
||||
self.weight = candle.ones((out_features, in_features), **factory_kwargs)
|
||||
if bias:
|
||||
self.bias = candle.zeros((out_features,), **factory_kwargs)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
dims = x.shape
|
||||
last_dim = dims[-1]
|
||||
|
||||
if isinstance(self.weight, candle.QTensor):
|
||||
if len(dims) < 3:
|
||||
matmul_result = self.weight.matmul_t(x).broadcast_add(self.bias)
|
||||
elif len(dims) == 3:
|
||||
b, n, m = dims
|
||||
output_shape = (b, n, self.out_features)
|
||||
re = x.reshape((b * n, m))
|
||||
matmul_result = self.weight.matmul_t(re).reshape((output_shape))
|
||||
else:
|
||||
raise NotImplementedError("'QTensor.matmul_t' is not implemented for more than 3 dimensions")
|
||||
|
||||
if self.bias:
|
||||
return matmul_result.broadcast_add(self.bias)
|
||||
else:
|
||||
if self.weight.shape[-1] == last_dim and len(dims) < 3:
|
||||
w = self.weight.t()
|
||||
else:
|
||||
batch_size = dims[0]
|
||||
w = self.weight.broadcast_left((batch_size,)).t()
|
||||
|
||||
x = x.matmul(w)
|
||||
if self.bias is not None:
|
||||
x = x.broadcast_add(self.bias)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}"
|
702
candle-pyo3/py_src/candle/nn/module.py
Normal file
702
candle-pyo3/py_src/candle/nn/module.py
Normal file
@ -0,0 +1,702 @@
|
||||
from candle import Tensor, QTensor, DType
|
||||
from typing import (
|
||||
Dict,
|
||||
Tuple,
|
||||
Any,
|
||||
Optional,
|
||||
Union,
|
||||
Iterator,
|
||||
Set,
|
||||
overload,
|
||||
Mapping,
|
||||
TypeVar,
|
||||
List,
|
||||
)
|
||||
from collections import OrderedDict, namedtuple
|
||||
|
||||
TensorLike = Union[Tensor, QTensor]
|
||||
T = TypeVar("T", bound="Module")
|
||||
|
||||
|
||||
class _IncompatibleKeys(namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"])):
|
||||
def __repr__(self):
|
||||
if not self.missing_keys and not self.unexpected_keys:
|
||||
return "<All keys matched successfully>"
|
||||
return super().__repr__()
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
|
||||
# see: https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/module.py
|
||||
class Module:
|
||||
"""
|
||||
Pytorch like Module.
|
||||
|
||||
Base class for all neural network modules.
|
||||
|
||||
Your models should also subclass this class.
|
||||
"""
|
||||
|
||||
_modules: Dict[str, Optional["Module"]]
|
||||
_buffers: Dict[str, Optional[TensorLike]]
|
||||
_non_persistent_buffers_set: Set[str]
|
||||
_quantizable_buffers: Set[str]
|
||||
_version: int = 1
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
"""
|
||||
Initializes internal Module state
|
||||
"""
|
||||
super().__setattr__("_modules", OrderedDict())
|
||||
super().__setattr__("_buffers", OrderedDict())
|
||||
super().__setattr__("_non_persistent_buffers_set", set())
|
||||
super().__setattr__("_quantizable_buffers", set())
|
||||
|
||||
def __call__(self, *input):
|
||||
"""
|
||||
Call self as a function.
|
||||
"""
|
||||
return self.forward(*input)
|
||||
|
||||
def forward(self, *input):
|
||||
"""
|
||||
Defines the computation performed at every call.
|
||||
Should be overridden by all subclasses.
|
||||
"""
|
||||
pass
|
||||
|
||||
def children(self) -> Iterator["Module"]:
|
||||
r"""Returns an iterator over immediate children modules.
|
||||
|
||||
Yields:
|
||||
Module: a child module
|
||||
"""
|
||||
for name, module in self.named_children():
|
||||
yield module
|
||||
|
||||
def named_children(self) -> Iterator[Tuple[str, "Module"]]:
|
||||
r"""Returns an iterator over immediate children modules, yielding both
|
||||
the name of the module as well as the module itself.
|
||||
|
||||
Yields:
|
||||
(str, Module): Tuple containing a name and child module
|
||||
|
||||
Example::
|
||||
|
||||
>>> for name, module in model.named_children():
|
||||
>>> if name in ['conv4', 'conv5']:
|
||||
>>> print(module)
|
||||
|
||||
"""
|
||||
memo = set()
|
||||
for name, module in self._modules.items():
|
||||
if module is not None and module not in memo:
|
||||
memo.add(module)
|
||||
yield name, module
|
||||
|
||||
def add_module(self, name: str, module: Optional["Module"]) -> None:
|
||||
r"""Adds a child module to the current module.
|
||||
|
||||
The module can be accessed as an attribute using the given name.
|
||||
|
||||
Args:
|
||||
name (str): name of the child module. The child module can be
|
||||
accessed from this module using the given name
|
||||
module (Module): child module to be added to the module.
|
||||
"""
|
||||
if not isinstance(module, Module) and module is not None:
|
||||
raise TypeError(f"{str(module)} is not a Module subclass")
|
||||
elif not isinstance(name, str):
|
||||
raise TypeError(f"module name should be a string. Got {name}")
|
||||
elif hasattr(self, name) and name not in self._modules:
|
||||
raise KeyError(f"attribute '{name}' already exists")
|
||||
elif "." in name:
|
||||
raise KeyError(f'module name can\'t contain ".", got: {name}')
|
||||
elif name == "":
|
||||
raise KeyError('module name can\'t be empty string ""')
|
||||
self._modules[name] = module
|
||||
|
||||
def register_module(self, name: str, module: Optional["Module"]) -> None:
|
||||
r"""Alias for :func:`add_module`."""
|
||||
self.add_module(name, module)
|
||||
|
||||
def modules(self) -> Iterator["Module"]:
|
||||
r"""Returns an iterator over all modules in the network."""
|
||||
for _, module in self.named_modules():
|
||||
yield module
|
||||
|
||||
def named_modules(
|
||||
self,
|
||||
memo: Optional[Set["Module"]] = None,
|
||||
prefix: str = "",
|
||||
remove_duplicate: bool = True,
|
||||
):
|
||||
r"""Returns an iterator over all modules in the network, yielding
|
||||
both the name of the module as well as the module itself.
|
||||
|
||||
Args:
|
||||
memo: a memo to store the set of modules already added to the result
|
||||
prefix: a prefix that will be added to the name of the module
|
||||
remove_duplicate: whether to remove the duplicated module instances in the result
|
||||
or not
|
||||
|
||||
Yields:
|
||||
(str, Module): Tuple of name and module
|
||||
|
||||
Note:
|
||||
Duplicate modules are returned only once. In the following
|
||||
example, ``l`` will be returned only once.
|
||||
"""
|
||||
|
||||
if memo is None:
|
||||
memo = set()
|
||||
if self not in memo:
|
||||
if remove_duplicate:
|
||||
memo.add(self)
|
||||
yield prefix, self
|
||||
for name, module in self._modules.items():
|
||||
if module is None:
|
||||
continue
|
||||
submodule_prefix = prefix + ("." if prefix else "") + name
|
||||
for m in module.named_modules(memo, submodule_prefix, remove_duplicate):
|
||||
yield m
|
||||
|
||||
def buffers(self, recurse: bool = True) -> Iterator[TensorLike]:
|
||||
"""
|
||||
Returns an iterator over module buffers.
|
||||
"""
|
||||
for name, buf in self.named_buffers(recurse=recurse):
|
||||
yield buf
|
||||
|
||||
def named_buffers(
|
||||
self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
|
||||
) -> Iterator[Tuple[str, TensorLike]]:
|
||||
r"""Returns an iterator over module buffers, yielding both the
|
||||
name of the buffer as well as the buffer itself.
|
||||
|
||||
Args:
|
||||
prefix (str): prefix to prepend to all buffer names.
|
||||
recurse (bool, optional): if True, then yields buffers of this module
|
||||
and all submodules. Otherwise, yields only buffers that
|
||||
are direct members of this module. Defaults to True.
|
||||
remove_duplicate (bool, optional): whether to remove the duplicated buffers in the result. Defaults to True.
|
||||
|
||||
Yields:
|
||||
(str, Tensor): Tuple containing the name and buffer
|
||||
|
||||
Example::
|
||||
|
||||
>>> for name, buf in self.named_buffers():
|
||||
>>> if name in ['running_var']:
|
||||
>>> print(buf.size())
|
||||
|
||||
"""
|
||||
gen = self._named_members(
|
||||
lambda module: module._buffers.items(),
|
||||
prefix=prefix,
|
||||
recurse=recurse,
|
||||
remove_duplicate=remove_duplicate,
|
||||
)
|
||||
yield from gen
|
||||
|
||||
# The user can pass an optional arbitrary mappable object to `state_dict`, in which case `state_dict` returns
|
||||
# back that same object. But if they pass nothing, an `OrderedDict` is created and returned.
|
||||
T_destination = TypeVar("T_destination", bound=Dict[str, Any])
|
||||
|
||||
@overload
|
||||
def state_dict(self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...) -> T_destination:
|
||||
...
|
||||
|
||||
@overload
|
||||
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> Dict[str, Any]:
|
||||
...
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", keep_vars=False):
|
||||
r"""Returns a dictionary containing references to the whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are
|
||||
included. Keys are corresponding parameter and buffer names.
|
||||
Parameters and buffers set to ``None`` are not included.
|
||||
|
||||
.. note::
|
||||
The returned object is a shallow copy. It contains references
|
||||
to the module's parameters and buffers.
|
||||
|
||||
.. warning::
|
||||
Currently ``state_dict()`` also accepts positional arguments for
|
||||
``destination``, ``prefix`` and ``keep_vars`` in order. However,
|
||||
this is being deprecated and keyword arguments will be enforced in
|
||||
future releases.
|
||||
|
||||
.. warning::
|
||||
Please avoid the use of argument ``destination`` as it is not
|
||||
designed for end-users.
|
||||
|
||||
Args:
|
||||
destination (dict, optional): If provided, the state of module will
|
||||
be updated into the dict and the same object is returned.
|
||||
Otherwise, an ``OrderedDict`` will be created and returned.
|
||||
Default: ``None``.
|
||||
prefix (str, optional): a prefix added to parameter and buffer
|
||||
names to compose the keys in state_dict. Default: ``''``.
|
||||
keep_vars (bool, optional): by default the :class:`~candle.Tensor` s
|
||||
returned in the state dict are detached from autograd. If it's
|
||||
set to ``True``, detaching will not be performed.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
dict:
|
||||
a dictionary containing a whole state of the module
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +SKIP("undefined vars")
|
||||
>>> module.state_dict().keys()
|
||||
['bias', 'weight']
|
||||
|
||||
"""
|
||||
|
||||
# TODO: Remove `args` and the parsing logic when BC allows.
|
||||
if len(args) > 0:
|
||||
if destination is None:
|
||||
destination = args[0]
|
||||
if len(args) > 1 and prefix == "":
|
||||
prefix = args[1]
|
||||
if len(args) > 2 and keep_vars is False:
|
||||
keep_vars = args[2]
|
||||
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
destination._metadata = OrderedDict()
|
||||
|
||||
local_metadata = dict(version=self._version)
|
||||
if hasattr(destination, "_metadata"):
|
||||
destination._metadata[prefix[:-1]] = local_metadata
|
||||
self._save_to_state_dict(destination, prefix, keep_vars)
|
||||
for name, module in self._modules.items():
|
||||
if module is not None:
|
||||
module.state_dict(
|
||||
destination=destination,
|
||||
prefix=prefix + name + ".",
|
||||
keep_vars=keep_vars,
|
||||
)
|
||||
return destination
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
submodule in :meth:`~candle.nn.Module.state_dict`.
|
||||
|
||||
In rare cases, subclasses can achieve class-specific behavior by
|
||||
overriding this method with custom logic.
|
||||
|
||||
Args:
|
||||
destination (dict): a dict where state will be stored
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
"""
|
||||
for name, buf in self._buffers.items():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
if isinstance(buf, Tensor):
|
||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||
else:
|
||||
destination[prefix + name] = buf
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into
|
||||
this module and its descendants. If :attr:`strict` is ``True``, then
|
||||
the keys of :attr:`state_dict` must exactly match the keys returned
|
||||
by this module's :meth:`~candle.nn.Module.state_dict` function.
|
||||
|
||||
.. warning::
|
||||
If :attr:`assign` is ``True`` the optimizer must be created after
|
||||
the call to :attr:`load_state_dict`.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
strict (bool, optional): whether to strictly enforce that the keys
|
||||
in :attr:`state_dict` match the keys returned by this module's
|
||||
:meth:`~candle.nn.Module.state_dict` function. Default: ``True``
|
||||
assign (bool, optional): whether to assign items in the state
|
||||
dictionary to their corresponding keys in the module instead
|
||||
of copying them inplace into the module's current parameters and buffers.
|
||||
When ``False``, the properties of the tensors in the current
|
||||
module are preserved while when ``True``, the properties of the
|
||||
Tensors in the state dict are preserved.
|
||||
Default: ``False``
|
||||
|
||||
Returns:
|
||||
``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
|
||||
* **missing_keys** is a list of str containing the missing keys
|
||||
* **unexpected_keys** is a list of str containing the unexpected keys
|
||||
|
||||
Note:
|
||||
If a parameter or buffer is registered as ``None`` and its corresponding key
|
||||
exists in :attr:`state_dict`, :meth:`load_state_dict` will raise a
|
||||
``RuntimeError``.
|
||||
"""
|
||||
if not isinstance(state_dict, Mapping):
|
||||
raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")
|
||||
|
||||
missing_keys: List[str] = []
|
||||
unexpected_keys: List[str] = []
|
||||
error_msgs: List[str] = []
|
||||
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
metadata = getattr(state_dict, "_metadata", None)
|
||||
state_dict = OrderedDict(state_dict)
|
||||
if metadata is not None:
|
||||
# mypy isn't aware that "_metadata" exists in state_dict
|
||||
state_dict._metadata = metadata # type: ignore[attr-defined]
|
||||
|
||||
def load(module, local_state_dict, prefix=""):
|
||||
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
||||
if assign:
|
||||
local_metadata["assign_to_params_buffers"] = assign
|
||||
module._load_from_state_dict(
|
||||
local_state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
True,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
child_prefix = prefix + name + "."
|
||||
child_state_dict = {k: v for k, v in local_state_dict.items() if k.startswith(child_prefix)}
|
||||
load(child, child_state_dict, child_prefix)
|
||||
|
||||
load(self, state_dict)
|
||||
del load
|
||||
|
||||
if strict:
|
||||
if len(unexpected_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0,
|
||||
"Unexpected key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in unexpected_keys)),
|
||||
)
|
||||
if len(missing_keys) > 0:
|
||||
error_msgs.insert(
|
||||
0,
|
||||
"Missing key(s) in state_dict: {}. ".format(", ".join(f'"{k}"' for k in missing_keys)),
|
||||
)
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(self.__class__.__name__, "\n\t".join(error_msgs))
|
||||
)
|
||||
return _IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
def _load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~candle.nn.Module.load_state_dict`. Metadata saved for this
|
||||
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
||||
For state dicts without metadata, :attr:`local_metadata` is empty.
|
||||
Subclasses can achieve class-specific backward compatible loading using
|
||||
the version number at `local_metadata.get("version", None)`.
|
||||
Additionally, :attr:`local_metadata` can also contain the key
|
||||
`assign_to_params_buffers` that indicates whether keys should be
|
||||
assigned their corresponding tensor in the state_dict.
|
||||
|
||||
.. note::
|
||||
:attr:`state_dict` is not the same object as the input
|
||||
:attr:`state_dict` to :meth:`~candle.nn.Module.load_state_dict`. So
|
||||
it can be modified.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
local_metadata (dict): a dict containing the metadata for this module.
|
||||
See
|
||||
strict (bool): whether to strictly enforce that the keys in
|
||||
:attr:`state_dict` with :attr:`prefix` match the names of
|
||||
parameters and buffers in this module
|
||||
missing_keys (list of str): if ``strict=True``, add missing keys to
|
||||
this list
|
||||
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
||||
keys to this list
|
||||
error_msgs (list of str): error messages should be added to this
|
||||
list, and will be reported together in
|
||||
:meth:`~candle.nn.Module.load_state_dict`
|
||||
"""
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = persistent_buffers.items()
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
if not isinstance(input_param, (Tensor, QTensor)):
|
||||
error_msgs.append(
|
||||
f'While copying the parameter named "{key}", '
|
||||
"expected Tensor-like object from checkpoint but "
|
||||
f"received {type(input_param)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if input_param.shape != param.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append(
|
||||
"size mismatch for {}: copying a param with shape {} from checkpoint, "
|
||||
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
# Shape checks are already done above -> Just assign tensor
|
||||
setattr(self, name, input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
f'While copying the parameter named "{key}", '
|
||||
f"whose dimensions in the model are {param.shape} and "
|
||||
f"whose dimensions in the checkpoint are {input_param.shape}, "
|
||||
f"an exception occurred : {ex.args}."
|
||||
)
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix):
|
||||
input_name = key[len(prefix) :]
|
||||
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def _named_members(self, get_members_fn, prefix="", recurse=True, remove_duplicate: bool = True):
|
||||
r"""Helper method for yielding various names + members of modules."""
|
||||
memo = set()
|
||||
modules = self.named_modules(prefix=prefix, remove_duplicate=remove_duplicate) if recurse else [(prefix, self)]
|
||||
for module_prefix, module in modules:
|
||||
members = get_members_fn(module)
|
||||
for k, v in members:
|
||||
if v is None or v in memo:
|
||||
continue
|
||||
if remove_duplicate:
|
||||
memo.add(v)
|
||||
name = module_prefix + ("." if module_prefix else "") + k
|
||||
yield name, v
|
||||
|
||||
def _get_name(self):
|
||||
return self.__class__.__name__
|
||||
|
||||
def _apply(self, fn):
|
||||
for module in self.children():
|
||||
module._apply(fn)
|
||||
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
|
||||
return self
|
||||
|
||||
def __move_tensor_to_device(self, tensor: TensorLike, device: str):
|
||||
if isinstance(tensor, Tensor):
|
||||
return tensor.to_device(device)
|
||||
else:
|
||||
raise NotImplementedError("Cannot offload QTensor to cuda, yet!")
|
||||
|
||||
def device(self) -> str:
|
||||
"""
|
||||
Gets the device of the module, by inspecting its tensors.
|
||||
"""
|
||||
tensor = next(self.buffers())
|
||||
if isinstance(tensor, Tensor):
|
||||
return tensor.device
|
||||
else:
|
||||
# QTensors can only be on the CPU
|
||||
return "cpu"
|
||||
|
||||
def cuda(self: T) -> T:
|
||||
r"""Moves all model parameters and buffers to the GPU.
|
||||
|
||||
This also makes associated parameters and buffers different objects. So
|
||||
it should be called before constructing optimizer if the module will
|
||||
live on GPU while being optimized.
|
||||
|
||||
.. note::
|
||||
This method modifies the module in-place.
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
|
||||
def to_cuda(t: TensorLike):
|
||||
return self.__move_tensor_to_device(t, "cuda")
|
||||
|
||||
return self._apply(to_cuda)
|
||||
|
||||
def cpu(self: T) -> T:
|
||||
r"""Moves all model parameters and buffers to the CPU.
|
||||
|
||||
.. note::
|
||||
This method modifies the module in-place.
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
|
||||
def to_cpu(t: TensorLike):
|
||||
return self.__move_tensor_to_device(t, "cpu")
|
||||
|
||||
return self._apply(to_cpu)
|
||||
|
||||
def __cast_tensor(self, tensor: TensorLike, dtype: Union[DType, str]):
|
||||
if isinstance(tensor, Tensor):
|
||||
return tensor.to_dtype(dtype)
|
||||
else:
|
||||
raise TypeError("candle.Module.to only accepts Tensor dtypes, but got desired dtype={}".format(dtype))
|
||||
|
||||
def type(self: T, dst_type: Union[DType, str]) -> T:
|
||||
r"""Casts all parameters and buffers to :attr:`dst_type`.
|
||||
|
||||
.. note::
|
||||
This method modifies the module in-place.
|
||||
|
||||
Args:
|
||||
dst_type (type or string): the desired type
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
|
||||
def cast(t: TensorLike):
|
||||
return self.__cast_tensor(t, dst_type)
|
||||
|
||||
return self._apply(cast)
|
||||
|
||||
@overload
|
||||
def to(
|
||||
self: T,
|
||||
device: str = ...,
|
||||
dtype: Optional[Union[DType, str]] = ...,
|
||||
) -> T:
|
||||
...
|
||||
|
||||
@overload
|
||||
def to(self: T, dtype: Union[DType, str]) -> T:
|
||||
...
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
r"""Moves and/or casts the parameters and buffers.
|
||||
|
||||
This can be called as
|
||||
|
||||
.. function:: to(device=None, dtype=None)
|
||||
:noindex:
|
||||
|
||||
.. function:: to(dtype)
|
||||
:noindex:
|
||||
|
||||
See below for examples.
|
||||
|
||||
.. note::
|
||||
This method modifies the module in-place.
|
||||
|
||||
Args:
|
||||
device (:class:`candle.device`): the desired device of the parameters
|
||||
and buffers in this module
|
||||
dtype (:class:`candle.dtype`): the desired floating point dtype of
|
||||
the parameters and buffers in this module
|
||||
|
||||
Returns:
|
||||
Module: self
|
||||
"""
|
||||
|
||||
device = None
|
||||
dtype = None
|
||||
|
||||
if args:
|
||||
for arg in args:
|
||||
# Assuming arg can be a string representing a device or a dtype
|
||||
|
||||
if isinstance(arg, str):
|
||||
lower_arg = str(arg).lower()
|
||||
if lower_arg.startswith("cuda") or lower_arg == "cpu":
|
||||
device = lower_arg
|
||||
else:
|
||||
dtype = arg
|
||||
elif isinstance(arg, DType):
|
||||
dtype = str(arg)
|
||||
else:
|
||||
raise TypeError("Module.to() received an invalid combination of arguments. Got: {}".format(args))
|
||||
|
||||
if kwargs:
|
||||
device = kwargs.get("device", device)
|
||||
dtype = str(kwargs.get("dtype", dtype))
|
||||
|
||||
if device:
|
||||
device = device.lower()
|
||||
|
||||
if dtype:
|
||||
dtype = dtype.lower()
|
||||
if dtype not in ["f32", "f16", "f64"]:
|
||||
raise TypeError(
|
||||
"candle.Module.to only accepts floating point" "dtypes, but got desired dtype={}".format(dtype)
|
||||
)
|
||||
|
||||
def convert(t):
|
||||
if dtype:
|
||||
t = self.__cast_tensor(t, dtype)
|
||||
if device:
|
||||
t = self.__move_tensor_to_device(t, device)
|
||||
return t
|
||||
|
||||
return self._apply(convert)
|
||||
|
||||
def __setattr__(self, __name: str, __value: Any) -> None:
|
||||
if isinstance(__value, Module):
|
||||
self._modules[__name] = __value
|
||||
elif isinstance(__value, QTensor):
|
||||
if __name in self._quantizable_buffers:
|
||||
type = __value.ggml_dtype.lower()
|
||||
if type in ["f32", "f16"]:
|
||||
# It is faster to just dequantize the tensor here and use the normal tensor operations
|
||||
dequant = __value.dequantize()
|
||||
if type == "f16":
|
||||
dequant = dequant.to_dtype("f16")
|
||||
self._buffers[__name] = dequant
|
||||
else:
|
||||
self._buffers[__name] = __value
|
||||
else:
|
||||
# We expect a normal tensor here => dequantize it
|
||||
self._buffers[__name] = __value.dequantize()
|
||||
elif isinstance(__value, Tensor):
|
||||
self._buffers[__name] = __value
|
||||
else:
|
||||
super().__setattr__(__name, __value)
|
||||
|
||||
def __getattr__(self, __name: str) -> Any:
|
||||
if "_modules" in self.__dict__:
|
||||
modules = self.__dict__["_modules"]
|
||||
if __name in modules:
|
||||
return modules[__name]
|
||||
if "_buffers" in self.__dict__:
|
||||
tensors = self.__dict__["_buffers"]
|
||||
if __name in tensors:
|
||||
return tensors[__name]
|
||||
return super().__getattribute__(__name)
|
||||
|
||||
def __delattr__(self, name):
|
||||
if name in self._buffers:
|
||||
del self._buffers[name]
|
||||
elif name in self._modules:
|
||||
del self._modules[name]
|
||||
else:
|
||||
super().__delattr__(name)
|
54
candle-pyo3/py_src/candle/nn/normalization.py
Normal file
54
candle-pyo3/py_src/candle/nn/normalization.py
Normal file
@ -0,0 +1,54 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
from .module import Module
|
||||
from typing import Union, List, Tuple, Optional, Any
|
||||
|
||||
_shape_t = Union[int, List[int]]
|
||||
import numbers
|
||||
|
||||
|
||||
class LayerNorm(Module):
|
||||
r"""Applies Layer Normalization over a mini-batch of inputs as described in
|
||||
the paper `Layer Normalization <https://arxiv.org/abs/1607.06450>`
|
||||
|
||||
math::
|
||||
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
|
||||
"""
|
||||
__constants__ = ["normalized_shape", "eps"]
|
||||
normalized_shape: Tuple[int, ...]
|
||||
eps: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
normalized_shape: _shape_t,
|
||||
eps: float = 1e-5,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
if isinstance(normalized_shape, numbers.Integral):
|
||||
normalized_shape = (normalized_shape,)
|
||||
self.normalized_shape = tuple(normalized_shape)
|
||||
self.eps = eps
|
||||
|
||||
self.weight = candle.ones(normalized_shape, **factory_kwargs)
|
||||
if bias:
|
||||
self.bias = candle.zeros(normalized_shape, **factory_kwargs)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
mean_x = input.sum_keepdim(2) / float(self.normalized_shape[-1])
|
||||
x = input.broadcast_sub(mean_x)
|
||||
norm_x = x.sqr().sum_keepdim(2) / float(self.normalized_shape[-1])
|
||||
x_normed = x.broadcast_div((norm_x + self.eps).sqrt())
|
||||
x = x_normed.broadcast_mul(self.weight)
|
||||
|
||||
if self.bias:
|
||||
x = x.broadcast_add(self.bias)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "{normalized_shape}, eps={eps}, " "elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
39
candle-pyo3/py_src/candle/nn/sparse.py
Normal file
39
candle-pyo3/py_src/candle/nn/sparse.py
Normal file
@ -0,0 +1,39 @@
|
||||
from .module import Module
|
||||
from typing import Optional, Tuple, Any
|
||||
from candle import Tensor
|
||||
import candle
|
||||
|
||||
|
||||
class Embedding(Module):
|
||||
"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
||||
|
||||
This module is often used to store word embeddings and retrieve them using indices.
|
||||
The input to the module is a list of indices, and the output is the corresponding
|
||||
word embeddings.
|
||||
|
||||
Args:
|
||||
num_embeddings (int): size of the dictionary of embeddings
|
||||
embedding_dim (int): the size of each embedding vector
|
||||
|
||||
Attributes:
|
||||
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
||||
initialized from :math:`\mathcal{N}(0, 1)`
|
||||
|
||||
Shape:
|
||||
- Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
|
||||
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
||||
"""
|
||||
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, device=None) -> None:
|
||||
factory_kwargs = {"device": device}
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.weight = candle.randn((num_embeddings, embedding_dim), **factory_kwargs)
|
||||
|
||||
def forward(self, indexes: Tensor) -> Tensor:
|
||||
final_dims = list(indexes.shape)
|
||||
final_dims.append(self.embedding_dim)
|
||||
indexes = indexes.flatten_all()
|
||||
values = self.weight.index_select(indexes, 0)
|
||||
return values.reshape(final_dims)
|
@ -2,7 +2,7 @@ from typing import TypeVar, Union, Sequence
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
_ArrayLike = Union[
|
||||
_ArrayLike = Union[
|
||||
_T,
|
||||
Sequence[_T],
|
||||
Sequence[Sequence[_T]],
|
||||
@ -10,7 +10,7 @@ _ArrayLike = Union[
|
||||
Sequence[Sequence[Sequence[Sequence[_T]]]],
|
||||
]
|
||||
|
||||
CPU:str = "cpu"
|
||||
CUDA:str = "cuda"
|
||||
CPU: str = "cpu"
|
||||
CUDA: str = "cuda"
|
||||
|
||||
Device = TypeVar("Device", CPU, CUDA)
|
||||
Device = TypeVar("Device", CPU, CUDA)
|
||||
|
@ -28,3 +28,7 @@ features = ["pyo3/extension-module"]
|
||||
[tool.black]
|
||||
line-length = 119
|
||||
target-version = ['py35']
|
||||
|
||||
[project.optional-dependencies]
|
||||
testing = ["pytest", "black==22.3"]
|
||||
huggingface = ["transformers>=4.33.3", "huggingface-hub>=0.17.3"]
|
@ -2,181 +2,59 @@
|
||||
import sys
|
||||
from typing import Dict, Tuple, Any
|
||||
import candle
|
||||
from candle import Tensor, QTensor, utils, nn
|
||||
from candle.models.llama import QuantizedLlama
|
||||
from candle import utils
|
||||
|
||||
MAX_SEQ_LEN = 4096
|
||||
|
||||
def masked_fill(on_false:Tensor, mask:Tensor, on_true:Tensor):
|
||||
shape = mask.shape
|
||||
on_true = candle.tensor(on_true).broadcast_as(shape)
|
||||
return mask.where_cond(on_true, on_false)
|
||||
|
||||
class RmsNorm:
|
||||
def __init__(self, qtensor:QTensor):
|
||||
self.weight = qtensor.dequantize()
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
b_size, seq_len, hidden_size = x.shape
|
||||
norm_x = x.sqr().sum_keepdim(2) / hidden_size
|
||||
x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
|
||||
return x_normed.broadcast_mul(self.weight)
|
||||
|
||||
class QuantizedLayer:
|
||||
def __init__(self, layer_idx:int, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor], cos_sin:Tuple[Tensor,Tensor]):
|
||||
p = f"layers.{layer_idx}"
|
||||
self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
|
||||
self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
|
||||
self.attention_wv = all_tensors[f"{p}.attention.wv.weight"]
|
||||
self.attention_wo = all_tensors[f"{p}.attention.wo.weight"]
|
||||
self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"]
|
||||
self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"]
|
||||
self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"]
|
||||
self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"])
|
||||
self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"])
|
||||
|
||||
self.n_head = hparams["n_head"]
|
||||
self.n_kv_head = self.n_head
|
||||
self.head_dim = hparams["n_embd"] // self.n_head
|
||||
|
||||
self.kv_cache = None
|
||||
self.cos = cos_sin[0]
|
||||
self.sin = cos_sin[1]
|
||||
|
||||
def __call__(self, x:Tensor, mask:Tensor, index_pos:int):
|
||||
residual = x
|
||||
x = self.attn_norm(x)
|
||||
attn = self.forward_attn(x, mask, index_pos)
|
||||
x = attn + residual
|
||||
|
||||
residual = x
|
||||
x = self.ffn_norm(x)
|
||||
w1 = self.ffw1.matmul_t(x)
|
||||
w3 = self.ffw3.matmul_t(x)
|
||||
mlp = self.ffw2.matmul_t(nn.silu(w1) * w3)
|
||||
|
||||
return mlp + residual
|
||||
|
||||
def forward_attn(self, x:Tensor, mask:Tensor, index_pos:int):
|
||||
b_size, seq_len, n_embd = x.shape
|
||||
q = self.attention_wq.matmul_t(x)
|
||||
k = self.attention_wk.matmul_t(x)
|
||||
v = self.attention_wv.matmul_t(x)
|
||||
|
||||
q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2)
|
||||
k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
|
||||
v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
|
||||
|
||||
q = self.apply_rotary_emb(q, index_pos)
|
||||
k = self.apply_rotary_emb(k, index_pos)
|
||||
|
||||
if self.kv_cache is not None and index_pos > 0:
|
||||
prev_k, prev_v = self.kv_cache
|
||||
k = candle.cat([prev_k, k], 2).contiguous()
|
||||
v = candle.cat([prev_v, v], 2).contiguous()
|
||||
|
||||
self.kv_cache = (k, v)
|
||||
|
||||
# TODO: maybe repeat k/v here if we start supporting MQA.
|
||||
|
||||
att = q.matmul(k.t()) / self.head_dim**0.5
|
||||
mask = mask.broadcast_as(att.shape)
|
||||
att = masked_fill(att, mask, float("-inf"))
|
||||
att = nn.softmax(att, -1)
|
||||
y = att.matmul(v.contiguous())
|
||||
y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
|
||||
return self.attention_wo.matmul_t(y)
|
||||
|
||||
def apply_rotary_emb(self, x:Tensor, index_pos:int):
|
||||
(b_size, n_head, seq_len, n_embd) = x.shape
|
||||
cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
|
||||
sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
|
||||
x = x.reshape((b_size, n_head, seq_len, n_embd//2, 2))
|
||||
x0 = x.narrow(-1, 0, 1)
|
||||
x1 = x.narrow(-1, 1, 1)
|
||||
y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin)
|
||||
y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos)
|
||||
rope = candle.cat([y0, y1], -1)
|
||||
return rope.flatten_from(-2)
|
||||
|
||||
def precompute_freqs_cis(hparams, freq_base):
|
||||
head_dim = hparams["n_embd"] // hparams["n_head"]
|
||||
theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)]
|
||||
theta = candle.tensor(theta)
|
||||
idx_theta = [float(i) for i in range(MAX_SEQ_LEN)]
|
||||
idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1))
|
||||
m = idx_theta.matmul(theta.unsqueeze(0))
|
||||
return (m.cos(), m.sin())
|
||||
|
||||
class QuantizedLlama:
|
||||
def __init__(self, hparams:Dict[str,Any], all_tensors:Dict[str,QTensor]):
|
||||
self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
|
||||
self.norm = RmsNorm(all_tensors["norm.weight"])
|
||||
self.output = all_tensors["output.weight"]
|
||||
self.layers = []
|
||||
rope_freq = hparams.get("rope_freq", 10000.)
|
||||
cos_sin = precompute_freqs_cis(hparams, rope_freq)
|
||||
for layer_idx in range(hparams["n_layer"]):
|
||||
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
|
||||
self.layers.append(layer)
|
||||
|
||||
def __call__(self, token:Tensor, index_pos:int):
|
||||
b_size, seq_len = token.shape
|
||||
vocab_size, hidden_size = self.tok_embeddings.shape
|
||||
token = token.reshape((b_size * seq_len,))
|
||||
x = self.tok_embeddings.index_select(token, 0)
|
||||
x = x.reshape((b_size, seq_len, hidden_size))
|
||||
|
||||
mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)]
|
||||
mask = candle.tensor(mask).reshape((seq_len, seq_len))
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, index_pos)
|
||||
x = self.norm(x)
|
||||
x = x.narrow(1, -1, 1).squeeze(1)
|
||||
x = self.output.matmul_t(x)
|
||||
return x
|
||||
|
||||
def gguf_rename(tensor_name:str):
|
||||
if tensor_name == 'token_embd.weight': return 'tok_embeddings.weight'
|
||||
if tensor_name == 'output_norm.weight': return 'norm.weight'
|
||||
tensor_name = tensor_name.replace('blk.', 'layers.')
|
||||
tensor_name = tensor_name.replace('.attn_q.', '.attention.wq.')
|
||||
tensor_name = tensor_name.replace('.attn_k.', '.attention.wk.')
|
||||
tensor_name = tensor_name.replace('.attn_v.', '.attention.wv.')
|
||||
tensor_name = tensor_name.replace('.attn_output.', '.attention.wo.')
|
||||
tensor_name = tensor_name.replace('.ffn_gate.', '.feed_forward.w1.')
|
||||
tensor_name = tensor_name.replace('.ffn_down.', '.feed_forward.w2.')
|
||||
tensor_name = tensor_name.replace('.ffn_up.', '.feed_forward.w3.')
|
||||
tensor_name = tensor_name.replace('.attn_norm.', '.attention_norm.')
|
||||
def gguf_rename(tensor_name: str):
|
||||
if tensor_name == "token_embd.weight":
|
||||
return "tok_embeddings.weight"
|
||||
if tensor_name == "output_norm.weight":
|
||||
return "norm.weight"
|
||||
tensor_name = tensor_name.replace("blk.", "layers.")
|
||||
tensor_name = tensor_name.replace(".attn_q.", ".attention.wq.")
|
||||
tensor_name = tensor_name.replace(".attn_k.", ".attention.wk.")
|
||||
tensor_name = tensor_name.replace(".attn_v.", ".attention.wv.")
|
||||
tensor_name = tensor_name.replace(".attn_output.", ".attention.wo.")
|
||||
tensor_name = tensor_name.replace(".ffn_gate.", ".feed_forward.w1.")
|
||||
tensor_name = tensor_name.replace(".ffn_down.", ".feed_forward.w2.")
|
||||
tensor_name = tensor_name.replace(".ffn_up.", ".feed_forward.w3.")
|
||||
tensor_name = tensor_name.replace(".attn_norm.", ".attention_norm.")
|
||||
return tensor_name
|
||||
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
raise ValueError("missing weight file argument")
|
||||
|
||||
filename = sys.argv[1]
|
||||
print(f"reading model file {filename}")
|
||||
if filename.endswith("gguf"):
|
||||
all_tensors, metadata = utils.load_gguf(sys.argv[1])
|
||||
all_tensors, metadata = utils.load_gguf(filename)
|
||||
vocab = metadata["tokenizer.ggml.tokens"]
|
||||
for i, v in enumerate(vocab):
|
||||
vocab[i] = '\n' if v == '<0x0A>' else v.replace('▁', ' ')
|
||||
vocab[i] = "\n" if v == "<0x0A>" else v.replace("▁", " ")
|
||||
hparams = {k: v for (k, v) in metadata.items() if not k.startswith("tokenizer")}
|
||||
print(hparams)
|
||||
hparams = {
|
||||
'n_vocab': len(vocab),
|
||||
'n_embd': metadata['llama.embedding_length'],
|
||||
'n_mult': 256,
|
||||
'n_head': metadata['llama.attention.head_count'],
|
||||
'n_head_kv': metadata['llama.attention.head_count_kv'],
|
||||
'n_layer': metadata['llama.block_count'],
|
||||
'n_rot': metadata['llama.rope.dimension_count'],
|
||||
'rope_freq': metadata.get('llama.rope.freq_base', 10000.),
|
||||
'ftype': metadata['general.file_type'],
|
||||
"n_vocab": len(vocab),
|
||||
"n_embd": metadata["llama.embedding_length"],
|
||||
"n_mult": 256,
|
||||
"n_head": metadata["llama.attention.head_count"],
|
||||
"n_head_kv": metadata["llama.attention.head_count_kv"],
|
||||
"n_layer": metadata["llama.block_count"],
|
||||
"n_rot": metadata["llama.rope.dimension_count"],
|
||||
"rope_freq": metadata.get("llama.rope.freq_base", 10000.0),
|
||||
"ftype": metadata["general.file_type"],
|
||||
"context_length": metadata["llama.context_length"],
|
||||
}
|
||||
all_tensors = { gguf_rename(k): v for k, v in all_tensors.items() }
|
||||
|
||||
all_tensors = {gguf_rename(k): v for k, v in all_tensors.items()}
|
||||
else:
|
||||
all_tensors, hparams, vocab = utils.load_ggml(sys.argv[1])
|
||||
all_tensors, hparams, vocab = utils.load_ggml(filename)
|
||||
hparams["context_length"] = 2048
|
||||
|
||||
print(hparams)
|
||||
model = QuantizedLlama(hparams, all_tensors)
|
||||
print("model built, starting inference")
|
||||
@ -185,13 +63,14 @@ def main():
|
||||
for token_idx in range(500):
|
||||
last_token = tokens[-1]
|
||||
lt = candle.tensor([last_token]).unsqueeze(0)
|
||||
logits = model(lt, len(tokens))
|
||||
logits = model.forward(lt, len(tokens))
|
||||
# Greedy sampling for now
|
||||
# pr = candle.nn.softmax(logits, -1)
|
||||
m = logits.get(0).argmax_keepdim(-1)
|
||||
next_token = m.values()[0]
|
||||
print(vocab[next_token], end='', flush=True)
|
||||
print(vocab[next_token], end="", flush=True)
|
||||
tokens.append(next_token)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -3,6 +3,7 @@ use pyo3::exceptions::{PyTypeError, PyValueError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{IntoPyDict, PyDict, PyTuple};
|
||||
use pyo3::ToPyObject;
|
||||
use std::os::raw::c_long;
|
||||
use std::sync::Arc;
|
||||
|
||||
use half::{bf16, f16};
|
||||
@ -196,6 +197,12 @@ trait MapDType {
|
||||
}
|
||||
}
|
||||
|
||||
enum Indexer {
|
||||
Index(usize),
|
||||
Slice(usize, usize),
|
||||
Elipsis,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PyTensor {
|
||||
#[new]
|
||||
@ -436,6 +443,95 @@ impl PyTensor {
|
||||
))
|
||||
}
|
||||
|
||||
#[getter]
|
||||
/// Index a tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __getitem__(&self, py: Python, idx: PyObject) -> PyResult<Self> {
|
||||
let mut indexers: Vec<Indexer> = vec![];
|
||||
let dims = self.0.shape().dims();
|
||||
|
||||
let to_absolute_index = |index: isize, current_dim: usize| {
|
||||
// Convert a relative index to an absolute index e.g. tensor[-1] -> tensor[0]
|
||||
let actual_index = if index < 0 {
|
||||
dims[current_dim] as isize + index
|
||||
} else {
|
||||
index
|
||||
};
|
||||
|
||||
// Check that the index is in range
|
||||
if actual_index < 0 || actual_index >= dims[current_dim] as isize {
|
||||
return Err(PyTypeError::new_err(format!(
|
||||
"index out of range for dimension '{i}' with indexer '{value}'",
|
||||
i = current_dim,
|
||||
value = index
|
||||
)));
|
||||
}
|
||||
Ok(actual_index as usize)
|
||||
};
|
||||
if let Ok(index) = idx.extract(py) {
|
||||
// Handle a single index e.g. tensor[0] or tensor[-1]
|
||||
indexers.push(Indexer::Index(to_absolute_index(index, 0)?));
|
||||
} else if let Ok(slice) = idx.downcast::<pyo3::types::PySlice>(py) {
|
||||
// Handle a single slice e.g. tensor[0:1] or tensor[0:-1]
|
||||
let index = slice.indices(dims[0] as c_long)?;
|
||||
indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
|
||||
} else if let Ok(tuple) = idx.downcast::<pyo3::types::PyTuple>(py) {
|
||||
// Handle multiple indices e.g. tensor[0,0] or tensor[0:1,0:1]
|
||||
|
||||
if tuple.len() > dims.len() {
|
||||
return Err(PyTypeError::new_err("provided too many indices"));
|
||||
}
|
||||
|
||||
for (i, item) in tuple.iter().enumerate() {
|
||||
if item.is_ellipsis() {
|
||||
// Handle '...' e.g. tensor[..., 0]
|
||||
|
||||
if i > 0 {
|
||||
return Err(PyTypeError::new_err("Ellipsis ('...') can only be used at the start of an indexing operation"));
|
||||
}
|
||||
indexers.push(Indexer::Elipsis);
|
||||
} else if let Ok(slice) = item.downcast::<pyo3::types::PySlice>() {
|
||||
// Handle slice
|
||||
let index = slice.indices(dims[i] as c_long)?;
|
||||
indexers.push(Indexer::Slice(index.start as usize, index.stop as usize));
|
||||
} else if let Ok(index) = item.extract::<isize>() {
|
||||
indexers.push(Indexer::Index(to_absolute_index(index, i)?));
|
||||
} else {
|
||||
return Err(PyTypeError::new_err("unsupported index"));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(PyTypeError::new_err("unsupported index"));
|
||||
}
|
||||
|
||||
let mut x = self.0.clone();
|
||||
let mut current_dim = 0;
|
||||
// Apply the indexers
|
||||
for indexer in indexers.iter() {
|
||||
x = match indexer {
|
||||
Indexer::Index(n) => x
|
||||
.narrow(current_dim, *n, 1)
|
||||
.map_err(wrap_err)?
|
||||
.squeeze(current_dim)
|
||||
.map_err(wrap_err)?,
|
||||
Indexer::Slice(start, stop) => {
|
||||
let out = x
|
||||
.narrow(current_dim, *start, stop.saturating_sub(*start))
|
||||
.map_err(wrap_err)?;
|
||||
current_dim += 1;
|
||||
out
|
||||
}
|
||||
Indexer::Elipsis => {
|
||||
// Elipsis is a special case, it means that all remaining dimensions should be selected => advance the current_dim to the last dimension we have indexers for
|
||||
current_dim += dims.len() - (indexers.len() - 1);
|
||||
x
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self(x))
|
||||
}
|
||||
|
||||
/// Add two tensors.
|
||||
/// &RETURNS&: Tensor
|
||||
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
@ -697,7 +793,7 @@ impl PyTensor {
|
||||
/// &RETURNS&: QTensor
|
||||
fn quantize(&self, quantized_dtype: &str) -> PyResult<PyQTensor> {
|
||||
use ::candle::quantized;
|
||||
let res = match quantized_dtype {
|
||||
let res = match quantized_dtype.to_lowercase().as_str() {
|
||||
"q2k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ2K>(self),
|
||||
"q3k" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ3K>(self),
|
||||
"q4_0" => quantized::QTensor::quantize::<quantized::k_quants::BlockQ4_0>(self),
|
||||
@ -1137,9 +1233,39 @@ fn silu(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||
Ok(PyTensor(s))
|
||||
}
|
||||
|
||||
fn candle_nn_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(tensor:Tensor)")]
|
||||
/// Applies the Gaussian Error Linear Unit (GELU) function to a given tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn gelu(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||
let s = tensor.0.gelu_erf().map_err(wrap_err)?;
|
||||
Ok(PyTensor(s))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(tensor:Tensor)")]
|
||||
/// Applies the Rectified Linear Unit (ReLU) function to a given tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn relu(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||
let s = tensor.0.relu().map_err(wrap_err)?;
|
||||
Ok(PyTensor(s))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
#[pyo3(text_signature = "(tensor:Tensor)")]
|
||||
/// Applies the tanh function to a given tensor.
|
||||
/// &RETURNS&: Tensor
|
||||
fn tanh(tensor: PyTensor) -> PyResult<PyTensor> {
|
||||
let s = tensor.0.tanh().map_err(wrap_err)?;
|
||||
Ok(PyTensor(s))
|
||||
}
|
||||
|
||||
fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(silu, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(softmax, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(gelu, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(relu, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(tanh, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1148,8 +1274,8 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let utils = PyModule::new(py, "utils")?;
|
||||
candle_utils(py, utils)?;
|
||||
m.add_submodule(utils)?;
|
||||
let nn = PyModule::new(py, "nn")?;
|
||||
candle_nn_m(py, nn)?;
|
||||
let nn = PyModule::new(py, "functional")?;
|
||||
candle_functional_m(py, nn)?;
|
||||
m.add_submodule(nn)?;
|
||||
m.add_class::<PyTensor>()?;
|
||||
m.add_class::<PyQTensor>()?;
|
||||
|
@ -1,4 +1,4 @@
|
||||
#See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
|
||||
# See: https://raw.githubusercontent.com/huggingface/tokenizers/main/bindings/python/stub.py
|
||||
import argparse
|
||||
import inspect
|
||||
import os
|
||||
@ -23,7 +23,7 @@ def do_indent(text: Optional[str], indent: str):
|
||||
return text.replace("\n", f"\n{indent}")
|
||||
|
||||
|
||||
def function(obj, indent:str, text_signature:str=None):
|
||||
def function(obj, indent: str, text_signature: str = None):
|
||||
if text_signature is None:
|
||||
text_signature = obj.__text_signature__
|
||||
|
||||
@ -32,12 +32,12 @@ def function(obj, indent:str, text_signature:str=None):
|
||||
if doc_string is None:
|
||||
doc_string = ""
|
||||
|
||||
# Check if we have a return type annotation in the docstring
|
||||
# Check if we have a return type annotation in the docstring
|
||||
return_type = None
|
||||
doc_lines = doc_string.split("\n")
|
||||
if doc_lines[-1].lstrip().startswith(RETURN_TYPE_MARKER):
|
||||
# Extract the return type and remove it from the docstring
|
||||
return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER):].strip()
|
||||
return_type = doc_lines[-1].lstrip()[len(RETURN_TYPE_MARKER) :].strip()
|
||||
doc_string = "\n".join(doc_lines[:-1])
|
||||
|
||||
string = ""
|
||||
@ -115,7 +115,7 @@ def pyi_file(obj, indent=""):
|
||||
body += f"{indent+INDENT}pass\n"
|
||||
body += "\n"
|
||||
|
||||
for (name, fn) in fns:
|
||||
for name, fn in fns:
|
||||
body += pyi_file(fn, indent=indent)
|
||||
|
||||
if not body:
|
||||
@ -221,12 +221,12 @@ if __name__ == "__main__":
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
#Enable execution from the candle and candle-pyo3 directories
|
||||
# Enable execution from the candle and candle-pyo3 directories
|
||||
cwd = Path.cwd()
|
||||
directory = "py_src/candle/"
|
||||
if cwd.name != "candle-pyo3":
|
||||
directory = f"candle-pyo3/{directory}"
|
||||
|
||||
|
||||
import candle
|
||||
|
||||
write(candle.candle, directory, "candle", check=args.check)
|
||||
|
@ -7,7 +7,7 @@ print(t + t)
|
||||
|
||||
t = candle.Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
|
||||
print(t)
|
||||
print(t+t)
|
||||
print(t + t)
|
||||
|
||||
t = t.reshape([2, 4])
|
||||
print(t.matmul(t.t()))
|
||||
|
0
candle-pyo3/tests/__init__.py
Normal file
0
candle-pyo3/tests/__init__.py
Normal file
38
candle-pyo3/tests/bindings/test_linear.py
Normal file
38
candle-pyo3/tests/bindings/test_linear.py
Normal file
@ -0,0 +1,38 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
from candle.nn import Linear
|
||||
|
||||
|
||||
def test_linear_layer_can_be_constructed():
|
||||
linear = Linear(10, 10)
|
||||
assert linear is not None
|
||||
|
||||
|
||||
def test_linear_layer_can_forward_a_singular_input():
|
||||
linear = Linear(384, 1536)
|
||||
input_tensor = candle.randn((8, 384))
|
||||
output = linear.forward(input_tensor)
|
||||
assert output.shape == (8, 1536)
|
||||
|
||||
|
||||
def test_linear_layer_can_forward_a_batched_input():
|
||||
linear = Linear(384, 1536)
|
||||
input_tensor = candle.randn((16, 8, 384))
|
||||
output = linear.forward(input_tensor)
|
||||
assert output.shape == (16, 8, 1536)
|
||||
|
||||
|
||||
def test_quantized_linear_layer_can_forward_a_singular_input():
|
||||
linear = Linear(384, 1536)
|
||||
linear.weight = linear.weight.quantize("q4_0")
|
||||
input_tensor = candle.randn((8, 384))
|
||||
output = linear.forward(input_tensor)
|
||||
assert output.shape == (8, 1536)
|
||||
|
||||
|
||||
def test_quantized_linear_layer_can_forward_a_batched_input():
|
||||
linear = Linear(384, 1536)
|
||||
linear.weight = linear.weight.quantize("q4_0")
|
||||
input_tensor = candle.randn((16, 8, 384))
|
||||
output = linear.forward(input_tensor)
|
||||
assert output.shape == (16, 8, 1536)
|
161
candle-pyo3/tests/bindings/test_module.py
Normal file
161
candle-pyo3/tests/bindings/test_module.py
Normal file
@ -0,0 +1,161 @@
|
||||
import candle
|
||||
from candle import Tensor, QTensor
|
||||
from candle.nn import Module, Linear
|
||||
from candle.utils import cuda_is_available
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_module_can_be_constructed():
|
||||
class A(Module):
|
||||
pass
|
||||
|
||||
a = A()
|
||||
assert a is not None
|
||||
assert len(list(a.buffers())) == 0
|
||||
|
||||
|
||||
def test_module_registers_tensors():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = Tensor(42.0)
|
||||
|
||||
a = A()
|
||||
named_buffers = dict(a.named_buffers())
|
||||
assert len(named_buffers) == 1
|
||||
assert "t" in named_buffers
|
||||
|
||||
|
||||
def test_module_registers_submodules():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = Linear(10, 20)
|
||||
|
||||
a = A()
|
||||
named_modules = dict(a.named_modules())
|
||||
named_buffers = dict(a.named_buffers())
|
||||
assert len(named_buffers) == 2
|
||||
assert "linear" in named_modules
|
||||
assert "linear.weight" in named_buffers
|
||||
assert "linear.bias" in named_buffers
|
||||
|
||||
|
||||
def test_module_can_dump_statedict():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = Linear(10, 20)
|
||||
self.t = Tensor(42.0)
|
||||
|
||||
a = A()
|
||||
state_dict = a.state_dict()
|
||||
assert hasattr(state_dict, "_metadata")
|
||||
assert "t" in state_dict
|
||||
assert "linear.weight" in state_dict
|
||||
assert "linear.bias" in state_dict
|
||||
assert len(state_dict) == 3
|
||||
|
||||
|
||||
def test_module_can_load_statedict():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = Linear(10, 20)
|
||||
self.t = Tensor(42.0)
|
||||
|
||||
statedict = {
|
||||
"linear.weight": candle.ones((20, 10)),
|
||||
"linear.bias": candle.zeros((20,)),
|
||||
"t": Tensor(42.0),
|
||||
}
|
||||
a = A()
|
||||
a.load_state_dict(statedict)
|
||||
|
||||
|
||||
def test_module_throws_on_shape_missmatch():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = Tensor(42.0)
|
||||
|
||||
statedict = {
|
||||
"t": candle.ones((20,)),
|
||||
}
|
||||
a = A()
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
a.load_state_dict(statedict)
|
||||
assert "size mismatch" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_module_throws_on_missing_key():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = Tensor(42.0)
|
||||
|
||||
statedict = {
|
||||
"not_t": Tensor(42.0),
|
||||
}
|
||||
|
||||
a = A()
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
a.load_state_dict(statedict)
|
||||
assert 'Missing key(s) in state_dict: "t".' in str(excinfo.value)
|
||||
|
||||
|
||||
def test_module_can_load_quantized_tensors():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = candle.randn((16, 256))
|
||||
self._quantizable_buffers.add("t")
|
||||
|
||||
statedict = {
|
||||
"t": candle.ones((16, 256)).quantize("q4_0"),
|
||||
}
|
||||
a = A()
|
||||
a.load_state_dict(statedict)
|
||||
assert isinstance(a.t, QTensor)
|
||||
assert a.t.ggml_dtype == "Q4_0"
|
||||
|
||||
|
||||
def test_module_dequantizes_tensors_automaticaly():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = candle.randn((16, 256))
|
||||
|
||||
statedict = {
|
||||
"t": candle.ones((16, 256)).quantize("q4_0"),
|
||||
}
|
||||
a = A()
|
||||
a.load_state_dict(statedict)
|
||||
assert isinstance(a.t, Tensor)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
|
||||
def test_module_can_be_moved_to_cuda():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = candle.randn((16, 256))
|
||||
|
||||
a = A()
|
||||
a.cuda()
|
||||
assert a.t.device == "cuda"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
|
||||
def test_module_can_be_moved_from_cuda_to_cpu():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = candle.randn((16, 256))
|
||||
|
||||
a = A()
|
||||
a.cuda()
|
||||
assert a.t.device == "cuda"
|
||||
a.cpu()
|
||||
assert a.t.device == "cpu"
|
74
candle-pyo3/tests/native/test_tensor.py
Normal file
74
candle-pyo3/tests/native/test_tensor.py
Normal file
@ -0,0 +1,74 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
|
||||
|
||||
def test_tensor_can_be_constructed():
|
||||
t = Tensor(42.0)
|
||||
assert t.values() == 42.0
|
||||
|
||||
|
||||
def test_tensor_can_be_constructed_from_list():
|
||||
t = Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
|
||||
assert t.values() == [3.0, 1, 4, 1, 5, 9, 2, 6]
|
||||
|
||||
|
||||
def test_tensor_can_be_constructed_from_list_of_lists():
|
||||
t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]])
|
||||
assert t.values() == [[3.0, 1, 4, 1], [5, 9, 2, 6]]
|
||||
|
||||
|
||||
def test_tensor_can_be_quantized():
|
||||
t = candle.randn((16, 256))
|
||||
for format in [
|
||||
"q4_0",
|
||||
"q4_1",
|
||||
"q5_0",
|
||||
"q5_1",
|
||||
"q8_0",
|
||||
"q2k",
|
||||
"q3k",
|
||||
"q4k",
|
||||
"q5k",
|
||||
"q8k",
|
||||
]:
|
||||
for formatted_format in [format.upper(), format.lower()]:
|
||||
quant_t = t.quantize(formatted_format)
|
||||
assert quant_t.ggml_dtype.lower() == format.lower()
|
||||
assert quant_t.shape == t.shape
|
||||
|
||||
|
||||
def test_tensor_can_be_indexed():
|
||||
t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]])
|
||||
assert t[0].values() == [3.0, 1.0, 4.0, 1.0]
|
||||
assert t[1].values() == [5.0, 9.0, 2.0, 6.0]
|
||||
assert t[-1].values() == [5.0, 9.0, 2.0, 6.0]
|
||||
assert t[-2].values() == [3.0, 1.0, 4.0, 1.0]
|
||||
|
||||
|
||||
def test_tensor_can_be_sliced():
|
||||
t = Tensor([3.0, 1, 4, 10, 5, 9, 2, 6])
|
||||
|
||||
assert t[0:4].values() == [3.0, 1.0, 4.0, 10.0]
|
||||
assert t[4:8].values() == [5.0, 9.0, 2.0, 6.0]
|
||||
assert t[-4:].values() == [5.0, 9.0, 2.0, 6.0]
|
||||
assert t[:-4].values() == [3.0, 1.0, 4.0, 10.0]
|
||||
assert t[-4:-2].values() == [5.0, 9.0]
|
||||
|
||||
|
||||
def test_tensor_can_be_sliced_2d():
|
||||
t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]])
|
||||
assert t[:, 0].values() == [3.0, 5]
|
||||
assert t[:, 1].values() == [1.0, 9.0]
|
||||
assert t[0, 0].values() == 3.0
|
||||
assert t[:, -1].values() == [1.0, 6.0]
|
||||
assert t[:, -4].values() == [3.0, 5]
|
||||
assert t[..., 0].values() == [3.0, 5]
|
||||
|
||||
|
||||
def test_tensor_can_be_scliced_3d():
|
||||
t = Tensor([[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]]])
|
||||
assert t[:, :, 0].values() == [[1, 5], [9, 13]]
|
||||
assert t[:, :, 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
||||
assert t[:, 0, 0].values() == [1, 9]
|
||||
assert t[..., 0].values() == [[1, 5], [9, 13]]
|
||||
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
51
candle-pyo3/tests/native/test_utils.py
Normal file
51
candle-pyo3/tests/native/test_utils.py
Normal file
@ -0,0 +1,51 @@
|
||||
import candle
|
||||
from candle import Tensor, QTensor
|
||||
from candle.utils import load_safetensors, save_gguf, load_gguf, save_safetensors
|
||||
from pathlib import Path
|
||||
|
||||
TEST_DIR = Path(__file__).parent.parent / "_workdir"
|
||||
TEST_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def test_can_roundtrip_safetensors():
|
||||
tensors = {
|
||||
"a": candle.randn((16, 256)),
|
||||
"b": candle.randn((16, 16)),
|
||||
}
|
||||
|
||||
file = str(TEST_DIR / "test.safetensors")
|
||||
save_safetensors(file, tensors)
|
||||
loaded_tensors = load_safetensors(file)
|
||||
assert set(tensors.keys()) == set(loaded_tensors.keys())
|
||||
for key in tensors.keys():
|
||||
assert tensors[key].values() == loaded_tensors[key].values(), "Values are not equal"
|
||||
assert tensors[key].shape == loaded_tensors[key].shape, "Shapes are not equal"
|
||||
assert str(tensors[key].dtype) == str(loaded_tensors[key].dtype), "Dtypes are not equal"
|
||||
|
||||
|
||||
def test_can_roundtrip_gguf():
|
||||
metadata = {
|
||||
"a": 1,
|
||||
"b": "foo",
|
||||
"c": [1, 2, 3],
|
||||
"d": [[1, 2], [3, 4]],
|
||||
}
|
||||
|
||||
tensors = {
|
||||
"a": candle.randn((16, 256)).quantize("q4_0"),
|
||||
"b": candle.randn((16, 16)).quantize("f32"),
|
||||
}
|
||||
|
||||
file = str(TEST_DIR / "test.gguf")
|
||||
save_gguf(file, tensors, metadata)
|
||||
loaded_tensors, loaded_metadata = load_gguf(file)
|
||||
|
||||
assert set(metadata.keys()) == set(loaded_metadata.keys())
|
||||
for key in metadata.keys():
|
||||
assert metadata[key] == loaded_metadata[key]
|
||||
|
||||
assert set(tensors.keys()) == set(loaded_tensors.keys())
|
||||
for key in tensors.keys():
|
||||
assert tensors[key].dequantize().values() == loaded_tensors[key].dequantize().values(), "Values are not equal"
|
||||
assert tensors[key].shape == loaded_tensors[key].shape, "Shapes are not equal"
|
||||
assert str(tensors[key].ggml_dtype) == str(loaded_tensors[key].ggml_dtype), "Dtypes are not equal"
|
Reference in New Issue
Block a user