diff --git a/.gitignore b/.gitignore index 920b466..627fefc 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,4 @@ *.o /strfry /strfry-db/*.mdb +.vscode/* diff --git a/TODO b/TODO index 88c68b4..b5fb7eb 100644 --- a/TODO +++ b/TODO @@ -6,7 +6,7 @@ 0.2 release ? why isn't the LMDB mapping CLOEXEC - plugin for stream + ? plugin for stream: make sure bortloff@github didn't make a mess of it fix sync * logging of bytes up/down * up/both directions diff --git a/src/PluginWritePolicy.h b/src/PluginWritePolicy.h index 5230927..452c4a1 100644 --- a/src/PluginWritePolicy.h +++ b/src/PluginWritePolicy.h @@ -53,7 +53,7 @@ struct PluginWritePolicy { std::unique_ptr running; - WritePolicyResult acceptEvent(std::string_view jsonStr, uint64_t receivedAt, EventSourceType sourceType, std::string_view sourceInfo, std::string &okMsg) { + WritePolicyResult acceptEvent(const tao::json::value &evJson, uint64_t receivedAt, EventSourceType sourceType, std::string_view sourceInfo, std::string &okMsg) { const auto &pluginPath = cfg().relay__writePolicy__plugin; if (pluginPath.size() == 0) { @@ -81,7 +81,7 @@ struct PluginWritePolicy { auto request = tao::json::value({ { "type", "new" }, - { "event", tao::json::from_string(jsonStr) }, + { "event", evJson }, { "receivedAt", receivedAt / 1000000 }, { "sourceType", eventSourceTypeToStr(sourceType) }, { "sourceInfo", sourceType == EventSourceType::IP4 || sourceType == EventSourceType::IP6 ? renderIP(sourceInfo) : sourceInfo }, diff --git a/src/RelayWriter.cpp b/src/RelayWriter.cpp index 217148d..c1e2639 100644 --- a/src/RelayWriter.cpp +++ b/src/RelayWriter.cpp @@ -17,9 +17,10 @@ void RelayServer::runWriter(ThreadPool::Thread &thr) { for (auto &newMsg : newMsgs) { if (auto msg = std::get_if(&newMsg.msg)) { + tao::json::value evJson = tao::json::from_string(msg->jsonStr); EventSourceType sourceType = msg->ipAddr.size() == 4 ? EventSourceType::IP4 : EventSourceType::IP6; std::string okMsg; - auto res = writePolicy.acceptEvent(msg->jsonStr, msg->receivedAt, sourceType, msg->ipAddr, okMsg); + auto res = writePolicy.acceptEvent(evJson, msg->receivedAt, sourceType, msg->ipAddr, okMsg); if (res == WritePolicyResult::Accept) { newEvents.emplace_back(std::move(msg->flatStr), std::move(msg->jsonStr), msg->receivedAt, sourceType, std::move(msg->ipAddr), msg); diff --git a/src/WSConnection.h b/src/WSConnection.h index 73ecb5e..9ecfdfb 100644 --- a/src/WSConnection.h +++ b/src/WSConnection.h @@ -25,6 +25,7 @@ class WSConnection { std::function onTrigger; bool reconnect = true; uint64_t reconnectDelayMilliseconds = 5'000; + std::string remoteAddr; // Should only be called from the websocket thread (ie within an onConnect or onMessage callback) void send(std::string_view msg, uWS::OpCode op = uWS::OpCode::TEXT, size_t *compressedSize = nullptr) { @@ -57,8 +58,8 @@ class WSConnection { currWs = nullptr; } - std::string addr = ws->getAddress().address; - LI << "Connected to " << addr; + remoteAddr = ws->getAddress().address; + LI << "Connected to " << remoteAddr; { int optval = 1; diff --git a/src/cmd_stream.cpp b/src/cmd_stream.cpp index 96799d7..28b678d 100644 --- a/src/cmd_stream.cpp +++ b/src/cmd_stream.cpp @@ -10,6 +10,8 @@ #include "WSConnection.h" #include "events.h" +#include "PluginWritePolicy.h" + static const char USAGE[] = R"( @@ -30,12 +32,14 @@ void cmd_stream(const std::vector &subArgs) { if (dir != "up" && dir != "down" && dir != "both") throw herr("invalid direction: ", dir, ". Should be one of up/down/both"); - flat_hash_set downloadedIds; WriterPipeline writer; WSConnection ws(url); Decompressor decomp; + PluginWritePolicy writePolicy; + + ws.onConnect = [&]{ if (dir == "down" || dir == "both") { auto encoded = tao::json::to_string(tao::json::value::array({ "REQ", "sub", tao::json::value({ { "limit", 0 } }) })); @@ -63,8 +67,16 @@ void cmd_stream(const std::vector &subArgs) { if (dir == "down" || dir == "both") { if (origJson.get_array().size() < 3) throw herr("array too short"); auto &evJson = origJson.at(2); - downloadedIds.emplace(from_hex(evJson.at("id").get_string())); - writer.inbox.push_move({ std::move(evJson), EventSourceType::Stream, url }); + + std::string okMsg; + auto res = writePolicy.acceptEvent(evJson, hoytech::curr_time_s(), EventSourceType::Stream, ws.remoteAddr, okMsg); + if (res == WritePolicyResult::Accept) { + downloadedIds.emplace(from_hex(evJson.at("id").get_string())); + writer.inbox.push_move({ std::move(evJson), EventSourceType::Stream, url }); + } else { + LI << "[" << ws.remoteAddr << "] write policy blocked event " << evJson.at("id").get_string() << ": " << okMsg; + } + } else { LW << "Unexpected EVENT"; }