/*
 * Copyright (c) 2001 Daniel Hartmeier
 * Copyright (c) 2005 Marcus Glocker
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *    - Redistributions of source code must retain the above copyright
 *      notice, this list of conditions and the following disclaimer.
 *    - Redistributions in binary form must reproduce the above
 *      copyright notice, this list of conditions and the following
 *      disclaimer in the documentation and/or other materials provided
 *      with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
 * COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
 * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
 * ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
 * POSSIBILITY OF SUCH DAMAGE.
 *
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/wait.h>

#include <netinet/in.h>
#include <arpa/inet.h>

#include <ctype.h>
#include <errno.h>
#include <fcntl.h>
#include <netdb.h>
#include <signal.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <unistd.h>

/* solaris quirks */
#ifdef __sun
#define INADDR_NONE	((in_addr_t) - 1)
#endif

/*
 * prototypes
 */
int	open_proxy(void);
int	sync_write(int, const char *, int);
void	logx(const char *, ...);
void	handle_client(int);
void	read_config(FILE *f);
void	parse_host(char *, char **, unsigned *);
char	*base64e(char *);

/*
 * global variables
 */
int			parentpid;
FILE			*logfile = NULL;
FILE			*pidfile = NULL;
static const char	*version = "1.3";
static const char	*cfgfile = "gotthard.conf";
volatile sig_atomic_t	quit = 0;

struct config {
	char		*pid;
	char		*log;
	char		*path;
	char		*addr_local;
	char		*addr_proxy;
	char		*addr_ext;
	char		*auth;
	unsigned	port_local;
	unsigned	port_proxy;
	unsigned	port_ext;
} config;

static const unsigned char encode[64] =
	"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

/*
 * signal handler
 */
void
handle_signal(int sigraised)
{
	switch (sigraised) {
	case SIGINT:
	case SIGTERM:
		quit = 1;
		break;
	case SIGCHLD: {
		pid_t	pid;
		int	status;

		while ((pid = waitpid(-1, &status, WNOHANG)) > 0)
			if (WIFEXITED(status))
				logx("child %i exited with status %i",
				    pid, WEXITSTATUS(status));
			else if (WIFSIGNALED(status))
				logx("child %i exited due to signal %i",
				    pid, WTERMSIG(status));
			else
				logx("child %i stopped", pid);

		break;
	}
	case SIGHUP:
		/* ignore */
		break;
	case SIGQUIT:
		/* ignore */
		break;
	case SIGPIPE:
		/* ignore */
		break;
	case SIGALRM:
		/* ignore */
		break;
	default:
		/* ignore */
		break;
	}
}

/*
 * usage
 */
void
usage(int mode)
{
	const static char	*progname = "gotthard";

	if (mode) {
		fprintf(stderr, "%s %s\n", progname, version);
		exit(1);
	}

	fprintf(stderr, "usage: %s ", progname);
	fprintf(stderr, "[-hv] [-c configfile]\n\n");
	fprintf(stderr, "options:\n");
	fprintf(stderr, "  -h\t\t: This help.\n");
	fprintf(stderr, "  -v\t\t: Shows version.\n");
	fprintf(stderr, "  -c configfile\t: Use an alternate configfile.\n");

	exit(1);
}

/*
 * ssh through https proxy tunnel
 */
