Fix websocket (and timer!) bugs

trace-sqes
Vitaliy Filippov 2020-04-26 01:33:50 +03:00
parent 268b497c0b
commit ec4a52af48
5 changed files with 107 additions and 39 deletions

View File

@ -401,6 +401,11 @@ void osd_t::force_stop()
exit(0); exit(0);
}); });
} }
else
{
printf("[OSD %lu] Force stopping\n", this->osd_num);
exit(0);
}
} }
void osd_t::load_pgs() void osd_t::load_pgs()

View File

@ -13,6 +13,8 @@
#define READ_BUFFER_SIZE 9000 #define READ_BUFFER_SIZE 9000
static int extract_port(std::string & host); static int extract_port(std::string & host);
static std::string strtolower(const std::string & in);
static std::string trim(const std::string & in);
static std::string ws_format_frame(int type, uint64_t size); static std::string ws_format_frame(int type, uint64_t size);
static bool ws_parse_frame(std::string & buf, int & type, std::string & res); static bool ws_parse_frame(std::string & buf, int & type, std::string & res);
@ -26,10 +28,12 @@ struct http_co_t
int request_timeout = 0; int request_timeout = 0;
std::string host; std::string host;
std::string request; std::string request;
std::string ws_outbox;
std::string response; std::string response;
bool want_streaming; bool want_streaming;
http_response_t parsed; http_response_t parsed;
uint64_t target_response_size = 0;
int state = 0; int state = 0;
int peer_fd = -1; int peer_fd = -1;
@ -47,7 +51,7 @@ struct http_co_t
websocket_t ws; websocket_t ws;
~http_co_t(); ~http_co_t();
void connect(); void start_connection();
void handle_connect_result(); void handle_connect_result();
void submit_read(); void submit_read();
void submit_send(); void submit_send();
@ -60,7 +64,7 @@ struct http_co_t
#define HTTP_CO_REQUEST_SENT 3 #define HTTP_CO_REQUEST_SENT 3
#define HTTP_CO_HEADERS_RECEIVED 4 #define HTTP_CO_HEADERS_RECEIVED 4
#define HTTP_CO_WEBSOCKET 5 #define HTTP_CO_WEBSOCKET 5
#define HTTP_CO_STREAMING_CHUNKED 6 #define HTTP_CO_CHUNKED 6
void osd_t::http_request(std::string host, std::string request, bool streaming, std::function<void(const http_response_t *response)> callback) void osd_t::http_request(std::string host, std::string request, bool streaming, std::function<void(const http_response_t *response)> callback)
{ {
@ -75,7 +79,7 @@ void osd_t::http_request(std::string host, std::string request, bool streaming,
handler->request = request; handler->request = request;
handler->callback = callback; handler->callback = callback;
handler->ws.co = handler; handler->ws.co = handler;
handler->connect(); handler->start_connection();
} }
void osd_t::http_request_json(std::string host, std::string request, void osd_t::http_request_json(std::string host, std::string request,
@ -90,14 +94,14 @@ void osd_t::http_request_json(std::string host, std::string request,
} }
if (res->status_code != 200) if (res->status_code != 200)
{ {
callback("HTTP "+std::to_string(res->status_code)+" "+res->status_line+" body: "+res->body, json11::Json()); callback("HTTP "+std::to_string(res->status_code)+" "+res->status_line+" body: "+trim(res->body), json11::Json());
return; return;
} }
std::string json_err; std::string json_err;
json11::Json data = json11::Json::parse(res->body, json_err); json11::Json data = json11::Json::parse(res->body, json_err);
if (json_err != "") if (json_err != "")
{ {
callback("Bad JSON: "+json_err+" (response: "+res->body+")", json11::Json()); callback("Bad JSON: "+json_err+" (response: "+trim(res->body)+")", json11::Json());
return; return;
} }
callback(std::string(), data); callback(std::string(), data);
@ -124,7 +128,7 @@ websocket_t* osd_t::open_websocket(std::string host, std::string path, std::func
handler->request = request; handler->request = request;
handler->callback = callback; handler->callback = callback;
handler->ws.co = handler; handler->ws.co = handler;
handler->connect(); handler->start_connection();
return &handler->ws; return &handler->ws;
} }
@ -133,6 +137,11 @@ void websocket_t::post_message(int type, const std::string & msg)
co->post_message(type, msg); co->post_message(type, msg);
} }
void websocket_t::close()
{
delete co;
}
http_co_t::~http_co_t() http_co_t::~http_co_t()
{ {
if (timeout_id >= 0) if (timeout_id >= 0)
@ -183,7 +192,7 @@ http_co_t::~http_co_t()
callback(&parsed); callback(&parsed);
} }
void http_co_t::connect() void http_co_t::start_connection()
{ {
int port = extract_port(host); int port = extract_port(host);
struct sockaddr_in addr; struct sockaddr_in addr;
@ -224,11 +233,11 @@ void http_co_t::connect()
} }
else else
{ {
if (epoll_events & EPOLLIN) if (this->epoll_events & EPOLLIN)
{ {
submit_read(); submit_read();
} }
else if (epoll_events & (EPOLLRDHUP|EPOLLERR)) else if (this->epoll_events & (EPOLLRDHUP|EPOLLERR))
{ {
delete this; delete this;
} }
@ -396,21 +405,44 @@ void http_co_t::handle_read()
{ {
// Don't care about validating the key // Don't care about validating the key
state = HTTP_CO_WEBSOCKET; state = HTTP_CO_WEBSOCKET;
request = ""; request = ws_outbox;
ws_outbox = "";
sent = 0; sent = 0;
submit_send();
} }
else if (want_streaming && parsed.headers["transfer-encoding"] == "chunked") else if (parsed.headers["transfer-encoding"] == "chunked")
{ {
state = HTTP_CO_STREAMING_CHUNKED; state = HTTP_CO_CHUNKED;
}
else if (parsed.headers["connection"] != "close")
{
target_response_size = stoull_full(parsed.headers["content-length"]);
if (!target_response_size)
{
// Sorry, unsupported response
delete this;
return;
}
} }
} }
} }
if (state == HTTP_CO_STREAMING_CHUNKED && response.size() > 0) if (state == HTTP_CO_HEADERS_RECEIVED && target_response_size > 0 && response.size() >= target_response_size)
{
delete this;
return;
}
if (state == HTTP_CO_CHUNKED && response.size() > 0)
{ {
int prev = 0, pos = 0; int prev = 0, pos = 0;
while ((pos = response.find("\r\n", prev)) >= prev) while ((pos = response.find("\r\n", prev)) >= prev)
{ {
uint64_t len = strtoull(response.c_str()+prev, NULL, 16); uint64_t len = strtoull(response.c_str()+prev, NULL, 16);
if (!len)
{
// Zero length chunk indicates EOF
delete this;
return;
}
if (response.size() < pos+2+len+2) if (response.size() < pos+2+len+2)
{ {
break; break;
@ -422,7 +454,7 @@ void http_co_t::handle_read()
{ {
response = response.substr(prev); response = response.substr(prev);
} }
if (parsed.body.size() > 0) if (want_streaming && parsed.body.size() > 0)
{ {
callback(&parsed); callback(&parsed);
parsed.body = ""; parsed.body = "";
@ -440,9 +472,17 @@ void http_co_t::handle_read()
void http_co_t::post_message(int type, const std::string & msg) void http_co_t::post_message(int type, const std::string & msg)
{ {
request += ws_format_frame(type, msg.size()); if (state == HTTP_CO_WEBSOCKET)
request += msg; {
submit_send(); request += ws_format_frame(type, msg.size());
request += msg;
submit_send();
}
else
{
ws_outbox += ws_format_frame(type, msg.size());
ws_outbox += msg;
}
} }
uint64_t stoull_full(const std::string & str, int base) uint64_t stoull_full(const std::string & str, int base)
@ -487,13 +527,12 @@ void parse_http_headers(std::string & res, http_response_t *parsed)
int p2 = header.find(":"); int p2 = header.find(":");
if (p2 >= 0) if (p2 >= 0)
{ {
std::string key = header.substr(0, p2); std::string key = strtolower(header.substr(0, p2));
for (int i = 0; i < key.length(); i++)
key[i] = tolower(key[i]);
int p3 = p2+1; int p3 = p2+1;
while (p3 < header.length() && isblank(header[p3])) while (p3 < header.length() && isblank(header[p3]))
p3++; p3++;
parsed->headers[key] = header.substr(p3); parsed->headers[key] = key == "connection" || key == "upgrade" || key == "transfer-encoding"
? strtolower(header.substr(p3)) : header.substr(p3);
} }
prev = pos+2; prev = pos+2;
} }
@ -541,9 +580,9 @@ static bool ws_parse_frame(std::string & buf, int & type, std::string & res)
return false; return false;
} }
type = buf[0] & ~0x80; type = buf[0] & ~0x80;
bool mask = buf[1] & 0x80; bool mask = !!(buf[1] & 0x80);
hdr += mask ? 4 : 0; hdr += mask ? 4 : 0;
uint64_t len = (buf[1] & ~0x80); uint64_t len = ((uint8_t)buf[1] & ~0x80);
if (len == 126) if (len == 126)
{ {
hdr += 2; hdr += 2;
@ -551,7 +590,7 @@ static bool ws_parse_frame(std::string & buf, int & type, std::string & res)
{ {
return false; return false;
} }
len = ((uint64_t)buf[2] << 8) | ((uint64_t)buf[3] << 0); len = ((uint64_t)(uint8_t)buf[2] << 8) | ((uint64_t)(uint8_t)buf[3] << 0);
} }
else if (len == 127) else if (len == 127)
{ {
@ -560,14 +599,14 @@ static bool ws_parse_frame(std::string & buf, int & type, std::string & res)
{ {
return false; return false;
} }
len = ((uint64_t)buf[2] << 56) | len = ((uint64_t)(uint8_t)buf[2] << 56) |
((uint64_t)buf[3] << 48) | ((uint64_t)(uint8_t)buf[3] << 48) |
((uint64_t)buf[4] << 40) | ((uint64_t)(uint8_t)buf[4] << 40) |
((uint64_t)buf[5] << 32) | ((uint64_t)(uint8_t)buf[5] << 32) |
((uint64_t)buf[6] << 24) | ((uint64_t)(uint8_t)buf[6] << 24) |
((uint64_t)buf[7] << 16) | ((uint64_t)(uint8_t)buf[7] << 16) |
((uint64_t)buf[8] << 8) | ((uint64_t)(uint8_t)buf[8] << 8) |
((uint64_t)buf[9] << 0); ((uint64_t)(uint8_t)buf[9] << 0);
} }
if (buf.size() < hdr+len) if (buf.size() < hdr+len)
{ {
@ -633,3 +672,22 @@ static int extract_port(std::string & host)
} }
return port; return port;
} }
static std::string strtolower(const std::string & in)
{
std::string s = in;
for (int i = 0; i < s.length(); i++)
{
s[i] = tolower(s[i]);
}
return s;
}
static std::string trim(const std::string & in)
{
int begin = in.find_first_not_of(" \n\r\t");
if (begin == -1)
return "";
int end = in.find_last_not_of(" \n\r\t");
return in.substr(begin, end+1-begin);
}

View File

@ -27,6 +27,7 @@ struct websocket_t
{ {
http_co_t *co; http_co_t *co;
void post_message(int type, const std::string & msg); void post_message(int type, const std::string & msg);
void close();
}; };
void parse_http_headers(std::string & res, http_response_t *parsed); void parse_http_headers(std::string & res, http_response_t *parsed);

View File

@ -3,11 +3,13 @@
#include <signal.h> #include <signal.h>
static osd_t *osd = NULL; static osd_t *osd = NULL;
static bool force_stopping = false;
static void handle_sigint(int sig) static void handle_sigint(int sig)
{ {
if (osd) if (osd && !force_stopping)
{ {
force_stopping = true;
osd->force_stop(); osd->force_stop();
return; return;
} }

View File

@ -150,15 +150,17 @@ void timerfd_manager_t::set_wait()
read(timerfd, &n, 8); read(timerfd, &n, 8);
if (nearest >= 0) if (nearest >= 0)
{ {
if (!timers[nearest].repeat) int nearest_id = timers[nearest].id;
{ auto cb = timers[nearest].callback;
timers.erase(timers.begin()+nearest, timers.begin()+nearest+1); if (timers[nearest].repeat)
}
else
{ {
inc_timer(timers[nearest]); inc_timer(timers[nearest]);
} }
timers[nearest].callback(timers[nearest].id); else
{
timers.erase(timers.begin()+nearest, timers.begin()+nearest+1);
}
cb(nearest_id);
nearest = -1; nearest = -1;
} }
wait_state = 0; wait_state = 0;