mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00

* Some first `Module` implementations * Add `state_dict` and `load_state_dict` functionality * Move modules around and create `candle.nn.Linear` * Add `nn.Embedding` and `nn.LayerNorm` * Add BERT implementation * Batch q-matmul * Automatically dequantize `QTensors` if a `Tensor` is expected * Add Module `.to()`, `.cuda()`, `cpu()` and `.type()` functionality * Unittests for `Module`, `Tensor` and `candle.utils` * Add `pytorch` like slicing to `Tensor` * Cleanup and BERT fixes * `black` formatting + unit-test for `nn.Linear` * Refactor slicing implementation
162 lines
3.9 KiB
Python
162 lines
3.9 KiB
Python
import candle
|
|
from candle import Tensor, QTensor
|
|
from candle.nn import Module, Linear
|
|
from candle.utils import cuda_is_available
|
|
|
|
import pytest
|
|
|
|
|
|
def test_module_can_be_constructed():
|
|
class A(Module):
|
|
pass
|
|
|
|
a = A()
|
|
assert a is not None
|
|
assert len(list(a.buffers())) == 0
|
|
|
|
|
|
def test_module_registers_tensors():
|
|
class A(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.t = Tensor(42.0)
|
|
|
|
a = A()
|
|
named_buffers = dict(a.named_buffers())
|
|
assert len(named_buffers) == 1
|
|
assert "t" in named_buffers
|
|
|
|
|
|
def test_module_registers_submodules():
|
|
class A(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = Linear(10, 20)
|
|
|
|
a = A()
|
|
named_modules = dict(a.named_modules())
|
|
named_buffers = dict(a.named_buffers())
|
|
assert len(named_buffers) == 2
|
|
assert "linear" in named_modules
|
|
assert "linear.weight" in named_buffers
|
|
assert "linear.bias" in named_buffers
|
|
|
|
|
|
def test_module_can_dump_statedict():
|
|
class A(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = Linear(10, 20)
|
|
self.t = Tensor(42.0)
|
|
|
|
a = A()
|
|
state_dict = a.state_dict()
|
|
assert hasattr(state_dict, "_metadata")
|
|
assert "t" in state_dict
|
|
assert "linear.weight" in state_dict
|
|
assert "linear.bias" in state_dict
|
|
assert len(state_dict) == 3
|
|
|
|
|
|
def test_module_can_load_statedict():
|
|
class A(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = Linear(10, 20)
|
|
self.t = Tensor(42.0)
|
|
|
|
statedict = {
|
|
"linear.weight": candle.ones((20, 10)),
|
|
"linear.bias": candle.zeros((20,)),
|
|
"t": Tensor(42.0),
|
|
}
|
|
a = A()
|
|
a.load_state_dict(statedict)
|
|
|
|
|
|
def test_module_throws_on_shape_missmatch():
|
|
class A(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.t = Tensor(42.0)
|
|
|
|
statedict = {
|
|
"t": candle.ones((20,)),
|
|
}
|
|
a = A()
|
|
with pytest.raises(RuntimeError) as excinfo:
|
|
a.load_state_dict(statedict)
|
|
assert "size mismatch" in str(excinfo.value)
|
|
|
|
|
|
def test_module_throws_on_missing_key():
|
|
class A(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.t = Tensor(42.0)
|
|
|
|
statedict = {
|
|
"not_t": Tensor(42.0),
|
|
}
|
|
|
|
a = A()
|
|
with pytest.raises(RuntimeError) as excinfo:
|
|
a.load_state_dict(statedict)
|
|
assert 'Missing key(s) in state_dict: "t".' in str(excinfo.value)
|
|
|
|
|
|
def test_module_can_load_quantized_tensors():
|
|
class A(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.t = candle.randn((16, 256))
|
|
self._quantizable_buffers.add("t")
|
|
|
|
statedict = {
|
|
"t": candle.ones((16, 256)).quantize("q4_0"),
|
|
}
|
|
a = A()
|
|
a.load_state_dict(statedict)
|
|
assert isinstance(a.t, QTensor)
|
|
assert a.t.ggml_dtype == "Q4_0"
|
|
|
|
|
|
def test_module_dequantizes_tensors_automaticaly():
|
|
class A(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.t = candle.randn((16, 256))
|
|
|
|
statedict = {
|
|
"t": candle.ones((16, 256)).quantize("q4_0"),
|
|
}
|
|
a = A()
|
|
a.load_state_dict(statedict)
|
|
assert isinstance(a.t, Tensor)
|
|
|
|
|
|
@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
|
|
def test_module_can_be_moved_to_cuda():
|
|
class A(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.t = candle.randn((16, 256))
|
|
|
|
a = A()
|
|
a.cuda()
|
|
assert a.t.device == "cuda"
|
|
|
|
|
|
@pytest.mark.skipif(not cuda_is_available(), reason="CUDA is not available")
|
|
def test_module_can_be_moved_from_cuda_to_cpu():
|
|
class A(Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.t = candle.randn((16, 256))
|
|
|
|
a = A()
|
|
a.cuda()
|
|
assert a.t.device == "cuda"
|
|
a.cpu()
|
|
assert a.t.device == "cpu"
|