int
main(int argc, char *argv[])
{
	int 			r, ch, val, cfgflag, listen_fd;
	char			tmp[1024];
	pid_t			pid;
	socklen_t		len;
	struct sockaddr_in	sa;
	FILE			*configfile;

	cfgflag = 0;
	listen_fd = -1;
	configfile = NULL;
	memset(&config, 0, sizeof(config));

	/*
	 * get command line options
	 */
	while ((ch = getopt(argc, argv, "hvc:")) != -1) {
		switch (ch) {
		case 'c':
			cfgfile = optarg;
			cfgflag = 1;
			break;
		case 'v':
			usage(1);
			break;
		case 'h':
			/* FALLTHROUGH */
		default:
			usage(0);
			/* NOTREACHED */
		}
	}

	/*
	 * read config file
	 */
	configfile = fopen(cfgfile, "rb");
	if (!configfile) {
		fprintf(stderr, "%s: %s\n", cfgfile, strerror(errno));
		exit(1);
	}
	read_config(configfile);
	fclose(configfile);
	configfile = NULL;
	if (!config.addr_local || !config.addr_proxy || !config.addr_ext ||
	    !config.port_local || !config.port_proxy || !config.port_ext) {
		fprintf(stderr, "config file incomplete\n");
		exit(1);
	}

	/*
	 * prepare pid file and log file in executers home directory
	 */
	snprintf(tmp, sizeof(tmp), "%s/.gotthard", getenv("HOME"));
	config.path = strdup(tmp);
	snprintf(tmp, sizeof(tmp), "%s/gotthard.log", config.path);
	config.log = strdup(tmp);
	snprintf(tmp, sizeof(tmp), "%s/gotthard.pid", config.path);
	config.pid = strdup(tmp);
	mkdir(config.path, 0755);

	/*
	 * daemonize
	 */
	if ((pid = fork()) < 0) {
		perror("fork");
		goto error;
	} else if (pid > 0)
		exit(0);
	if ((pid = setsid()) == -1) {
		perror("setsid");
		goto error;
	}
	if ((pid = fork()) < 0) {
		perror("fork");
		goto error;
	} else if (pid > 0)
		exit(0);
	parentpid = getpid();
	if (chdir("/")) {
		perror("chdir");
		goto error;
	}
	umask(022);

	/*
	 * install signal handler
	 */
	signal(SIGHUP, handle_signal);
	signal(SIGINT, handle_signal);
	signal(SIGTERM, handle_signal);
	signal(SIGCHLD, handle_signal);
	signal(SIGQUIT, handle_signal);
	signal(SIGPIPE, handle_signal);
	signal(SIGALRM, handle_signal);

	if (config.pid) {
		char	s[128];

		if ((pidfile = fopen(config.pid, "rb"))) {
			pid_t	pid;

			fread(s, 1, 128, pidfile);
			fclose(pidfile);
			pidfile = NULL;
			pid = atol(s);
			if (kill(pid, SIGHUP) && errno == ESRCH)
				fprintf(stderr, "overwriting stale pid file\n");
			else {
				#ifdef __sun
				fprintf(stdout, "%s already runs (pid %li)\n",
				#else
				fprintf(stdout, "%s already runs (pid %d)\n",
				#endif
				    argv[0], pid);
				goto error;
			}
		}

		pidfile = fopen(config.pid, "wb");
		if (!pidfile) {
			perror("fopen");
			goto error;
		} else {
			snprintf(s, 128, "%i\n", parentpid);
			fwrite(s, 1, strlen(s), pidfile);
			fflush(pidfile);
		}
	}

	if (config.log) {
		if (!(logfile = fopen(config.log, "ab"))) {
			perror("fopen");
			goto error;
		}
	}
	close(0);
	close(1);
	close(2);
	open("/dev/null", O_RDWR);
	dup(logfile ? fileno(logfile) : 0);
	dup(logfile ? fileno(logfile) : 0);

	if (cfgflag)
		logx("STARTED as %s %s %s", argv[0], argv[1], argv[2]);
	else
		logx("STARTED as %s", argv[0]);
	logx("  pid      %s [%i]", config.pid, parentpid);
	logx("  log      %s", config.log);
	logx("  listen   %s:%u", config.addr_local, config.port_local);
	logx("  proxy    %s:%u", config.addr_proxy, config.port_proxy);
	logx("  external %s:%u", config.addr_ext, config.port_ext);

	if ((listen_fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
		perror("socket");
		goto error;
	}

	if (fcntl(listen_fd, F_SETFL, fcntl(listen_fd, F_GETFL) | O_NONBLOCK)) {
		perror("fcntl");
		goto error;
	}

        val = 1;
        if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, (const char *)&val,
	    sizeof(val))) {
                perror("setsockopt");
		goto error;
        }

	memset(&sa, 0, sizeof(sa));
	sa.sin_family = AF_INET;
	sa.sin_addr.s_addr = inet_addr(config.addr_local);
	sa.sin_port = htons(config.port_local);
        if (bind(listen_fd, (const struct sockaddr *)&sa, sizeof(sa))) {
                perror("bind");
		goto error;
        }

        if (listen(listen_fd, 1)) {
                perror("listen");
		goto error;
        }

	/*
	 * handle incoming client connections
	 */
	while (!quit) {
		fd_set		readfds;
		struct timeval	tv;

		FD_ZERO(&readfds);
		FD_SET(listen_fd, &readfds);
		tv.tv_sec = 10;
		tv.tv_usec = 0;
		r = select(listen_fd+1, &readfds, NULL, NULL, &tv);
		if (r < 0) {
			if (errno != EINTR) {
				perror("select");
				break;
			}
		}
		else if (r > 0 && FD_ISSET(listen_fd, &readfds)) {
			int client_fd;

			memset(&sa, 0, sizeof(sa));
			len = sizeof(sa);
			client_fd = accept(listen_fd, (struct sockaddr *)&sa,
			    &len);
			if (client_fd < 0 || len != sizeof(sa)) {
				perror("accept");
				break;
			}

			pid = fork();
			if (pid < 0) {
				perror("fork");
				break;
			}
			if (pid)
				close(client_fd);
			else {
				close(listen_fd);
				logx("connection from %s:%i",
				    inet_ntoa(sa.sin_addr), ntohs(sa.sin_port));
				handle_client(client_fd);
				close(client_fd);
				return 0;
			}
		}
	}

error:
	if (listen_fd)
		close(listen_fd);
	if (configfile) {
		fclose(configfile);
		fclose(pidfile);
		unlink(config.pid);
	}
	if (logfile) {
		fflush(logfile);
		fclose(logfile);
	}
	if (quit)
		logx("SIGINT/SIGTERM, terminating gracefully");

	return 0;
}

