diff --git a/net.cc b/net.cc index 6e5745a..cfe4e69 100644 --- a/net.cc +++ b/net.cc @@ -23,6 +23,7 @@ #include #include +#include #include #include #include @@ -232,6 +233,10 @@ int getRecvSocket(const char* bindhost, int port) { PLOG_E("socket(AF_INET6)"); return -1; } + if (fcntl(sockfd, F_SETFL, O_NONBLOCK)) { + PLOG_E("fcntl(%d, F_SETFL, O_NONBLOCK)", sockfd); + return -1; + } int so = 1; if (setsockopt(sockfd, SOL_SOCKET, SO_REUSEADDR, &so, sizeof(so)) == -1) { PLOG_E("setsockopt(%d, SO_REUSEADDR)", sockfd); @@ -264,7 +269,7 @@ int getRecvSocket(const char* bindhost, int port) { int acceptConn(int listenfd) { struct sockaddr_in6 cli_addr; socklen_t socklen = sizeof(cli_addr); - int connfd = accept(listenfd, (struct sockaddr*)&cli_addr, &socklen); + int connfd = accept4(listenfd, (struct sockaddr*)&cli_addr, &socklen, SOCK_NONBLOCK); if (connfd == -1) { if (errno != EINTR) { PLOG_E("accept(%d)", listenfd); diff --git a/nsjail.cc b/nsjail.cc index 691f4dd..ce11e07 100644 --- a/nsjail.cc +++ b/nsjail.cc @@ -21,6 +21,8 @@ #include "nsjail.h" +#include +#include #include #include #include @@ -31,7 +33,10 @@ #include #include +#include +#include #include +#include #include "cmdline.h" #include "logs.h" @@ -47,10 +52,7 @@ static __thread int sigFatal = 0; static __thread bool showProc = false; static void sigHandler(int sig) { - if (sig == SIGALRM) { - return; - } - if (sig == SIGCHLD) { + if (sig == SIGALRM || sig == SIGCHLD || sig == SIGPIPE) { return; } if (sig == SIGUSR1 || sig == SIGQUIT) { @@ -74,7 +76,7 @@ static bool setSigHandler(int sig) { if (sig == SIGTTIN || sig == SIGTTOU) { sa.sa_handler = SIG_IGN; - }; + } if (sigaction(sig, &sa, NULL) == -1) { PLOG_E("sigaction(%d)", sig); return false; @@ -115,6 +117,72 @@ static bool setTimer(nsjconf_t* nsjconf) { return true; } +static bool pipeTraffic(nsjconf_t* nsjconf, int listenfd) { + std::vector fds; + fds.reserve(nsjconf->pipes.size() * 2 + 1); + for (const auto& p : nsjconf->pipes) { + fds.push_back({ + .fd = p.first, + .events = POLLIN, + .revents = 0, + }); + fds.push_back({ + .fd = p.second, + .events = POLLOUT, + .revents = 0, + }); + } + fds.push_back({ + .fd = listenfd, + .events = POLLIN, + .revents = 0, + }); + LOG_D("Waiting for fd activity"); + while (poll(fds.data(), fds.size(), -1) > 0) { + if (fds.back().revents != 0) { + LOG_D("New connection ready"); + return true; + } + bool cleanup = false; + for (size_t i = 0; i < fds.size() - 1; i += 2) { + bool read_ready = fds[i].events == 0 || (fds[i].revents & POLLIN) == POLLIN; + bool write_ready = + fds[i + 1].events == 0 || (fds[i + 1].revents & POLLOUT) == POLLOUT; + if (read_ready && write_ready) { + if (splice(fds[i].fd, nullptr, fds[i + 1].fd, nullptr, 4096, + SPLICE_F_NONBLOCK) == -1 && + errno != EAGAIN) { + PLOG_E("splice fd pair #%ld {%d, %d}\n", i / 2, fds[i].fd, + fds[i + 1].fd); + } + fds[i].events = POLLIN; + fds[i + 1].events = POLLOUT; + } else if (read_ready) { + LOG_D("Read ready on %ld", i / 2); + fds[i].events = 0; + } else if (write_ready) { + LOG_D("Write ready on %ld", i / 2); + fds[i + 1].events = 0; + } + if ((fds[i].revents & (POLLHUP | POLLERR)) != 0 || + (fds[i + 1].revents & (POLLHUP | POLLERR)) != 0) { + LOG_D("Hangup on %ld", i / 2); + cleanup = true; + close(fds[i].fd); + close(fds[i + 1].fd); + nsjconf->pipes[i / 2] = {0, 0}; + } + } + if (cleanup) { + break; + } + } + nsjconf->pipes.erase( + std::remove(nsjconf->pipes.begin(), nsjconf->pipes.end(), std::pair(0, 0)), + nsjconf->pipes.end()); + return false; +} + static int listenMode(nsjconf_t* nsjconf) { int listenfd = net::getRecvSocket(nsjconf->bindhost.c_str(), nsjconf->port); if (listenfd == -1) { @@ -131,10 +199,21 @@ static int listenMode(nsjconf_t* nsjconf) { showProc = false; subproc::displayProc(nsjconf); } - int connfd = net::acceptConn(listenfd); - if (connfd >= 0) { - subproc::runChild(nsjconf, connfd, connfd, connfd); - close(connfd); + if (pipeTraffic(nsjconf, listenfd)) { + int connfd = net::acceptConn(listenfd); + if (connfd >= 0) { + int in[2]; + int out[2]; + if (pipe(in) != 0 || pipe(out) != 0) { + PLOG_E("pipe"); + continue; + } + nsjconf->pipes.emplace_back(connfd, in[1]); + nsjconf->pipes.emplace_back(out[0], connfd); + subproc::runChild(nsjconf, in[0], out[1], out[1]); + close(in[0]); + close(out[1]); + } } subproc::reapProc(nsjconf); } diff --git a/nsjail.h b/nsjail.h index 68b1253..98c3661 100644 --- a/nsjail.h +++ b/nsjail.h @@ -44,6 +44,7 @@ static const int nssigs[] = { SIGTERM, SIGTTIN, SIGTTOU, + SIGPIPE, }; struct pids_t { @@ -157,6 +158,7 @@ struct nsjconf_t { std::vector openfds; std::vector caps; std::vector ifaces; + std::vector> pipes; }; #endif /* _NSJAIL_H */