From ec4a52af488c3cc7117aca27fce2f3b26ccf52a8 Mon Sep 17 00:00:00 2001 From: Vitaliy Filippov Date: Sun, 26 Apr 2020 01:33:50 +0300 Subject: [PATCH] Fix websocket (and timer!) bugs --- osd_cluster.cpp | 5 ++ osd_http.cpp | 122 ++++++++++++++++++++++++++++++++------------ osd_http.h | 1 + osd_main.cpp | 4 +- timerfd_manager.cpp | 14 ++--- 5 files changed, 107 insertions(+), 39 deletions(-) diff --git a/osd_cluster.cpp b/osd_cluster.cpp index 33537317..0222df37 100644 --- a/osd_cluster.cpp +++ b/osd_cluster.cpp @@ -401,6 +401,11 @@ void osd_t::force_stop() exit(0); }); } + else + { + printf("[OSD %lu] Force stopping\n", this->osd_num); + exit(0); + } } void osd_t::load_pgs() diff --git a/osd_http.cpp b/osd_http.cpp index bb880f87..076ab492 100644 --- a/osd_http.cpp +++ b/osd_http.cpp @@ -13,6 +13,8 @@ #define READ_BUFFER_SIZE 9000 static int extract_port(std::string & host); +static std::string strtolower(const std::string & in); +static std::string trim(const std::string & in); static std::string ws_format_frame(int type, uint64_t size); static bool ws_parse_frame(std::string & buf, int & type, std::string & res); @@ -26,10 +28,12 @@ struct http_co_t int request_timeout = 0; std::string host; std::string request; + std::string ws_outbox; std::string response; bool want_streaming; http_response_t parsed; + uint64_t target_response_size = 0; int state = 0; int peer_fd = -1; @@ -47,7 +51,7 @@ struct http_co_t websocket_t ws; ~http_co_t(); - void connect(); + void start_connection(); void handle_connect_result(); void submit_read(); void submit_send(); @@ -60,7 +64,7 @@ struct http_co_t #define HTTP_CO_REQUEST_SENT 3 #define HTTP_CO_HEADERS_RECEIVED 4 #define HTTP_CO_WEBSOCKET 5 -#define HTTP_CO_STREAMING_CHUNKED 6 +#define HTTP_CO_CHUNKED 6 void osd_t::http_request(std::string host, std::string request, bool streaming, std::function callback) { @@ -75,7 +79,7 @@ void osd_t::http_request(std::string host, std::string request, bool streaming, handler->request = request; handler->callback = callback; handler->ws.co = handler; - handler->connect(); + handler->start_connection(); } void osd_t::http_request_json(std::string host, std::string request, @@ -90,14 +94,14 @@ void osd_t::http_request_json(std::string host, std::string request, } if (res->status_code != 200) { - callback("HTTP "+std::to_string(res->status_code)+" "+res->status_line+" body: "+res->body, json11::Json()); + callback("HTTP "+std::to_string(res->status_code)+" "+res->status_line+" body: "+trim(res->body), json11::Json()); return; } std::string json_err; json11::Json data = json11::Json::parse(res->body, json_err); if (json_err != "") { - callback("Bad JSON: "+json_err+" (response: "+res->body+")", json11::Json()); + callback("Bad JSON: "+json_err+" (response: "+trim(res->body)+")", json11::Json()); return; } callback(std::string(), data); @@ -124,7 +128,7 @@ websocket_t* osd_t::open_websocket(std::string host, std::string path, std::func handler->request = request; handler->callback = callback; handler->ws.co = handler; - handler->connect(); + handler->start_connection(); return &handler->ws; } @@ -133,6 +137,11 @@ void websocket_t::post_message(int type, const std::string & msg) co->post_message(type, msg); } +void websocket_t::close() +{ + delete co; +} + http_co_t::~http_co_t() { if (timeout_id >= 0) @@ -183,7 +192,7 @@ http_co_t::~http_co_t() callback(&parsed); } -void http_co_t::connect() +void http_co_t::start_connection() { int port = extract_port(host); struct sockaddr_in addr; @@ -224,11 +233,11 @@ void http_co_t::connect() } else { - if (epoll_events & EPOLLIN) + if (this->epoll_events & EPOLLIN) { submit_read(); } - else if (epoll_events & (EPOLLRDHUP|EPOLLERR)) + else if (this->epoll_events & (EPOLLRDHUP|EPOLLERR)) { delete this; } @@ -396,21 +405,44 @@ void http_co_t::handle_read() { // Don't care about validating the key state = HTTP_CO_WEBSOCKET; - request = ""; + request = ws_outbox; + ws_outbox = ""; sent = 0; + submit_send(); } - else if (want_streaming && parsed.headers["transfer-encoding"] == "chunked") + else if (parsed.headers["transfer-encoding"] == "chunked") { - state = HTTP_CO_STREAMING_CHUNKED; + state = HTTP_CO_CHUNKED; + } + else if (parsed.headers["connection"] != "close") + { + target_response_size = stoull_full(parsed.headers["content-length"]); + if (!target_response_size) + { + // Sorry, unsupported response + delete this; + return; + } } } } - if (state == HTTP_CO_STREAMING_CHUNKED && response.size() > 0) + if (state == HTTP_CO_HEADERS_RECEIVED && target_response_size > 0 && response.size() >= target_response_size) + { + delete this; + return; + } + if (state == HTTP_CO_CHUNKED && response.size() > 0) { int prev = 0, pos = 0; while ((pos = response.find("\r\n", prev)) >= prev) { uint64_t len = strtoull(response.c_str()+prev, NULL, 16); + if (!len) + { + // Zero length chunk indicates EOF + delete this; + return; + } if (response.size() < pos+2+len+2) { break; @@ -422,7 +454,7 @@ void http_co_t::handle_read() { response = response.substr(prev); } - if (parsed.body.size() > 0) + if (want_streaming && parsed.body.size() > 0) { callback(&parsed); parsed.body = ""; @@ -440,9 +472,17 @@ void http_co_t::handle_read() void http_co_t::post_message(int type, const std::string & msg) { - request += ws_format_frame(type, msg.size()); - request += msg; - submit_send(); + if (state == HTTP_CO_WEBSOCKET) + { + request += ws_format_frame(type, msg.size()); + request += msg; + submit_send(); + } + else + { + ws_outbox += ws_format_frame(type, msg.size()); + ws_outbox += msg; + } } uint64_t stoull_full(const std::string & str, int base) @@ -487,13 +527,12 @@ void parse_http_headers(std::string & res, http_response_t *parsed) int p2 = header.find(":"); if (p2 >= 0) { - std::string key = header.substr(0, p2); - for (int i = 0; i < key.length(); i++) - key[i] = tolower(key[i]); + std::string key = strtolower(header.substr(0, p2)); int p3 = p2+1; while (p3 < header.length() && isblank(header[p3])) p3++; - parsed->headers[key] = header.substr(p3); + parsed->headers[key] = key == "connection" || key == "upgrade" || key == "transfer-encoding" + ? strtolower(header.substr(p3)) : header.substr(p3); } prev = pos+2; } @@ -541,9 +580,9 @@ static bool ws_parse_frame(std::string & buf, int & type, std::string & res) return false; } type = buf[0] & ~0x80; - bool mask = buf[1] & 0x80; + bool mask = !!(buf[1] & 0x80); hdr += mask ? 4 : 0; - uint64_t len = (buf[1] & ~0x80); + uint64_t len = ((uint8_t)buf[1] & ~0x80); if (len == 126) { hdr += 2; @@ -551,7 +590,7 @@ static bool ws_parse_frame(std::string & buf, int & type, std::string & res) { return false; } - len = ((uint64_t)buf[2] << 8) | ((uint64_t)buf[3] << 0); + len = ((uint64_t)(uint8_t)buf[2] << 8) | ((uint64_t)(uint8_t)buf[3] << 0); } else if (len == 127) { @@ -560,14 +599,14 @@ static bool ws_parse_frame(std::string & buf, int & type, std::string & res) { return false; } - len = ((uint64_t)buf[2] << 56) | - ((uint64_t)buf[3] << 48) | - ((uint64_t)buf[4] << 40) | - ((uint64_t)buf[5] << 32) | - ((uint64_t)buf[6] << 24) | - ((uint64_t)buf[7] << 16) | - ((uint64_t)buf[8] << 8) | - ((uint64_t)buf[9] << 0); + len = ((uint64_t)(uint8_t)buf[2] << 56) | + ((uint64_t)(uint8_t)buf[3] << 48) | + ((uint64_t)(uint8_t)buf[4] << 40) | + ((uint64_t)(uint8_t)buf[5] << 32) | + ((uint64_t)(uint8_t)buf[6] << 24) | + ((uint64_t)(uint8_t)buf[7] << 16) | + ((uint64_t)(uint8_t)buf[8] << 8) | + ((uint64_t)(uint8_t)buf[9] << 0); } if (buf.size() < hdr+len) { @@ -633,3 +672,22 @@ static int extract_port(std::string & host) } return port; } + +static std::string strtolower(const std::string & in) +{ + std::string s = in; + for (int i = 0; i < s.length(); i++) + { + s[i] = tolower(s[i]); + } + return s; +} + +static std::string trim(const std::string & in) +{ + int begin = in.find_first_not_of(" \n\r\t"); + if (begin == -1) + return ""; + int end = in.find_last_not_of(" \n\r\t"); + return in.substr(begin, end+1-begin); +} diff --git a/osd_http.h b/osd_http.h index 71c51317..10c1e866 100644 --- a/osd_http.h +++ b/osd_http.h @@ -27,6 +27,7 @@ struct websocket_t { http_co_t *co; void post_message(int type, const std::string & msg); + void close(); }; void parse_http_headers(std::string & res, http_response_t *parsed); diff --git a/osd_main.cpp b/osd_main.cpp index f1146ca2..15085be3 100644 --- a/osd_main.cpp +++ b/osd_main.cpp @@ -3,11 +3,13 @@ #include static osd_t *osd = NULL; +static bool force_stopping = false; static void handle_sigint(int sig) { - if (osd) + if (osd && !force_stopping) { + force_stopping = true; osd->force_stop(); return; } diff --git a/timerfd_manager.cpp b/timerfd_manager.cpp index 7c2f95bf..298dee58 100644 --- a/timerfd_manager.cpp +++ b/timerfd_manager.cpp @@ -150,15 +150,17 @@ void timerfd_manager_t::set_wait() read(timerfd, &n, 8); if (nearest >= 0) { - if (!timers[nearest].repeat) - { - timers.erase(timers.begin()+nearest, timers.begin()+nearest+1); - } - else + int nearest_id = timers[nearest].id; + auto cb = timers[nearest].callback; + if (timers[nearest].repeat) { inc_timer(timers[nearest]); } - timers[nearest].callback(timers[nearest].id); + else + { + timers.erase(timers.begin()+nearest, timers.begin()+nearest+1); + } + cb(nearest_id); nearest = -1; } wait_state = 0;