void
handle_client(int client_fd)
{
	int		r, max, len, proxy_fd;
	char		buf[65535];
	struct timeval	tv;
	fd_set		readfds;
	time_t		t;
	unsigned long	bytes_in, bytes_out;

	proxy_fd = open_proxy();
	if (!proxy_fd) {
		logx("handle_client() can't open proxy connection");
		return;
	}

	if (fcntl(proxy_fd, F_SETFL, fcntl(proxy_fd, F_GETFL) | O_NONBLOCK))
		perror("fcntl");

	max = client_fd > proxy_fd ? client_fd : proxy_fd;

	t = time(0);
	bytes_in = bytes_out = 0;
	for (;;) {
		FD_ZERO(&readfds);
		FD_SET(proxy_fd, &readfds);
		FD_SET(client_fd, &readfds);
                tv.tv_sec = 10;
                tv.tv_usec = 0;
                r = select(max + 1, &readfds, NULL, NULL, &tv);
                if (r < 0) {
			if (errno != EINTR) {
                        	perror("select");
				break;
			}
		}
		else if (r > 0) {
			if (FD_ISSET(proxy_fd, &readfds)) {
				len = read(proxy_fd, buf, 65535);
				if (len) {
					if (sync_write(client_fd, buf, len)) {
						logx("incomplete write");
						break;
					}
					bytes_in += len;
				} else {
					logx("connection closed by proxy");
					break;
				}
			}
			if (FD_ISSET(client_fd, &readfds)) {
				len = read(client_fd, buf, 65535);
				if (len) {
					if (sync_write(proxy_fd, buf, len)) {
						logx("incomplete write");
						break;
					}
					bytes_out += len;
				} else {
					logx("connection closed by client");
					break;
				}
			}
		}
	}
	logx("%lu bytes in, %lu bytes out, %u seconds",
	    bytes_in, bytes_out, time(0) - t);
}

int
open_proxy(void)
{
	int			i, fd, len;
	struct sockaddr_in	sa;
	char			buf[1024];

	if ((fd = socket(AF_INET, SOCK_STREAM, 0)) < 0) {
		perror("socket");
		return 0;
	}

	memset(&sa, 0, sizeof(sa));
	sa.sin_family = AF_INET;
	sa.sin_addr.s_addr = inet_addr(config.addr_proxy);
	if (sa.sin_addr.s_addr == INADDR_NONE) {
		struct hostent *h = gethostbyname(config.addr_proxy);
		if (h == NULL) {
			logx("invalid host name %s", config.addr_proxy);
			return 0;
		}
		memcpy(&sa.sin_addr.s_addr, h->h_addr, sizeof(in_addr_t));
	}
	sa.sin_port = htons(config.port_proxy);
	if (connect(fd, (struct sockaddr *)&sa, sizeof(sa))) {
		perror("connect");
		close(fd);
		return 0;
	}

	if (config.auth) {
		snprintf(buf, sizeof(buf),
		    "CONNECT %s:%u HTTP/1.0\r\nHost: %s\r\n"
		    "Proxy-Authorization: Basic %s\r\n\r\n",
		    config.addr_ext, config.port_ext, config.addr_ext,
		    config.auth);
	} else {
		snprintf(buf, sizeof(buf),
		    "CONNECT %s:%u HTTP/1.0\r\nHost: %s\r\n\r\n",
		    config.addr_ext, config.port_ext, config.addr_ext);
	}
	write(fd, buf, strlen(buf));

	i = 0;
	while ((len = read(fd, buf + i, 1)) > 0) {
		if (buf[i] == '\n' || i == 1023) {
			if (i && buf[i - 1] == '\r')
				buf[i - 1] = 0;
			else
				buf[i] = 0;
			i = 0;
			if (!buf[0]) {
				logx("connection through proxy %s:%u "
				    "established",
				    config.addr_proxy, config.port_proxy);
				return fd;
			}
		} else
			i++;
	}
        if (len < 0)
                perror("read");

	logx("didn't detect proxy response");

	return 0;
}

