diff --git a/Makefile b/Makefile index 87d74c9..6a66431 100644 --- a/Makefile +++ b/Makefile @@ -4,3 +4,7 @@ OPT = -O3 -g include golpe/rules.mk LDLIBS += -lsecp256k1 -lb2 -lzstd + +test/xor: OPT=-O0 -g +test/xor: test/xor.cpp + $(CXX) $(CXXFLAGS) $(INCS) $(LDFLAGS) $(LDLIBS) $< -o $@ diff --git a/src/xor.h b/src/xor.h index 454f0c8..7997572 100644 --- a/src/xor.h +++ b/src/xor.h @@ -76,15 +76,25 @@ struct XorView { } void finalise() { - std::reverse(elems.begin(), elems.end()); // pushed in approximately descending order, so hopefully this speeds up the sort + std::reverse(elems.begin(), elems.end()); // typically pushed in approximately descending order so this may speed up the sort std::sort(elems.begin(), elems.end(), [&](const auto &a, const auto &b) { return a.getCompare(idSize) < b.getCompare(idSize); }); ready = true; } + std::string initialQuery() { + if (!ready) throw herr("xor view not ready"); + + std::string output; + splitRange(elems.begin(), elems.end(), output); + return output; + } + // FIXME: try/catch everywhere that calls this std::string handleQuery(std::string_view query, std::vector &haveIds, std::vector &needIds) { + if (!ready) throw herr("xor view not ready"); + std::string output; auto cmp = [&](const auto &a, const auto &b){ return a.getCompare(idSize) < b.getCompare(idSize); }; @@ -98,10 +108,11 @@ struct XorView { if (upperLength > idSize + 5) throw herr("upper too long"); XorElem upperKey(getBytes(query, upperLength)); - auto lower = std::lower_bound(elems.begin(), elems.end(), lowerKey, cmp); + auto lower = std::lower_bound(elems.begin(), elems.end(), lowerKey, cmp); // FIXME: start at prev upper? auto upper = std::upper_bound(elems.begin(), elems.end(), upperKey, cmp); // FIXME: start at lower? uint64_t mode = decodeVarInt(query); // 0 = range, 8 and above = n-8 inline IDs + std::cerr << "BING MODE = " << mode << std::endl; if (mode == 0) { XorElem theirXorSet(0, getBytes(query, idSize), idSize); @@ -109,65 +120,28 @@ struct XorView { XorElem ourXorSet; for (auto i = lower; i < upper; ++i) ourXorSet.doXor(*i); - if (theirXorSet.getId(idSize) != ourXorSet.getId(idSize)) { - // Split our range - uint64_t numElems = upper - lower; - const uint64_t buckets = 16; - - if (numElems < buckets * 2) { - output += encodeVarInt(numElems + 8); - for (auto it = lower; it < upper; ++it) output += it->getId(idSize); - } else { - uint64_t elemsPerBucket = numElems / buckets; - uint64_t bucketsWithExtra = numElems % buckets; - auto curr = lower; - - for (uint64_t i = 0; i < buckets; i++) { - { - auto k = getLowerKey(curr); - output += encodeVarInt(k.size()); - output += k; - } - - auto bucketEnd = curr + elemsPerBucket; - if (i < bucketsWithExtra) bucketEnd++; - - XorElem ourXorSet; - for (auto bucketEnd = curr + elemsPerBucket + (i < bucketsWithExtra ? 1 : 0); curr != bucketEnd; curr++) { - ourXorSet.doXor(*curr); - } - - output += ourXorSet.getId(idSize); - - { - auto k = getLowerKey(curr); - output += encodeVarInt(k.size()); - output += k; - } - } - } - } + if (theirXorSet.getId(idSize) != ourXorSet.getId(idSize)) splitRange(lower, upper, output); } else if (mode >= 8) { flat_hash_map theirElems; for (uint64_t i = 0; i < mode - 8; i++) { theirElems.emplace(XorElem(0, getBytes(query, idSize), idSize), false); } - for (auto it = lower; lower < upper; ++it) { + for (auto it = lower; it < upper; ++it) { auto e = theirElems.find(*it); if (e == theirElems.end()) { - // Id exists on our side, but not their side + // ID exists on our side, but not their side haveIds.emplace_back(it->getId(idSize)); } else { - // Id exists on both sides + // ID exists on both sides e->second = true; } } for (const auto &[k, v] : theirElems) { if (!v) { - // Id exists on their side, but not our side + // ID exists on their side, but not our side needIds.emplace_back(k.getId(idSize)); } } @@ -177,6 +151,47 @@ struct XorView { return output; } + private: + + void splitRange(std::vector::iterator lower, std::vector::iterator upper, std::string &output) { + // Split our range + uint64_t numElems = upper - lower; + const uint64_t buckets = 16; + + if (numElems < buckets * 2) { + appendBoundKey(getLowerKey(lower), output); + appendBoundKey(getUpperKey(upper), output); + + output += encodeVarInt(numElems + 8); + for (auto it = lower; it < upper; ++it) output += it->getId(idSize); + } else { + uint64_t elemsPerBucket = numElems / buckets; + uint64_t bucketsWithExtra = numElems % buckets; + auto curr = lower; + + for (uint64_t i = 0; i < buckets; i++) { + appendBoundKey(getLowerKey(curr), output); + + auto bucketEnd = curr + elemsPerBucket; + if (i < bucketsWithExtra) bucketEnd++; + + XorElem ourXorSet; + for (auto bucketEnd = curr + elemsPerBucket + (i < bucketsWithExtra ? 1 : 0); curr != bucketEnd; curr++) { + ourXorSet.doXor(*curr); + } + + appendBoundKey(getUpperKey(curr), output); + + output += ourXorSet.getId(idSize); + } + } + } + + void appendBoundKey(std::string k, std::string &output) { + output += encodeVarInt(k.size()); + output += k; + } + std::string getLowerKey(std::vector::iterator it) { if (it == elems.begin()) return std::string(1, '\0'); return minimalKeyDiff(it->getCompare(idSize), std::prev(it)->getCompare(idSize)); @@ -194,16 +209,6 @@ struct XorView { throw herr("couldn't compute shared prefix"); } - - std::string xorRange(uint64_t start, uint64_t len) { - XorElem output; - - for (uint64_t i = 0; i < len; i++) { - output.doXor(elems[i]); - } - - return std::string(output.getId(idSize)); - } }; @@ -211,7 +216,7 @@ namespace std { // inject specialization of std::hash template<> struct hash { std::size_t operator()(XorElem const &p) const { - return phmap::HashState().combine(0, p.getFull()); + return phmap::HashState().combine(0, p.getIdPadded()); } }; } diff --git a/test/.gitignore b/test/.gitignore new file mode 100644 index 0000000..93eed21 --- /dev/null +++ b/test/.gitignore @@ -0,0 +1 @@ +/xor diff --git a/test/xor.cpp b/test/xor.cpp new file mode 100644 index 0000000..abb5dd8 --- /dev/null +++ b/test/xor.cpp @@ -0,0 +1,31 @@ +#include + +#include "golpe.h" +#include "xor.h" + + +int main() { + XorView x1(16); + x1.addElem(1000, std::string(16, 'a')); + x1.addElem(2000, std::string(16, 'b')); + x1.finalise(); + + XorView x2(16); + x2.addElem(2000, std::string(16, 'b')); + x2.addElem(3000, std::string(16, 'c')); + x2.finalise(); + + { + auto q = x1.initialQuery(); + std::cout << to_hex(q) << std::endl; + + std::vector have, need; + auto q2 = x2.handleQuery(q, have, need); + + for (auto &s : have) std::cout << "HAVE: " << to_hex(s) << std::endl; + for (auto &s : need) std::cout << "NEED: " << to_hex(s) << std::endl; + std::cout << to_hex(q2) << std::endl; + } + + return 0; +}