/*-
 * Copyright (c) 2001-2005 Christian S.J. Peron
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 */
#ifndef lint
static const char rcsid[] =
        "@(#) $Header: /usr/cvs/ipex/tcp_session.c,v 1.18 2005/06/24 02:55:05 modulus Exp $";
#endif

#include "ipex_includes.h"

#define HPARENT(npos) ((int)(((npos) - 1) / 2))
#define HLEFT(npos) (((npos) * 2) + 1)
#define HRIGHT(npos) (((npos) * 2) + 2)

TAILQ_HEAD(tailhead, conq) head = TAILQ_HEAD_INITIALIZER(head);
typedef struct {
	unsigned int values[5];
} uint128_t;

#ifndef TAILQ_FOREACH_SAFE
#define TAILQ_FOREACH_SAFE(var, head, field, tvar)			\
	for ((var) = TAILQ_FIRST((head));				\
	    (var) && ((tvar) = TAILQ_NEXT((var), field), 1);		\
	    (var) = (tvar))
#endif

char *
chomp(char *s)
{
	char *p;

	while((p = strchr(s, '.'))) {
		*p = '\0';
		p++;
	}
	return (s);
}

char *
tcp_genfname(u_char *data, struct pcap_pkthdr *h, u_char *p)
{
	struct ip *ip = (struct ip *)(p + opts.loffset);
	unsigned ip_hl = (ip->ip_hl * 0x4);
	struct tcphdr *tcp = ((struct tcphdr *)(((char *)ip) + ip_hl));
	static char fname[256];
	uint128_t hash;

	memcpy(&hash.values[0], &ip->ip_src,
	    sizeof(struct in_addr));
	memcpy(&hash.values[1], &ip->ip_dst,
	    sizeof(struct in_addr));
	hash.values[2] = tcp->th_sport;
	hash.values[3] = tcp->th_dport;
	hash.values[4] = time(NULL);
	snprintf(fname, sizeof(fname), "%08x%08x%04x%04x%08x.dmp",
	    hash.values[0], hash.values[1],
	    hash.values[2], hash.values[3],
	    hash.values[4]);
	return (fname);
}

void
disassemble_fname(char *filename)
{
	uint128_t hash;
	char true[256];
	int fd;
	off_t nbytes;
	struct stat sb;

	snprintf(true, sizeof(true), "%s.dmp", filename);
	if ((fd = open(true, O_RDONLY)) < 0) 
		err(1, "open failed");
	if (fstat(fd, &sb) < 0)
		err(1, "stat failed");
	nbytes = sb.st_size;
	close(fd);
	sscanf(filename, "%08x%08x%04x%04x%08x",
	    &hash.values[0], &hash.values[1],
	    &hash.values[2], &hash.values[3],
	    &hash.values[4]);
	printf(";;\n");
	printf(";; Filename: %s.dmp\n", filename);
	printf(";; Host(s)/Port(s): %s:%u and",
	    hostname(4, &hash.values[0]),
	    htons(hash.values[2]));
	printf(" %s:%u\n",
	    hostname(4, &hash.values[1]),
	    htons(hash.values[3]));
	printf(";; Time: %s", ctime((time_t *)&hash.values[4]));
	printf(";; Size: %llu bytes\n\n", nbytes);
}

int
compare(struct conq *conep, struct conq *ctwop)
{
	if (conep->hits > ctwop->hits)
		return (1);
	else
		return(-1);
}

void
heap_init(Heap *heap, int (*cmp)(const void *key1, const void *key2),
    void (*unalloc)(void *data))
{
	TAILQ_INIT(&head);
}

int
heap_extract(Heap *heap, void **data)
{
	int ipos, lpos, rpos, mpos;
	void *save, *tmp;

	if (HEAP_SIZE(heap) == 0)
		return(-1);
	*data = heap->btree[0];
	save = heap->btree[HEAP_SIZE(heap) - 1];
	if (HEAP_SIZE(heap) - 1 > 0) {
		tmp = memrealloc(heap->btree, (HEAP_SIZE(heap) - 1) *
		    sizeof(void *));
		heap->btree = tmp;
		heap->size--;
	} else {
		free(heap->btree);
		heap->btree = NULL;
		heap->size = 0;
		return (0);
	}
	heap->btree[0] = save;
	ipos = 0;
	lpos = HLEFT(ipos);
	rpos = HRIGHT(ipos);
	for (;;) {
		lpos = HLEFT(ipos);
		rpos = HRIGHT(ipos);
		if (lpos < HEAP_SIZE(heap) &&
		    heap->cmp(heap->btree[lpos], heap->btree[ipos]) > 0)
			mpos = lpos;
		else
			mpos = ipos;
		if (rpos < HEAP_SIZE(heap) &&
		    heap->cmp(heap->btree[rpos], heap->btree[mpos]) > 0)
			mpos = rpos;
		if (mpos == ipos)
			break;
		else {
			tmp = heap->btree[mpos];
			heap->btree[mpos] = heap->btree[ipos];
			heap->btree[ipos] = tmp;
			ipos = mpos;
		}
	}
	return (0);
}

