// stdlib
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <errno.h>
// system
#include <fcntl.h>
#include <poll.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <sys/socket.h>
#include <netinet/ip.h>
// C++
#include <string>
#include <vector>
#include <map>
static void msg(const char *msg) {
fprintf(stderr, "%s\n", msg);
}
static void msg_errno(const char *msg) {
fprintf(stderr, "[errno:%d] %s\n", errno, msg);
}
static void die(const char *msg) {
fprintf(stderr, "[%d] %s\n", errno, msg);
abort();
}
static void fd_set_nb(int fd) {
errno = 0;
int flags = fcntl(fd, F_GETFL, 0);
if (errno) {
die("fcntl error");
return;
}
flags |= O_NONBLOCK;
errno = 0;
(void)fcntl(fd, F_SETFL, flags);
if (errno) {
die("fcntl error");
}
}
const size_t k_max_msg = 32 << 20; // likely larger than the kernel buffer
struct Conn {
int fd = -1;
// application's intention, for the event loop
bool want_read = false;
bool want_write = false;
bool want_close = false;
// buffered input and output
std::vector<uint8_t> incoming; // data to be parsed by the application
std::vector<uint8_t> outgoing; // responses generated by the application
};
// append to the back
static void
buf_append(std::vector<uint8_t> &buf, const uint8_t *data, size_t len) {
buf.insert(buf.end(), data, data + len);
}
// remove from the front
static void buf_consume(std::vector<uint8_t> &buf, size_t n) {
buf.erase(buf.begin(), buf.begin() + n);
}
// application callback when the listening socket is ready
static Conn *handle_accept(int fd) {
// accept
struct sockaddr_in client_addr = {};
socklen_t socklen = sizeof(client_addr);
int connfd = accept(fd, (struct sockaddr *)&client_addr, &socklen);
if (connfd < 0) {
msg_errno("accept() error");
return NULL;
}
uint32_t ip = client_addr.sin_addr.s_addr;
fprintf(stderr, "new client from %u.%u.%u.%u:%u\n",
ip & 255, (ip >> 8) & 255, (ip >> 16) & 255, ip >> 24,
ntohs(client_addr.sin_port)
);
// set the new connection fd to nonblocking mode
fd_set_nb(connfd);
// create a `struct Conn`
Conn *conn = new Conn();
conn->fd = connfd;
conn->want_read = true;
return conn;
}
const size_t k_max_args = 200 * 1000;
static bool read_u32(const uint8_t *&cur, const uint8_t *end, uint32_t &out) {
if (cur + 4 > end) {
return false;
}
memcpy(&out, cur, 4);
cur += 4;
return true;
}
static bool
read_str(const uint8_t *&cur, const uint8_t *end, size_t n, std::string &out) {
if (cur + n > end) {
return false;
}
out.assign(cur, cur + n);
cur += n;
return true;
}
// +------+-----+------+-----+------+-----+-----+------+
// | nstr | len | str1 | len | str2 | ... | len | strn |
// +------+-----+------+-----+------+-----+-----+------+
static int32_t
parse_req(const uint8_t *data, size_t size, std::vector<std::string> &out) {
const uint8_t *end = data + size;
uint32_t nstr = 0;
if (!read_u32(data, end, nstr)) {
return -1;
}
if (nstr > k_max_args) {
return -1; // safety limit
}
while (out.size() < nstr) {
uint32_t len = 0;
if (!read_u32(data, end, len)) {
return -1;
}
out.push_back(std::string());
if (!read_str(data, end, len, out.back())) {
return -1;
}
}
if (data != end) {
return -1; // trailing garbage
}
return 0;
}
// Response::status
enum {
RES_OK = 0,
RES_ERR = 1, // error
RES_NX = 2, // key not found
};
// +--------+---------+
// | status | data... |
// +--------+---------+
struct Response {
uint32_t status = 0;
std::vector<uint8_t> data;
};
// placeholder; implemented later
static std::map<std::string, std::string> g_data;
static void do_request(std::vector<std::string> &cmd, Response &out) {
if (cmd.size() == 2 && cmd[0] == "get") {
auto it = g_data.find(cmd[1]);
if (it == g_data.end()) {
out.status = RES_NX; // not found
return;
}
const std::string &val = it->second;
out.data.assign(val.begin(), val.end());
} else if (cmd.size() == 3 && cmd[0] == "set") {
g_data[cmd[1]].swap(cmd[2]);
} else if (cmd.size() == 2 && cmd[0] == "del") {
g_data.erase(cmd[1]);
} else {
out.status = RES_ERR; // unrecognized command
}
}
static void make_response(const Response &resp, std::vector<uint8_t> &out) {
uint32_t resp_len = 4 + (uint32_t)resp.data.size();
buf_append(out, (const uint8_t *)&resp_len, 4);
buf_append(out, (const uint8_t *)&resp.status, 4);
buf_append(out, resp.data.data(), resp.data.size());
}
// process 1 request if there is enough data
static bool try_one_request(Conn *conn) {
// try to parse the protocol: message header
if (conn->incoming.size() < 4) {
return false; // want read
}
uint32_t len = 0;
memcpy(&len, conn->incoming.data(), 4);
if (len > k_max_msg) {
msg("too long");
conn->want_close = true;
return false; // want close
}
// message body
if (4 + len > conn->incoming.size()) {
return false; // want read
}
const uint8_t *request = &conn->incoming[4];
// got one request, do some application logic
std::vector<std::string> cmd;
if (parse_req(request, len, cmd) < 0) {
msg("bad request");
conn->want_close = true;
return false; // want close
}
Response resp;
do_request(cmd, resp);
make_response(resp, conn->outgoing);
// application logic done! remove the request message.
buf_consume(conn->incoming, 4 + len);
// Q: Why not just empty the buffer? See the explanation of "pipelining".
return true; // success
}
// application callback when the socket is writable
static void handle_write(Conn *conn) {
assert(conn->outgoing.size() > 0);
ssize_t rv = write(conn->fd, &conn->outgoing[0], conn->outgoing.size());
if (rv < 0 && errno == EAGAIN) {
return; // actually not ready
}
if (rv < 0) {
msg_errno("write() error");
conn->want_close = true; // error handling
return;
}
// remove written data from `outgoing`
buf_consume(conn->outgoing, (size_t)rv);
// update the readiness intention
if (conn->outgoing.size() == 0) { // all data written
conn->want_read = true;
conn->want_write = false;
} // else: want write
}
// application callback when the socket is readable
static void handle_read(Conn *conn) {
// read some data
uint8_t buf[64 * 1024];
ssize_t rv = read(conn->fd, buf, sizeof(buf));
if (rv < 0 && errno == EAGAIN) {
return; // actually not ready
}
// handle IO error
if (rv < 0) {
msg_errno("read() error");
conn->want_close = true;
return; // want close
}
// handle EOF
if (rv == 0) {
if (conn->incoming.size() == 0) {
msg("client closed");
} else {
msg("unexpected EOF");
}
conn->want_close = true;
return; // want close
}
// got some new data
buf_append(conn->incoming, buf, (size_t)rv);
// parse requests and generate responses
while (try_one_request(conn)) {}
// Q: Why calling this in a loop? See the explanation of "pipelining".
// update the readiness intention
if (conn->outgoing.size() > 0) { // has a response
conn->want_read = false;
conn->want_write = true;
// The socket is likely ready to write in a request-response protocol,
// try to write it without waiting for the next iteration.
return handle_write(conn);
} // else: want read
}
int main() {
// the listening socket
int fd = socket(AF_INET, SOCK_STREAM, 0);
if (fd < 0) {
die("socket()");
}
int val = 1;
setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val));
// bind
struct sockaddr_in addr = {};
addr.sin_family = AF_INET;
addr.sin_port = ntohs(1234);
addr.sin_addr.s_addr = ntohl(0); // wildcard address 0.0.0.0
int rv = bind(fd, (const sockaddr *)&addr, sizeof(addr));
if (rv) {
die("bind()");
}
// set the listen fd to nonblocking mode
fd_set_nb(fd);
// listen
rv = listen(fd, SOMAXCONN);
if (rv) {
die("listen()");
}
// a map of all client connections, keyed by fd
std::vector<Conn *> fd2conn;
// the event loop
std::vector<struct pollfd> poll_args;
while (true) {
// prepare the arguments of the poll()
poll_args.clear();
// put the listening sockets in the first position
struct pollfd pfd = {fd, POLLIN, 0};
poll_args.push_back(pfd);
// the rest are connection sockets
for (Conn *conn : fd2conn) {
if (!conn) {
continue;
}
// always poll() for error
struct pollfd pfd = {conn->fd, POLLERR, 0};
// poll() flags from the application's intent
if (conn->want_read) {
pfd.events |= POLLIN;
}
if (conn->want_write) {
pfd.events |= POLLOUT;
}
poll_args.push_back(pfd);
}
// wait for readiness
int rv = poll(poll_args.data(), (nfds_t)poll_args.size(), -1);
if (rv < 0 && errno == EINTR) {
continue; // not an error
}
if (rv < 0) {
die("poll");
}
// handle the listening socket
if (poll_args[0].revents) {
if (Conn *conn = handle_accept(fd)) {
// put it into the map
if (fd2conn.size() <= (size_t)conn->fd) {
fd2conn.resize(conn->fd + 1);
}
assert(!fd2conn[conn->fd]);
fd2conn[conn->fd] = conn;
}
}
// handle connection sockets
for (size_t i = 1; i < poll_args.size(); ++i) { // note: skip the 1st
uint32_t ready = poll_args[i].revents;
if (ready == 0) {
continue;
}
Conn *conn = fd2conn[poll_args[i].fd];
if (ready & POLLIN) {
assert(conn->want_read);
handle_read(conn); // application logic
}
if (ready & POLLOUT) {
assert(conn->want_write);
handle_write(conn); // application logic
}
// close the socket from socket error or application logic
if ((ready & POLLERR) || conn->want_close) {
(void)close(conn->fd);
fd2conn[conn->fd] = NULL;
delete conn;
}
} // for each connection sockets
} // the event loop
return 0;
}
redis/07/07_server.cpp
(Error report | Ask questions) @ build-your-own.org
Build Your Own Redis with C/C++