int
sync_write(int fd, const char *buf, int len)
{
	int		off, r;
	fd_set		writefds;
	struct timeval	tv;

	off = 0;

	while (len > off) {
		FD_ZERO(&writefds);
		FD_SET(fd, &writefds);
		tv.tv_sec = 10;
		tv.tv_usec = 0;
		r = select(fd + 1, NULL, &writefds, NULL, &tv);
		if (r < 0) {
			if (errno != EINTR) {
				perror("select");
				return 1;
			}
		}
		else if (r > 0 && FD_ISSET(fd, &writefds)) {
			r = write(fd, buf + off, len-off);
			if (r < 0) {
				perror("write");
				return 1;
			}
			off += r;
		}
	}

	return 0;
}

void
logx(const char *format, ...)
{
	time_t		t;
	struct tm	*tm;
	pid_t		pid;
	va_list		ap;

	t = time(0);
	tm = localtime(&t);
	if (tm)
		fprintf(stderr, "%4.4i.%2.2i.%2.2i %2.2i:%2.2i:%2.2i ",
		    tm->tm_year + 1900, tm->tm_mon + 1, tm->tm_mday,
		    tm->tm_hour, tm->tm_min, tm->tm_sec);

	pid = getpid();
	if (pid == parentpid)
		fprintf(stderr, "*** ");
	else
		#ifdef __sun
		fprintf(stderr, "%li ", pid);
		#else
		fprintf(stderr, "%d ", pid);
		#endif

	va_start(ap, format);
	vfprintf(stderr, format, ap);
	va_end(ap);
	fprintf(stderr, "\n");
	fflush(stderr);
}

void
read_config(FILE *f)
{
	int	i;
	char	b[1024];

	i = 0;

	while (fread(b + i, 1, 1, f)) {
		if (b[i] == '\n' || i == 1023) {
			b[i] = 0;
			if (i > 0 && b[0] != '#') {
				int j = 0, k;
				while (b[j] && isspace((int)b[j])) j++;
				k = j;
				while (b[k] && !isspace((int)b[k])) k++;
				if (b[k])
					b[k++] = 0;
				while (b[k] && isspace((int)b[k])) k++;
				if (!strcmp(b + j, "auth"))
					config.auth = strdup(base64e(b + k));
				else if (!strcmp(b + j, "listen"))
					parse_host(b + k, &config.addr_local,
					    &config.port_local);
				else if (!strcmp(b + j, "proxy"))
					parse_host(b + k, &config.addr_proxy,
					    &config.port_proxy);
				else if (!strcmp(b+j, "external"))
					parse_host(b + k, &config.addr_ext,
					    &config.port_ext);
				else
					fprintf(stderr, "unknown option %s\n",
					    b + j);
			}
			i = 0;
		} else
			i++;
	}
}

void
parse_host(char *s, char **addr, unsigned *port)
{
	int	i;

	i = 0;

	while (s[i] && s[i] != ':') i++;

	if (s[i])
		s[i++] = 0;

	*addr = strdup(s);
	*port = atoi(s + i);
}

char *
base64e(char *string)
{
	int		i, j, k;
	unsigned char	in_a[3], out[4];
	static char	buf[128];

	k = 0;
	memset(buf, 0, sizeof(buf));

	for (;;) {
		in_a[0] = in_a[1] = in_a[2] = 0;

		for (i = 0; i < 3; i++) {
			if (*string == '\0')
				break;
			in_a[i] = *string;
			string++;
		}

		/* split the 24-bit field in 4 6-bit */
		out[0] = encode[(in_a[0] >> 2)];
		out[1] = encode[((in_a[0] & 3) << 4) | (in_a[1] >> 4)];
		out[2] = i < 2 ? '=' : encode[((in_a[1] & 15) << 2) |
		    (in_a[2] >> 6)];
		out[3] = i < 3 ? '=' : encode[(in_a[2] & 63)];

		/* fill up buffer */
		for (j = 0; j < 4; j++) {
			buf[k] = out[j];
			k++;
		}

		/* finished */
		if (i != 3)
			break;
	}

	return buf;
}


syntax highlighted by Code2HTML, v. 0.9.1