int
tcp_sessionreaper(void)
{
	struct conq *cq, *cq2;
	time_t now;

	now = time(0);
	TAILQ_FOREACH_SAFE(cq, &head, glue, cq2) {
		if ((now - cq->lastwrite) > opts.Tflag) {
			printf("NOTE: closing TCP session file %s (timeout)\n",
			    cq->fname);
			TAILQ_REMOVE(&head, cq, glue);
			pcap_dump_close(cq->dumper);
			free(cq->rev);
			free(cq);
		}
	}
	return (0);
}

int
close_tcpsession(struct conq *cqp)
{
	/*
	 * Close the dump file and schedule it for removal
	 * from the TCP session queue.
	 */
	printf("NOTE: closing TCP session file %s (RST/FIN)\n",
	    cqp->fname);
	pcap_dump_close(cqp->dumper);
	cqp->dead = 1;
	return (0);
}

int
get_tcpsession(u_char *packet, Heap *heap, struct conq **cq)
{
	struct ip *ip = (struct ip *)(packet + opts.loffset);
	unsigned ip_hl = (ip->ip_hl * 0x4);
	struct tcphdr *tcp = ((struct tcphdr *)(((char *)ip) + ip_hl));
	struct conq *ctx;

	TAILQ_FOREACH(ctx, &head, glue) {
		if ((memcmp(&ip->ip_src, &ctx->ip_src,
		    sizeof(struct in_addr)) == 0) &&
		    (tcp->th_sport == ctx->sport)) {
			if (ctx->dead) {
				return (-1);
			}
			*cq = ctx;
			return (1);
		}
		if ((memcmp(&ip->ip_src, &ctx->rev->ip_src,
		    sizeof(struct in_addr)) == 0) &&
		    (tcp->th_sport == ctx->rev->sport)) {
			if (ctx->dead) {
				return (-1);
			}
			*cq = ctx;
			return (1);
		}
	}
	return (0);
}

void
heap_destory(Heap *heap)
{
	struct conq *cq, *cq2;

	TAILQ_FOREACH_SAFE(cq, &head, glue, cq2) {
		TAILQ_REMOVE(&head, cq, glue);
		close_tcpsession(cq);
		free(cq->rev);
		free(cq);
	}
}

int
heap_insert(Heap *heap, const void *data)
{
	struct conq *cq;

	cq = (struct conq *)data;
	TAILQ_INSERT_HEAD(&head, cq, glue);
	return (0);
}

void
handle_sessions(struct pcap_pkthdr *h, u_char *p)
{
	struct ip *ip = (struct ip *)(p + opts.loffset);
	int error;
	char *fname;
	struct conq *session;
	extern pcap_t *pd;

	tcp_sessionreaper();
	if (ip->ip_p == IPPROTO_TCP) {
		unsigned ip_hl = (ip->ip_hl * 0x4);
		struct tcphdr *tcp = ((struct tcphdr *)(((char *)ip)
			+ ip_hl));

		error = get_tcpsession(p, 0, &session);
		if (error == 0) {
			session = memalloc(sizeof(struct conq));
			session->hits++;
			fname = tcp_genfname(0, h, p);
			if ((session->dumper = pcap_dump_open(pd,
			    fname)) == NULL) {
				fprintf(stderr,
				    "fatal: pcap_dump_open: %s",
				    pcap_geterr(pd));
				cleanup();
			}
			session->dead = 0;
			session->fname = fname;
			session->ip_src.s_addr = ip->ip_src.s_addr;
			session->ip_dst.s_addr = ip->ip_dst.s_addr;
			session->sport = tcp->th_sport;
			session->dport = tcp->th_dport;
			session->rev = memalloc(sizeof(struct conq));
			session->rev->ip_dst.s_addr = ip->ip_src.s_addr;
			session->rev->ip_src.s_addr = ip->ip_dst.s_addr;
			session->rev->sport = tcp->th_dport;
			session->rev->dport = tcp->th_sport;
			heap_insert(0, session);
			pcap_dump((u_char *)session->dumper, h, p);
			session->lastwrite = time(0);
			fflush((FILE *)session->dumper);
			return;
		} else if (error == 1) {
			if (tcp->th_flags & (TH_FIN | TH_RST)) {
				/*
				 * write packet which closed the session
				 */
				pcap_dump((u_char *)session->dumper, h, p);
				close_tcpsession(session);
				return;
			}
			session->hits++;
			pcap_dump((u_char *)session->dumper, h, p);
			session->lastwrite = time(0);
			fflush((FILE *)session->dumper);
			return;
		} else if (error < 0)
			return;
	}
	return;
}

