proxy: implement a relaying system

also remove now defunct functions +
rewrite state machine
This commit is contained in:
Kevin J. 2024-09-17 18:40:01 +02:00
parent 009913283c
commit dbcf0d8ef4
4 changed files with 161 additions and 238 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
*.o *.o
proxy proxy
compile_commands.json compile_commands.json
.cache

BIN
proxlib

Binary file not shown.

358
proxlib.c
View File

@ -7,6 +7,7 @@
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netinet/in.h> #include <netinet/in.h>
#include <netdb.h> #include <netdb.h>
#include <poll.h>
#include "proxlib.h" #include "proxlib.h"
#include "parslib/parslib.h" #include "parslib/parslib.h"
@ -98,21 +99,21 @@ int pull_content_length(int fd, int len, int *msgbuff_len, char **msgbuff) {
int line_len = len; int line_len = len;
char *line = (char *) calloc(1, line_len); char *line = (char *) calloc(1, line_len);
if (!line) { if (!line) {
return err_mem; return ERR_MEM;
} }
int bytes = 0; int bytes = 0;
do { do {
ret = recv(fd, line+bytes, line_len-bytes, MSG_WAITALL); ret = recv(fd, line+bytes, line_len-bytes, MSG_WAITALL);
if (ret < 0) { if (ret < 0) {
return err_recv; return ERR_MEM;
} }
bytes += ret; bytes += ret;
} while (bytes < line_len); } while (bytes < line_len);
*msgbuff = (char *) realloc(*msgbuff, *msgbuff_len+line_len); *msgbuff = (char *) realloc(*msgbuff, *msgbuff_len+line_len);
if (!*msgbuff) { if (!*msgbuff) {
return err_mem; return ERR_MEM;
} }
memcpy(*msgbuff+*msgbuff_len, line, line_len); memcpy(*msgbuff+*msgbuff_len, line, line_len);
@ -166,9 +167,6 @@ int pull_chunked_encoding(int fd, int *msgbuff_len, char **msgbuff) {
memcpy(*msgbuff+*msgbuff_len, line, line_len); memcpy(*msgbuff+*msgbuff_len, line, line_len);
*msgbuff_len += line_len; *msgbuff_len += line_len;
if (debug == 1) {
fprintf(stdout, "debug - [upstream] received chunk:%d\n", line_len);
}
free(line); free(line);
} }
return 0; return 0;
@ -176,137 +174,27 @@ int pull_chunked_encoding(int fd, int *msgbuff_len, char **msgbuff) {
} }
void do_err(void) { void do_err(void) {
fprintf(stderr, "[%s] failed with error code %d=%s\n", fprintf(stderr, "failed with error code %d\n", err);
states_str[statem], err, errs_str[err]);
}
int do_fwd_clt(struct conn *conn) {
int bytes = 0;
int ret = 0;
while (bytes < conn->srvbuff_len) {
ret = write(conn->cltfd, conn->srvbuff+bytes, conn->srvbuff_len-bytes);
if (ret < 0)
return -1;
bytes += ret;
}
return 0;
}
int do_rcv_srv(struct conn *conn) {
int ret = 0;
char *line = NULL;
char *msgbuff = NULL;
int line_len = 0;
int msgbuff_len = 0;
// response line
ret = read_line(conn->srvfd, &line_len, &line, &msgbuff_len, &msgbuff);
if (ret < 0) {
return err_recv;
}
if (debug == 1) {
fprintf(stdout, "debug - [upstream] received line: %s\n", line);
}
ret = parestitl(line, line_len, &(conn->srvres.titl));
if (ret < 0) {
return err_parstitle;
}
if (debug == 1) {
fprintf(stdout, "debug - [upstream] parsed response line\n");
}
free(line);
// headers
int next_header = 1;
while (next_header) {
ret = read_line(conn->srvfd, &line_len, &line, &msgbuff_len, &msgbuff);
if (ret < 0) {
return err_recv;
}
if (line_len == 0) {
if (debug == 1) {
fprintf(stdout, "debug - [upstream] reached end of headers\n");
}
next_header = 0;
continue;
}
if (debug == 1) {
fprintf(stdout, "debug - [upstream] received line: %s\n", line);
}
ret = parshfield(line, line_len, conn->srvres.hentries);
if (ret < 0) {
return err_parsheader;
}
if (debug == 1) {
fprintf(stdout, "debug - parsed header field\n");
}
free(line);
}
// body
struct httpares *res = &conn->srvres;
struct point *content_length_entry = &res->hentries[header_content_length];
struct point *transfer_encoding_entry = &res->hentries[header_transfer_encoding];
if (content_length_entry->er) {
int content_length = 0;
ret = stoin(content_length_entry->er, content_length_entry->len, &content_length);
if (ret < 0) {
return err_pars;
}
ret = pull_content_length(conn->srvfd, content_length, &msgbuff_len, &msgbuff);
if (ret < 0) {
return err_recv;
}
fprintf(stdout, "Successfully received normal body from server\n");
} else if (transfer_encoding_entry->er && strcmp(transfer_encoding_entry->er, "chunked") == 0) {
ret = pull_chunked_encoding(conn->srvfd, &msgbuff_len, &msgbuff);
if (ret < 0) {
return err_recv;
}
fprintf(stdout, "Successfully received chunked body from server\n");
} else {
return err_support;
}
fprintf(stdout, "srvbuff:%p+srvbuff_len:%d\n", conn->srvbuff, conn->srvbuff_len);
conn->srvbuff = msgbuff;
conn->srvbuff_len = msgbuff_len;
return 0;
} }
int do_con_srv(struct conn *conn) { int do_con_srv(struct conn *conn) {
statem = state_con_srv;
int ret = 0; int ret = 0;
struct httpareq *req = &conn->cltreq; struct httpareq *req = &conn->cltreq;
struct point *host = &req->hentries[header_host]; struct point *host = &req->hentries[header_host];
if (host->er == NULL) { if (host->er == NULL) {
return err_pars; return ERR_PARS;
} }
struct hostinfo *info = (struct hostinfo *) calloc(1, sizeof(struct hostinfo)); struct hostinfo *info = (struct hostinfo *) calloc(1, sizeof(struct hostinfo));
if (!info) { if (!info) {
return err_mem; return ERR_MEM;
} }
ret = pahostinfo(host->er, host->len, info); ret = pahostinfo(host->er, host->len, info);
if (ret < 0) { if (ret < 0) {
return err_pars; return ERR_PARS;
}
if (debug <= 2) {
fprintf(stdout, "Establishing connection with upstream: %.*s : %.*s\n", info->hostname_len, info->hostname, info->service_len, info->service);
} }
struct addrinfo hints; struct addrinfo hints;
@ -321,7 +209,7 @@ int do_con_srv(struct conn *conn) {
free(info->hostname); free(info->hostname);
free(info->service); free(info->service);
free(info); free(info);
return err_pars; return ERR_PARS;
} }
ret = conn->srvfd = socket(res->ai_family, res->ai_socktype, ret = conn->srvfd = socket(res->ai_family, res->ai_socktype,
@ -331,7 +219,7 @@ int do_con_srv(struct conn *conn) {
free(info->hostname); free(info->hostname);
free(info->service); free(info->service);
free(info); free(info);
return err_pars; return ERR_PARS;
} }
ret = connect(conn->srvfd, res->ai_addr, res->ai_addrlen); ret = connect(conn->srvfd, res->ai_addr, res->ai_addrlen);
@ -340,26 +228,15 @@ int do_con_srv(struct conn *conn) {
free(info->hostname); free(info->hostname);
free(info->service); free(info->service);
free(info); free(info);
return err_pars; return ERR_PARS;
} }
return ret; return ret;
} }
int do_fwd_srv(struct conn *conn) {
int bytes = 0;
int ret = 0;
while (bytes < conn->cltbuff_len) {
ret = write(conn->srvfd, conn->cltbuff+bytes, conn->cltbuff_len-bytes);
if (ret < 0)
return -1;
bytes += ret;
}
return 0;
}
int do_rcv_clt(struct conn *conn) { int do_rcv_clt(struct conn *conn) {
statem = state_rcv_clt;
int ret = 0; int ret = 0;
char *line = NULL; char *line = NULL;
char *msgbuff = NULL; char *msgbuff = NULL;
@ -367,23 +244,14 @@ int do_rcv_clt(struct conn *conn) {
int msgbuff_len = 0; int msgbuff_len = 0;
// request line // request line
fprintf(stdout, "debug - listening for new lines from client\n");
ret = read_line(conn->cltfd, &line_len, &line, &msgbuff_len, &msgbuff); ret = read_line(conn->cltfd, &line_len, &line, &msgbuff_len, &msgbuff);
if (ret < 0) { if (ret < 0) {
return err_recv; return ERR_RECV;
}
if (debug == 1) {
fprintf(stdout, "debug - received line of %d bytes from client\n", line_len);
} }
ret = pareqtitl(line, line_len, &(conn->cltreq.titl)); ret = pareqtitl(line, line_len, &(conn->cltreq.titl));
if (ret < 0) { if (ret < 0) {
return err_parstitle; return ERR_PARSTITLE;
}
if (debug == 1) {
fprintf(stdout, "[do_rcv_clt] parsed request line\n");
} }
free(line); free(line);
@ -393,28 +261,17 @@ int do_rcv_clt(struct conn *conn) {
while (next_header) { while (next_header) {
ret = read_line(conn->cltfd, &line_len, &line, &msgbuff_len, &msgbuff); ret = read_line(conn->cltfd, &line_len, &line, &msgbuff_len, &msgbuff);
if (ret < 0) { if (ret < 0) {
return err_recv; return ERR_RECV;
} }
if (line_len == 0) { if (line_len == 0) {
if (debug == 1) {
fprintf(stdout, "[do_rcv_clt] reached end of headers for the client\n");
}
next_header = 0; next_header = 0;
continue; continue;
} }
if (debug == 1) {
fprintf(stdout, "debug - received line: %s\n", line);
}
ret = parshfield(line, line_len, conn->cltreq.hentries); ret = parshfield(line, line_len, conn->cltreq.hentries);
if (ret < 0) { if (ret < 0) {
return err_parsheader; return ERR_PARSHEADER;
}
if (debug == 1) {
fprintf(stdout, "debug - parsed header field\n");
} }
free(line); free(line);
@ -429,21 +286,18 @@ int do_rcv_clt(struct conn *conn) {
ret = stoin(content_length_entry->er, content_length_entry->len, &content_length); ret = stoin(content_length_entry->er, content_length_entry->len, &content_length);
if (ret < 0) { if (ret < 0) {
return err_pars; return ERR_PARS;
} }
ret = pull_content_length(conn->srvfd, content_length, &msgbuff_len, &msgbuff); ret = pull_content_length(conn->srvfd, content_length, &msgbuff_len, &msgbuff);
if (ret < 0) { if (ret < 0) {
return err_recv; return ERR_RECV;
} }
fprintf(stdout, "Successfully received normal body from server\n");
} else if (transfer_encoding_entry->er && strcmp(transfer_encoding_entry->er, "chunked") == 0) { } else if (transfer_encoding_entry->er && strcmp(transfer_encoding_entry->er, "chunked") == 0) {
ret = pull_chunked_encoding(conn->srvfd, &msgbuff_len, &msgbuff); ret = pull_chunked_encoding(conn->srvfd, &msgbuff_len, &msgbuff);
if (ret < 0) { if (ret < 0) {
return err_recv; return ERR_RECV;
} }
fprintf(stdout, "Successfully received chunked body from server\n");
} }
conn->cltbuff = msgbuff; conn->cltbuff = msgbuff;
@ -452,52 +306,125 @@ int do_rcv_clt(struct conn *conn) {
return 0; return 0;
} }
void do_clear(struct conn *conn) { int read_buffer(int fd, char **buff, int *len) {
statem = state_rcv_clt; char *tmp = (char *) malloc(RELAY_BUFFER_SIZE);
frepareq(&conn->cltreq); if (!tmp) {
frepares(&conn->srvres); return ERR_MEM;
free(conn->cltbuff); }
free(conn->srvbuff);
memset(tmp, 0, RELAY_BUFFER_SIZE);
int bytes = recv(fd, tmp, RELAY_BUFFER_SIZE, 0);
if (bytes <= 0) {
free(tmp);
return ERR_RECV;
}
*buff = realloc(*buff, *len+bytes);
if (!*buff) {
free(tmp);
return ERR_MEM;
}
memcpy(*buff+*len, tmp, bytes);
*len += bytes;
return 0;
}
int write_buffer(int fd, char **buff, int *len) {
if (*len <= 0) {
*len = 0;
return 0;
}
int writen = send(fd, *buff, *len, 0);
if (writen < 0) {
return ERR_SEND;
}
char *trunc = (char *) malloc(*len-writen);
if (!trunc) {
return ERR_MEM;
}
memcpy(trunc, *buff+writen, *len-writen);
char *tofree = *buff; // FIXME: any better solution?
*buff = trunc;
*len -= writen;
free(tofree);
return 0;
} }
void do_statem(struct conn *conn) { void do_statem(struct conn *conn) {
int ret = 0; int ret = 0;
for (int counter = 0; counter < MAX_BOUND; counter++) { ret = do_rcv_clt(conn);
switch (statem) { if (ret < 0) {
case state_rcv_clt: err = ret;
ret = do_rcv_clt(conn); do_err();
break;
case state_con_srv:
ret = do_con_srv(conn);
break;
case state_fwd_srv:
ret = do_fwd_srv(conn);
break;
case state_rcv_srv:
ret = do_rcv_srv(conn);
break;
case state_fwd_clt:
ret = do_fwd_clt(conn);
break;
}
if (ret > 0) {
err = ret;
}
if (err) {
do_err();
break;
}
if (statem == state_fwd_clt) {
do_clear(conn);
continue;
}
statem++;
} }
ret = do_con_srv(conn);
if (ret < 0) {
err = ret;
do_err();
}
// relay the data between the two sockets until the end of time
ssize_t bytes_received;
struct pollfd fds[2];
for (;;) {
memset(fds, 0, 2*sizeof(struct pollfd));
fds[0].fd = conn->cltfd;
fds[1].fd = conn->srvfd;
if (conn->srvbuff_len > 0) {
fds[0].events |= POLLOUT;
}
if (conn->cltbuff_len > 0) {
fds[1].events |= POLLOUT;
}
if (!conn->srvbuff_len) {
fds[1].events |= POLLIN;
}
if (!conn->cltbuff_len) {
fds[0].events |= POLLIN;
}
ret = poll(fds, 2, 1000);
if (fds[1].revents & POLLOUT) {
ret = write_buffer(conn->srvfd, &conn->cltbuff, &conn->cltbuff_len);
}
if (fds[1].revents & POLLIN) {
ret = read_buffer(conn->srvfd, &conn->srvbuff, &conn->srvbuff_len);
}
if (fds[0].revents & POLLIN) {
ret = read_buffer(conn->cltfd, &conn->cltbuff, &conn->cltbuff_len);
}
if (fds[0].revents & POLLOUT) {
ret = write_buffer(conn->cltfd, &conn->srvbuff, &conn->srvbuff_len);
}
if (fds[0].revents & POLLHUP) {
break;
}
if (ret < 0) {
break;
}
}
if (conn->cltbuff_len > 0) {
write_buffer(conn->srvfd, &conn->cltbuff, &conn->cltbuff_len);
}
if (conn->srvbuff_len > 0) {
write_buffer(conn->cltfd, &conn->srvbuff, &conn->srvbuff_len);
}
close(conn->cltfd);
close(conn->srvfd);
exit(0); // child die
} }
int do_srv(void) { int do_srv(void) {
@ -535,9 +462,8 @@ int do_srv(void) {
return -1; return -1;
} }
fprintf(stdout, "Listening on port %d\n", PROXY_PORT);
for (;;) { for (;;) {
fprintf(stdout, "listening for sockets\n");
struct sockaddr_in new_clt_addr; struct sockaddr_in new_clt_addr;
socklen_t new_clt_addr_len= sizeof(new_clt_addr); socklen_t new_clt_addr_len= sizeof(new_clt_addr);
int new_clt_sock; int new_clt_sock;
@ -549,6 +475,7 @@ int do_srv(void) {
"client\n"); "client\n");
return -1; return -1;
} }
fprintf(stdout, "accepted new client socket\n");
ret = fork(); ret = fork();
if (ret < 0) { if (ret < 0) {
@ -558,12 +485,11 @@ int do_srv(void) {
} }
if (ret > 0) { if (ret > 0) {
fprintf(stdout, "[PROGRAM] Successfully forked a new child process" fprintf(stdout, "+new request process:%d(pid)\n", ret);
" with PID %d\n", ret);
continue; continue;
} }
// child // request process
struct conn *conn = (struct conn *) calloc(1, sizeof(struct conn)); struct conn *conn = (struct conn *) calloc(1, sizeof(struct conn));
if (!conn) { if (!conn) {
fprintf(stderr, "Not enough dynamic memory to establish connection\n"); fprintf(stderr, "Not enough dynamic memory to establish connection\n");
@ -574,7 +500,6 @@ int do_srv(void) {
statem = state_rcv_clt; statem = state_rcv_clt;
do_statem(conn); do_statem(conn);
free(conn); free(conn);
return 0; return 0;
} }
@ -588,7 +513,10 @@ int main(int argc, char *argv[]) {
return -1; return -1;
} }
return do_srv(); ret = do_srv();
if (ret < 0) {
return -1;
}
fretres(); fretres();
} }

View File

@ -9,41 +9,35 @@
#define PROXY_PORT 2020 #define PROXY_PORT 2020
#define PROXY_CONN 20 #define PROXY_CONN 20
#define RELAY_BUFFER_SIZE 1024*2
#define RELAY_POLL_TIMEOUT 1000
enum states { enum states {
state_rcv_clt = 0, state_rcv_clt = 0,
state_con_srv, state_con_srv,
state_fwd_srv, state_fwd_srv,
state_rcv_srv, state_rcv_srv,
state_fwd_clt state_fwd_clt,
state_ok
}; };
enum errs { #define ERR_GENERIC -1
err_generic = 1, #define ERR_MEM -2
err_mem, #define ERR_RECV -3
err_recv, #define ERR_SEND -4
err_pars, #define ERR_PARS -5
err_parstitle, #define ERR_PARSTITLE -6
err_parsheader, #define ERR_PARSHEADER -7
err_support #define ERR_SUPPORT -8
}; #define ERR_TIMEOUT -9
char *states_str[] = { char *states_str[] = {
"state_rcv_clt", "state_rcv_clt",
"state_con_srv", "state_con_srv",
"state_fwd_srv", "state_fwd_srv",
"state_rcv_srv", "state_rcv_srv",
"state_fwd_clt" "state_fwd_clt",
}; "state_ok"
char *errs_str[] = {
"err_generic",
"err_mem",
"err_recv",
"err_pars",
"err_parstitle",
"err_parsheader",
"err_support"
}; };
struct conn { struct conn {