mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Make the Python Wrapper more Hackable and simplify Quantization (#1010)
* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
This commit is contained in:
0
candle-pyo3/tests/__init__.py
Normal file
0
candle-pyo3/tests/__init__.py
Normal file
38
candle-pyo3/tests/bindings/test_linear.py
Normal file
38
candle-pyo3/tests/bindings/test_linear.py
Normal file
@ -0,0 +1,38 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
from candle.nn import Linear
|
||||
|
||||
|
||||
def test_linear_layer_can_be_constructed():
|
||||
linear = Linear(10, 10)
|
||||
assert linear is not None
|
||||
|
||||
|
||||
def test_linear_layer_can_forward_a_singular_input():
|
||||
linear = Linear(384, 1536)
|
||||
input_tensor = candle.randn((8, 384))
|
||||
output = linear.forward(input_tensor)
|
||||
assert output.shape == (8, 1536)
|
||||
|
||||
|
||||
def test_linear_layer_can_forward_a_batched_input():
|
||||
linear = Linear(384, 1536)
|
||||
input_tensor = candle.randn((16, 8, 384))
|
||||
output = linear.forward(input_tensor)
|
||||
assert output.shape == (16, 8, 1536)
|
||||
|
||||
|
||||
def test_quantized_linear_layer_can_forward_a_singular_input():
|
||||
linear = Linear(384, 1536)
|
||||
linear.weight = linear.weight.quantize("q4_0")
|
||||
input_tensor = candle.randn((8, 384))
|
||||
output = linear.forward(input_tensor)
|
||||
assert output.shape == (8, 1536)
|
||||
|
||||
|
||||
def test_quantized_linear_layer_can_forward_a_batched_input():
|
||||
linear = Linear(384, 1536)
|
||||
linear.weight = linear.weight.quantize("q4_0")
|
||||
input_tensor = candle.randn((16, 8, 384))
|
||||
output = linear.forward(input_tensor)
|
||||
assert output.shape == (16, 8, 1536)
|
161
candle-pyo3/tests/bindings/test_module.py
Normal file
161
candle-pyo3/tests/bindings/test_module.py
Normal file
@ -0,0 +1,161 @@
|
||||
import candle
|
||||
from candle import Tensor, QTensor
|
||||
from candle.nn import Module, Linear
|
||||
from candle.utils import cuda_is_available
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_module_can_be_constructed():
|
||||
class A(Module):
|
||||
pass
|
||||
|
||||
a = A()
|
||||
assert a is not None
|
||||
assert len(list(a.buffers())) == 0
|
||||
|
||||
|
||||
def test_module_registers_tensors():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = Tensor(42.0)
|
||||
|
||||
a = A()
|
||||
named_buffers = dict(a.named_buffers())
|
||||
assert len(named_buffers) == 1
|
||||
assert "t" in named_buffers
|
||||
|
||||
|
||||
def test_module_registers_submodules():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = Linear(10, 20)
|
||||
|
||||
a = A()
|
||||
named_modules = dict(a.named_modules())
|
||||
named_buffers = dict(a.named_buffers())
|
||||
assert len(named_buffers) == 2
|
||||
assert "linear" in named_modules
|
||||
assert "linear.weight" in named_buffers
|
||||
assert "linear.bias" in named_buffers
|
||||
|
||||
|
||||
def test_module_can_dump_statedict():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = Linear(10, 20)
|
||||
self.t = Tensor(42.0)
|
||||
|
||||
a = A()
|
||||
state_dict = a.state_dict()
|
||||
assert hasattr(state_dict, "_metadata")
|
||||
assert "t" in state_dict
|
||||
assert "linear.weight" in state_dict
|
||||
assert "linear.bias" in state_dict
|
||||
assert len(state_dict) == 3
|
||||
|
||||
|
||||
def test_module_can_load_statedict():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = Linear(10, 20)
|
||||
self.t = Tensor(42.0)
|
||||
|
||||
statedict = {
|
||||
"linear.weight": candle.ones((20, 10)),
|
||||
"linear.bias": candle.zeros((20,)),
|
||||
"t": Tensor(42.0),
|
||||
}
|
||||
a = A()
|
||||
a.load_state_dict(statedict)
|
||||
|
||||
|
||||
def test_module_throws_on_shape_missmatch():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = Tensor(42.0)
|
||||
|
||||
statedict = {
|
||||
"t": candle.ones((20,)),
|
||||
}
|
||||
a = A()
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
a.load_state_dict(statedict)
|
||||
assert "size mismatch" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_module_throws_on_missing_key():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = Tensor(42.0)
|
||||
|
||||
statedict = {
|
||||
"not_t": Tensor(42.0),
|
||||
}
|
||||
|
||||
a = A()
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
a.load_state_dict(statedict)
|
||||
assert 'Missing key(s) in state_dict: "t".' in str(excinfo.value)
|
||||
|
||||
|
||||
def test_module_can_load_quantized_tensors():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = candle.randn((16, 256))
|
||||
self._quantizable_buffers.add("t")
|
||||
|
||||
statedict = {
|
||||
"t": candle.ones((16, 256)).quantize("q4_0"),
|
||||
}
|
||||
a = A()
|
||||
a.load_state_dict(statedict)
|
||||
assert isinstance(a.t, QTensor)
|
||||
assert a.t.ggml_dtype == "Q4_0"
|
||||
|
||||
|
||||
def test_module_dequantizes_tensors_automaticaly():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = candle.randn((16, 256))
|
||||
|
||||
statedict = {
|
||||
"t": candle.ones((16, 256)).quantize("q4_0"),
|
||||
}
|
||||
a = A()
|
||||
a.load_state_dict(statedict)
|
||||
assert isinstance(a.t, Tensor)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
|
||||
def test_module_can_be_moved_to_cuda():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = candle.randn((16, 256))
|
||||
|
||||
a = A()
|
||||
a.cuda()
|
||||
assert a.t.device == "cuda"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
|
||||
def test_module_can_be_moved_from_cuda_to_cpu():
|
||||
class A(Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.t = candle.randn((16, 256))
|
||||
|
||||
a = A()
|
||||
a.cuda()
|
||||
assert a.t.device == "cuda"
|
||||
a.cpu()
|
||||
assert a.t.device == "cpu"
|
74
candle-pyo3/tests/native/test_tensor.py
Normal file
74
candle-pyo3/tests/native/test_tensor.py
Normal file
@ -0,0 +1,74 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
|
||||
|
||||
def test_tensor_can_be_constructed():
|
||||
t = Tensor(42.0)
|
||||
assert t.values() == 42.0
|
||||
|
||||
|
||||
def test_tensor_can_be_constructed_from_list():
|
||||
t = Tensor([3.0, 1, 4, 1, 5, 9, 2, 6])
|
||||
assert t.values() == [3.0, 1, 4, 1, 5, 9, 2, 6]
|
||||
|
||||
|
||||
def test_tensor_can_be_constructed_from_list_of_lists():
|
||||
t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]])
|
||||
assert t.values() == [[3.0, 1, 4, 1], [5, 9, 2, 6]]
|
||||
|
||||
|
||||
def test_tensor_can_be_quantized():
|
||||
t = candle.randn((16, 256))
|
||||
for format in [
|
||||
"q4_0",
|
||||
"q4_1",
|
||||
"q5_0",
|
||||
"q5_1",
|
||||
"q8_0",
|
||||
"q2k",
|
||||
"q3k",
|
||||
"q4k",
|
||||
"q5k",
|
||||
"q8k",
|
||||
]:
|
||||
for formatted_format in [format.upper(), format.lower()]:
|
||||
quant_t = t.quantize(formatted_format)
|
||||
assert quant_t.ggml_dtype.lower() == format.lower()
|
||||
assert quant_t.shape == t.shape
|
||||
|
||||
|
||||
def test_tensor_can_be_indexed():
|
||||
t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]])
|
||||
assert t[0].values() == [3.0, 1.0, 4.0, 1.0]
|
||||
assert t[1].values() == [5.0, 9.0, 2.0, 6.0]
|
||||
assert t[-1].values() == [5.0, 9.0, 2.0, 6.0]
|
||||
assert t[-2].values() == [3.0, 1.0, 4.0, 1.0]
|
||||
|
||||
|
||||
def test_tensor_can_be_sliced():
|
||||
t = Tensor([3.0, 1, 4, 10, 5, 9, 2, 6])
|
||||
|
||||
assert t[0:4].values() == [3.0, 1.0, 4.0, 10.0]
|
||||
assert t[4:8].values() == [5.0, 9.0, 2.0, 6.0]
|
||||
assert t[-4:].values() == [5.0, 9.0, 2.0, 6.0]
|
||||
assert t[:-4].values() == [3.0, 1.0, 4.0, 10.0]
|
||||
assert t[-4:-2].values() == [5.0, 9.0]
|
||||
|
||||
|
||||
def test_tensor_can_be_sliced_2d():
|
||||
t = Tensor([[3.0, 1, 4, 1], [5, 9, 2, 6]])
|
||||
assert t[:, 0].values() == [3.0, 5]
|
||||
assert t[:, 1].values() == [1.0, 9.0]
|
||||
assert t[0, 0].values() == 3.0
|
||||
assert t[:, -1].values() == [1.0, 6.0]
|
||||
assert t[:, -4].values() == [3.0, 5]
|
||||
assert t[..., 0].values() == [3.0, 5]
|
||||
|
||||
|
||||
def test_tensor_can_be_scliced_3d():
|
||||
t = Tensor([[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]]])
|
||||
assert t[:, :, 0].values() == [[1, 5], [9, 13]]
|
||||
assert t[:, :, 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
||||
assert t[:, 0, 0].values() == [1, 9]
|
||||
assert t[..., 0].values() == [[1, 5], [9, 13]]
|
||||
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
51
candle-pyo3/tests/native/test_utils.py
Normal file
51
candle-pyo3/tests/native/test_utils.py
Normal file
@ -0,0 +1,51 @@
|
||||
import candle
|
||||
from candle import Tensor, QTensor
|
||||
from candle.utils import load_safetensors, save_gguf, load_gguf, save_safetensors
|
||||
from pathlib import Path
|
||||
|
||||
TEST_DIR = Path(__file__).parent.parent / "_workdir"
|
||||
TEST_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
def test_can_roundtrip_safetensors():
|
||||
tensors = {
|
||||
"a": candle.randn((16, 256)),
|
||||
"b": candle.randn((16, 16)),
|
||||
}
|
||||
|
||||
file = str(TEST_DIR / "test.safetensors")
|
||||
save_safetensors(file, tensors)
|
||||
loaded_tensors = load_safetensors(file)
|
||||
assert set(tensors.keys()) == set(loaded_tensors.keys())
|
||||
for key in tensors.keys():
|
||||
assert tensors[key].values() == loaded_tensors[key].values(), "Values are not equal"
|
||||
assert tensors[key].shape == loaded_tensors[key].shape, "Shapes are not equal"
|
||||
assert str(tensors[key].dtype) == str(loaded_tensors[key].dtype), "Dtypes are not equal"
|
||||
|
||||
|
||||
def test_can_roundtrip_gguf():
|
||||
metadata = {
|
||||
"a": 1,
|
||||
"b": "foo",
|
||||
"c": [1, 2, 3],
|
||||
"d": [[1, 2], [3, 4]],
|
||||
}
|
||||
|
||||
tensors = {
|
||||
"a": candle.randn((16, 256)).quantize("q4_0"),
|
||||
"b": candle.randn((16, 16)).quantize("f32"),
|
||||
}
|
||||
|
||||
file = str(TEST_DIR / "test.gguf")
|
||||
save_gguf(file, tensors, metadata)
|
||||
loaded_tensors, loaded_metadata = load_gguf(file)
|
||||
|
||||
assert set(metadata.keys()) == set(loaded_metadata.keys())
|
||||
for key in metadata.keys():
|
||||
assert metadata[key] == loaded_metadata[key]
|
||||
|
||||
assert set(tensors.keys()) == set(loaded_tensors.keys())
|
||||
for key in tensors.keys():
|
||||
assert tensors[key].dequantize().values() == loaded_tensors[key].dequantize().values(), "Values are not equal"
|
||||
assert tensors[key].shape == loaded_tensors[key].shape, "Shapes are not equal"
|
||||
assert str(tensors[key].ggml_dtype) == str(loaded_tensors[key].ggml_dtype), "Dtypes are not equal"
|
Reference in New Issue
Block a user