#define	NBUCKETS	16384
struct tcp_flow_id {
	struct timeval	tv;
	u_int32_t	dst_ip;
	u_int32_t	src_ip;
	u_int16_t	dst_port;
	u_int16_t	src_port;
	u_int32_t	seq;	/* XXX 32 bits? check the width of the seq */
	int		counter;
};

static __inline int  
hash_packet(struct tcp_flow_id *id)
{
	u_int32_t h;

	h = id->dst_ip ^ id->src_ip ^ id->dst_port ^ id->src_port;
	h &= NBUCKETS - 1;
	return (h);
}

static struct tcp_flow_id flows[NBUCKETS];

pcap_handler
statetrack(u_char *pargs, struct pcap_pkthdr *h, u_char *p) 
{
	struct ip *ip = (struct ip *)(p + opts.loffset);
	unsigned ip_hl = (ip->ip_hl * 0x4);
	struct tcphdr *tcp = ((struct tcphdr *)(((char *)ip) + ip_hl));
	struct tcp_flow_id id;
	u_int32_t hash;
	u_int32_t ack;

	id.dst_ip = ip->ip_dst.s_addr;
	id.src_ip = ip->ip_src.s_addr;
	id.dst_port = tcp->th_dport;
	id.src_port = tcp->th_sport;
	id.seq = ntohl(tcp->th_seq);
	hash = hash_packet(&id);
	/*
	 * If we are dealing with a bucket that has a timestamp in it,
	 * we have started to process a potential state. If the current
	 * timestamp minus the state creation timestamp is greater than
	 * what the user specified as a timeout, initialize the counter
	 * and timestamp to zero.
	 */
	if (flows[hash].tv.tv_sec != 0 &&
	    (time(0) - flows[hash].tv.tv_sec) > opts.tflag) {
		flows[hash].counter = 0;
		flows[hash].tv.tv_sec = 0;
	}
	if ((tcp->th_flags & TH_SYN) != 0 &&
	    (tcp->th_flags & TH_ACK) != 0) {
		memcpy(&flows[hash], &id, sizeof(id));
		gettimeofday(&flows[hash].tv, NULL);
		flows[hash].counter++;
		return (NULL);
	}
	/*
	 * In order to register the creation of a new TCP session, we watch
	 * only for packets with the TCP SYN+ACK flags set. This prevents
	 * registration of millions of TCP connections in the event of a SYN
	 * flood.
	 *
	 * We then look at the next packet in the sequence, and all of the
	 * following conditions must hold true before we register it:
	 *
	 * 1) Packet destination address must be equal to the source IP
	 *    address of the inital SYN+ACK. And vice versa for the source
	 *    address.
	 * 2) The packet must not have the RST flag set. If so then this could
	 *    be an indicator that there is a SYN+ACK port/host scan going on.
	 * 3) The packet's acknowledgement number must be equal to the sequence
	 *    number plus one of the origianl SYN+ACK packet.
	 *
	 * Although there are some edge cases which could result in invalid TCP
	 * connections being processed, this appears to work in the
	 * common case.
	 */
	if (flows[hash].counter > 0) 
		flows[hash].counter++;
	if (flows[hash].counter >= 2) {
		ack = ntohl(tcp->th_ack);
		if (flows[hash].src_ip == id.dst_ip &&
		    flows[hash].dst_ip == id.src_ip &&
		    (tcp->th_flags & TH_RST) == 0 &&
		    ack == (flows[hash].seq + 1)) {
			(void)printf("%s flowid: %08x %s > ",
			    time_stamp(&h->ts, opts.thiszone), hash,
			    hostname(4, &id.src_ip));
			(void)printf("%s %s\n", hostname(4, &id.dst_ip),
			    portname(opts.pflag, "tcp", id.dst_port));
		}
		flows[hash].counter = 0;
	}
	return (NULL);
}


syntax highlighted by Code2HTML, v. 0.9.1