diff --git a/.gitignore b/.gitignore index 51ac26b..918b7c1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ *.o proxy compile_commands.json +.cache diff --git a/proxlib b/proxlib index 7093a67..0f28968 100755 Binary files a/proxlib and b/proxlib differ diff --git a/proxlib.c b/proxlib.c index 07aff7c..4b44edc 100644 --- a/proxlib.c +++ b/proxlib.c @@ -7,6 +7,7 @@ #include #include #include +#include #include "proxlib.h" #include "parslib/parslib.h" @@ -98,21 +99,21 @@ int pull_content_length(int fd, int len, int *msgbuff_len, char **msgbuff) { int line_len = len; char *line = (char *) calloc(1, line_len); if (!line) { - return err_mem; + return ERR_MEM; } int bytes = 0; do { ret = recv(fd, line+bytes, line_len-bytes, MSG_WAITALL); if (ret < 0) { - return err_recv; + return ERR_MEM; } bytes += ret; } while (bytes < line_len); *msgbuff = (char *) realloc(*msgbuff, *msgbuff_len+line_len); if (!*msgbuff) { - return err_mem; + return ERR_MEM; } memcpy(*msgbuff+*msgbuff_len, line, line_len); @@ -166,9 +167,6 @@ int pull_chunked_encoding(int fd, int *msgbuff_len, char **msgbuff) { memcpy(*msgbuff+*msgbuff_len, line, line_len); *msgbuff_len += line_len; - if (debug == 1) { - fprintf(stdout, "debug - [upstream] received chunk:%d\n", line_len); - } free(line); } return 0; @@ -176,137 +174,27 @@ int pull_chunked_encoding(int fd, int *msgbuff_len, char **msgbuff) { } void do_err(void) { - fprintf(stderr, "[%s] failed with error code %d=%s\n", - states_str[statem], err, errs_str[err]); -} - -int do_fwd_clt(struct conn *conn) { - int bytes = 0; - int ret = 0; - while (bytes < conn->srvbuff_len) { - ret = write(conn->cltfd, conn->srvbuff+bytes, conn->srvbuff_len-bytes); - if (ret < 0) - return -1; - bytes += ret; - } - - return 0; -} - -int do_rcv_srv(struct conn *conn) { - int ret = 0; - char *line = NULL; - char *msgbuff = NULL; - int line_len = 0; - int msgbuff_len = 0; - - // response line - ret = read_line(conn->srvfd, &line_len, &line, &msgbuff_len, &msgbuff); - if (ret < 0) { - return err_recv; - } - - if (debug == 1) { - fprintf(stdout, "debug - [upstream] received line: %s\n", line); - } - - ret = parestitl(line, line_len, &(conn->srvres.titl)); - if (ret < 0) { - return err_parstitle; - } - - if (debug == 1) { - fprintf(stdout, "debug - [upstream] parsed response line\n"); - } - - free(line); - - // headers - int next_header = 1; - while (next_header) { - ret = read_line(conn->srvfd, &line_len, &line, &msgbuff_len, &msgbuff); - if (ret < 0) { - return err_recv; - } - - if (line_len == 0) { - if (debug == 1) { - fprintf(stdout, "debug - [upstream] reached end of headers\n"); - } - next_header = 0; - continue; - } - - if (debug == 1) { - fprintf(stdout, "debug - [upstream] received line: %s\n", line); - } - - ret = parshfield(line, line_len, conn->srvres.hentries); - if (ret < 0) { - return err_parsheader; - } - - if (debug == 1) { - fprintf(stdout, "debug - parsed header field\n"); - } - - free(line); - } - - // body - struct httpares *res = &conn->srvres; - struct point *content_length_entry = &res->hentries[header_content_length]; - struct point *transfer_encoding_entry = &res->hentries[header_transfer_encoding]; - if (content_length_entry->er) { - int content_length = 0; - - ret = stoin(content_length_entry->er, content_length_entry->len, &content_length); - if (ret < 0) { - return err_pars; - } - - ret = pull_content_length(conn->srvfd, content_length, &msgbuff_len, &msgbuff); - if (ret < 0) { - return err_recv; - } - fprintf(stdout, "Successfully received normal body from server\n"); - } else if (transfer_encoding_entry->er && strcmp(transfer_encoding_entry->er, "chunked") == 0) { - ret = pull_chunked_encoding(conn->srvfd, &msgbuff_len, &msgbuff); - if (ret < 0) { - return err_recv; - } - fprintf(stdout, "Successfully received chunked body from server\n"); - } else { - return err_support; - } - - fprintf(stdout, "srvbuff:%p+srvbuff_len:%d\n", conn->srvbuff, conn->srvbuff_len); - conn->srvbuff = msgbuff; - conn->srvbuff_len = msgbuff_len; - - return 0; + fprintf(stderr, "failed with error code %d\n", err); } int do_con_srv(struct conn *conn) { + statem = state_con_srv; + int ret = 0; struct httpareq *req = &conn->cltreq; struct point *host = &req->hentries[header_host]; if (host->er == NULL) { - return err_pars; + return ERR_PARS; } struct hostinfo *info = (struct hostinfo *) calloc(1, sizeof(struct hostinfo)); if (!info) { - return err_mem; + return ERR_MEM; } ret = pahostinfo(host->er, host->len, info); if (ret < 0) { - return err_pars; - } - - if (debug <= 2) { - fprintf(stdout, "Establishing connection with upstream: %.*s : %.*s\n", info->hostname_len, info->hostname, info->service_len, info->service); + return ERR_PARS; } struct addrinfo hints; @@ -321,7 +209,7 @@ int do_con_srv(struct conn *conn) { free(info->hostname); free(info->service); free(info); - return err_pars; + return ERR_PARS; } ret = conn->srvfd = socket(res->ai_family, res->ai_socktype, @@ -331,7 +219,7 @@ int do_con_srv(struct conn *conn) { free(info->hostname); free(info->service); free(info); - return err_pars; + return ERR_PARS; } ret = connect(conn->srvfd, res->ai_addr, res->ai_addrlen); @@ -340,26 +228,15 @@ int do_con_srv(struct conn *conn) { free(info->hostname); free(info->service); free(info); - return err_pars; + return ERR_PARS; } return ret; } -int do_fwd_srv(struct conn *conn) { - int bytes = 0; - int ret = 0; - while (bytes < conn->cltbuff_len) { - ret = write(conn->srvfd, conn->cltbuff+bytes, conn->cltbuff_len-bytes); - if (ret < 0) - return -1; - bytes += ret; - } - - return 0; -} - int do_rcv_clt(struct conn *conn) { + statem = state_rcv_clt; + int ret = 0; char *line = NULL; char *msgbuff = NULL; @@ -367,23 +244,14 @@ int do_rcv_clt(struct conn *conn) { int msgbuff_len = 0; // request line - fprintf(stdout, "debug - listening for new lines from client\n"); ret = read_line(conn->cltfd, &line_len, &line, &msgbuff_len, &msgbuff); if (ret < 0) { - return err_recv; - } - - if (debug == 1) { - fprintf(stdout, "debug - received line of %d bytes from client\n", line_len); + return ERR_RECV; } ret = pareqtitl(line, line_len, &(conn->cltreq.titl)); if (ret < 0) { - return err_parstitle; - } - - if (debug == 1) { - fprintf(stdout, "[do_rcv_clt] parsed request line\n"); + return ERR_PARSTITLE; } free(line); @@ -393,28 +261,17 @@ int do_rcv_clt(struct conn *conn) { while (next_header) { ret = read_line(conn->cltfd, &line_len, &line, &msgbuff_len, &msgbuff); if (ret < 0) { - return err_recv; + return ERR_RECV; } if (line_len == 0) { - if (debug == 1) { - fprintf(stdout, "[do_rcv_clt] reached end of headers for the client\n"); - } next_header = 0; continue; } - if (debug == 1) { - fprintf(stdout, "debug - received line: %s\n", line); - } - ret = parshfield(line, line_len, conn->cltreq.hentries); if (ret < 0) { - return err_parsheader; - } - - if (debug == 1) { - fprintf(stdout, "debug - parsed header field\n"); + return ERR_PARSHEADER; } free(line); @@ -429,21 +286,18 @@ int do_rcv_clt(struct conn *conn) { ret = stoin(content_length_entry->er, content_length_entry->len, &content_length); if (ret < 0) { - return err_pars; + return ERR_PARS; } ret = pull_content_length(conn->srvfd, content_length, &msgbuff_len, &msgbuff); if (ret < 0) { - return err_recv; + return ERR_RECV; } - - fprintf(stdout, "Successfully received normal body from server\n"); } else if (transfer_encoding_entry->er && strcmp(transfer_encoding_entry->er, "chunked") == 0) { ret = pull_chunked_encoding(conn->srvfd, &msgbuff_len, &msgbuff); if (ret < 0) { - return err_recv; + return ERR_RECV; } - fprintf(stdout, "Successfully received chunked body from server\n"); } conn->cltbuff = msgbuff; @@ -452,52 +306,125 @@ int do_rcv_clt(struct conn *conn) { return 0; } -void do_clear(struct conn *conn) { - statem = state_rcv_clt; - frepareq(&conn->cltreq); - frepares(&conn->srvres); - free(conn->cltbuff); - free(conn->srvbuff); -} +int read_buffer(int fd, char **buff, int *len) { + char *tmp = (char *) malloc(RELAY_BUFFER_SIZE); + if (!tmp) { + return ERR_MEM; + } + + memset(tmp, 0, RELAY_BUFFER_SIZE); + int bytes = recv(fd, tmp, RELAY_BUFFER_SIZE, 0); + if (bytes <= 0) { + free(tmp); + return ERR_RECV; + } + + *buff = realloc(*buff, *len+bytes); + if (!*buff) { + free(tmp); + return ERR_MEM; + } + + memcpy(*buff+*len, tmp, bytes); + *len += bytes; + + return 0; +} + +int write_buffer(int fd, char **buff, int *len) { + if (*len <= 0) { + *len = 0; + return 0; + } + + int writen = send(fd, *buff, *len, 0); + if (writen < 0) { + return ERR_SEND; + } + + char *trunc = (char *) malloc(*len-writen); + if (!trunc) { + return ERR_MEM; + } + + memcpy(trunc, *buff+writen, *len-writen); + + char *tofree = *buff; // FIXME: any better solution? + *buff = trunc; + *len -= writen; + free(tofree); + + return 0; +} void do_statem(struct conn *conn) { - int ret = 0; + int ret = 0; - for (int counter = 0; counter < MAX_BOUND; counter++) { - switch (statem) { - case state_rcv_clt: - ret = do_rcv_clt(conn); - break; - case state_con_srv: - ret = do_con_srv(conn); - break; - case state_fwd_srv: - ret = do_fwd_srv(conn); - break; - case state_rcv_srv: - ret = do_rcv_srv(conn); - break; - case state_fwd_clt: - ret = do_fwd_clt(conn); - break; - } - - if (ret > 0) { - err = ret; - } - - if (err) { - do_err(); - break; - } - - if (statem == state_fwd_clt) { - do_clear(conn); - continue; - } - - statem++; + ret = do_rcv_clt(conn); + if (ret < 0) { + err = ret; + do_err(); } + + ret = do_con_srv(conn); + if (ret < 0) { + err = ret; + do_err(); + } + + // relay the data between the two sockets until the end of time + ssize_t bytes_received; + struct pollfd fds[2]; + for (;;) { + memset(fds, 0, 2*sizeof(struct pollfd)); + fds[0].fd = conn->cltfd; + fds[1].fd = conn->srvfd; + + if (conn->srvbuff_len > 0) { + fds[0].events |= POLLOUT; + } + if (conn->cltbuff_len > 0) { + fds[1].events |= POLLOUT; + } + if (!conn->srvbuff_len) { + fds[1].events |= POLLIN; + } + if (!conn->cltbuff_len) { + fds[0].events |= POLLIN; + } + + ret = poll(fds, 2, 1000); + + if (fds[1].revents & POLLOUT) { + ret = write_buffer(conn->srvfd, &conn->cltbuff, &conn->cltbuff_len); + } + if (fds[1].revents & POLLIN) { + ret = read_buffer(conn->srvfd, &conn->srvbuff, &conn->srvbuff_len); + } + if (fds[0].revents & POLLIN) { + ret = read_buffer(conn->cltfd, &conn->cltbuff, &conn->cltbuff_len); + } + if (fds[0].revents & POLLOUT) { + ret = write_buffer(conn->cltfd, &conn->srvbuff, &conn->srvbuff_len); + } + if (fds[0].revents & POLLHUP) { + break; + } + if (ret < 0) { + break; + } + } + + if (conn->cltbuff_len > 0) { + write_buffer(conn->srvfd, &conn->cltbuff, &conn->cltbuff_len); + } + if (conn->srvbuff_len > 0) { + write_buffer(conn->cltfd, &conn->srvbuff, &conn->srvbuff_len); + } + + close(conn->cltfd); + close(conn->srvfd); + exit(0); // child die } int do_srv(void) { @@ -535,9 +462,8 @@ int do_srv(void) { return -1; } - fprintf(stdout, "Listening on port %d\n", PROXY_PORT); - for (;;) { + fprintf(stdout, "listening for sockets\n"); struct sockaddr_in new_clt_addr; socklen_t new_clt_addr_len= sizeof(new_clt_addr); int new_clt_sock; @@ -549,6 +475,7 @@ int do_srv(void) { "client\n"); return -1; } + fprintf(stdout, "accepted new client socket\n"); ret = fork(); if (ret < 0) { @@ -558,12 +485,11 @@ int do_srv(void) { } if (ret > 0) { - fprintf(stdout, "[PROGRAM] Successfully forked a new child process" - " with PID %d\n", ret); + fprintf(stdout, "+new request process:%d(pid)\n", ret); continue; } - // child + // request process struct conn *conn = (struct conn *) calloc(1, sizeof(struct conn)); if (!conn) { fprintf(stderr, "Not enough dynamic memory to establish connection\n"); @@ -574,7 +500,6 @@ int do_srv(void) { statem = state_rcv_clt; do_statem(conn); free(conn); - return 0; } @@ -588,7 +513,10 @@ int main(int argc, char *argv[]) { return -1; } - return do_srv(); + ret = do_srv(); + if (ret < 0) { + return -1; + } fretres(); } diff --git a/proxlib.h b/proxlib.h index 88fe106..93864e2 100644 --- a/proxlib.h +++ b/proxlib.h @@ -9,41 +9,35 @@ #define PROXY_PORT 2020 #define PROXY_CONN 20 +#define RELAY_BUFFER_SIZE 1024*2 +#define RELAY_POLL_TIMEOUT 1000 enum states { state_rcv_clt = 0, state_con_srv, state_fwd_srv, state_rcv_srv, - state_fwd_clt + state_fwd_clt, + state_ok }; -enum errs { - err_generic = 1, - err_mem, - err_recv, - err_pars, - err_parstitle, - err_parsheader, - err_support -}; +#define ERR_GENERIC -1 +#define ERR_MEM -2 +#define ERR_RECV -3 +#define ERR_SEND -4 +#define ERR_PARS -5 +#define ERR_PARSTITLE -6 +#define ERR_PARSHEADER -7 +#define ERR_SUPPORT -8 +#define ERR_TIMEOUT -9 char *states_str[] = { "state_rcv_clt", "state_con_srv", "state_fwd_srv", "state_rcv_srv", - "state_fwd_clt" -}; - -char *errs_str[] = { - "err_generic", - "err_mem", - "err_recv", - "err_pars", - "err_parstitle", - "err_parsheader", - "err_support" + "state_fwd_clt", + "state_ok" }; struct conn {