/**
 **	File ......... SocketHandler.cpp
 **	Published ....  2004-02-13
 **	Author ....... grymse@alhem.net
**/
/*
Copyright (C) 2004,2005  Anders Hedstrom

This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.

This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
*/
#include <stdio.h>
#ifdef _WIN32
#pragma warning(disable:4786)
#include <stdlib.h>
#else
#include <errno.h>
#endif

#include "TcpSocket.h"
#include <StdLog.h>
#include "SocketHandler.h"
#include "UdpSocket.h"
#include "PoolSocket.h"
#include "ResolvSocket.h"
#include "ResolvServer.h"


#ifdef _DEBUG
#define DEB(x) x
#else
#define DEB(x) 
#endif


SocketHandler::SocketHandler(StdLog *p)
:m_stdlog(p)
,m_maxsock(0)
,m_host("")
,m_ip(0)
,m_preverror(-1)
,m_slave(false)
,m_local_resolved(false)
,m_socks4_host(0)
,m_socks4_port(0)
,m_bTryDirect(false)
,m_resolv_id(0)
,m_resolver(NULL)
{
	FD_ZERO(&m_rfds);
	FD_ZERO(&m_wfds);
	FD_ZERO(&m_efds);
}


SocketHandler::~SocketHandler()
{
	if (m_resolver)
		m_resolver -> Quit();
	if (!m_slave)
	{
		for (socket_m::iterator it = m_sockets.begin(); it != m_sockets.end(); it++)
		{
			Socket *p = (*it).second;
			p -> Close();
//			p -> OnDelete(); // hey, I turn this back on. what's the worst that could happen??!!
			// MinionSocket breaks, calling MinderHandler methods in OnDelete -
			// MinderHandler is already gone when that happens...
			if (p -> DeleteByHandler())
			{
				delete p;
			}
		}
	}
	if (m_resolver)
		delete m_resolver;
}


void SocketHandler::ResolveLocal()
{
	char h[256];

	// get local hostname and translate into ip-address
	*h = 0;
	gethostname(h,255);
	{
		Socket zl(*this);
		if (zl.u2ip(h, m_ip))
		{
			zl.l2ip(m_ip, m_addr);
		}
	}
#ifdef IPPROTO_IPV6
	memset(&m_local_ip6, 0, sizeof(m_local_ip6));
	{
		Socket zl(*this);
		zl.SetIpv6();
		if (zl.u2ip(h, m_local_ip6))
		{
			zl.l2ip(m_local_ip6, m_local_addr6);
		}
	}
#endif
	m_host = h;
	m_local_resolved = true;
}


void SocketHandler::Add(Socket *p)
{
	if (p -> GetSocket() == INVALID_SOCKET)
	{
		LogError(p, "Add", -1, "Invalid socket", LOG_LEVEL_FATAL);
		return;
	}
DEB(	printf("%s: add socket %d\n",m_slave ? "slave" : "master",p -> GetSocket());)
	m_add[p -> GetSocket()] = p;
}


void SocketHandler::Get(SOCKET s,bool& r,bool& w,bool& e)
{
	if (s >= 0)
	{
		r = FD_ISSET(s, &m_rfds) ? true : false;
		w = FD_ISSET(s, &m_wfds) ? true : false;
		e = FD_ISSET(s, &m_efds) ? true : false;
	}
}


void SocketHandler::Set(SOCKET s,bool bRead,bool bWrite,bool bException)
{
	if (s >= 0)
	{
		if (bRead)
		{
			if (!FD_ISSET(s, &m_rfds))
			{
				FD_SET(s, &m_rfds);
			}
		}
		else
		{
			FD_CLR(s, &m_rfds);
		}
		if (bWrite)
		{
			if (!FD_ISSET(s, &m_wfds))
			{
				FD_SET(s, &m_wfds);
			}
		}
		else
		{
			FD_CLR(s, &m_wfds);
		}
		if (bException)
		{
			if (!FD_ISSET(s, &m_efds))
			{
				FD_SET(s, &m_efds);
			}
		}
		else
		{
			FD_CLR(s, &m_efds);
		}
	}
}


int SocketHandler::Select(long sec,long usec)
{
	struct timeval tv;
	fd_set rfds = m_rfds;
	fd_set wfds = m_wfds;
	fd_set efds = m_efds;
	int n;

	while (m_add.size() && m_sockets.size() < FD_SETSIZE )
	{
		socket_m::iterator it = m_add.begin();
		SOCKET s = (*it).first;
		Socket *p = (*it).second;
		// call Open before Add'ing a socket...
		if (p -> Connecting())
		{
			Set(s,false,true);
		}
		else
		{
			if (p -> IsDisableRead())
				Set(s, false, false);
			else
				Set(s,true,false);
		}
		m_maxsock = (s > m_maxsock) ? s : m_maxsock;
		m_sockets[s] = p;
		m_add.erase(it);
	}
	tv.tv_sec = sec;
	tv.tv_usec = usec;
	n = select( (int)(m_maxsock + 1),&rfds,&wfds,&efds,&tv);
	if (n == -1)
	{
		LogError(NULL, "select", Errno, StrError(Errno));
#ifdef _WIN32
DEB(
		int errcode = Errno;
		if (errcode != m_preverror)
		{
			printf("  select() errcode = %d\n",errcode);
			m_preverror = errcode;
			for (size_t i = 0; i <= m_maxsock; i++)
			{
				if (FD_ISSET(i, &m_rfds))
					printf("%4d: Read\n",i);
				if (FD_ISSET(i, &m_wfds))
					printf("%4d: Write\n",i);
				if (FD_ISSET(i, &m_efds))
					printf("%4d: Exception\n",i);
			}
		}
) // DEB
#else
DEB(		printf("slave: %s\n",m_slave ? "YES" : "NO");
		exit(-1);)
#endif
	}
	else
//	if (n > 0)
	{
		for (socket_m::iterator it2 = m_sockets.begin(); it2 != m_sockets.end(); it2++)
		{
			SOCKET i = (*it2).first;
			Socket *p = (*it2).second;
			if (p)
			{
				if (p -> CallOnConnect() && p -> Ready() )
				{
					if (p -> IsSSL()) // SSL Enabled socket
						p -> OnSSLConnect();
					else
					if (p -> Socks4())
						p -> OnSocks4Connect();
					else
						p -> OnConnect();
					p -> SetCallOnConnect( false );
				}
				// new SSL negotiate method
				if (p -> IsSSLNegotiate())
				{
					p -> SSLNegotiate();
				}
				else
				// old SSL method...
				if (p -> SSLConnecting())
				{
					if (p -> SSLCheckConnect())
					{
						p -> OnSSLInitDone();
					}
				}
				else
				if (n > 0)
				{
					if (FD_ISSET(i, &rfds))
					{
						TcpSocket *tcp = dynamic_cast<TcpSocket *>(p);
						// LockWrite (save total output buffer size)
						// Sockets with write lock won't call OnWrite in SendBuf
						// That will happen in UnlockWrite, if necessary
						p -> OnRead();
						bool need_more = false;
						while (tcp && p -> Socks4() && tcp -> GetInputLength() && !need_more && !p -> CloseAndDelete())
						{
							need_more = p -> OnSocks4Read();
						}
						if (!p -> Socks4())
						{
							if (p -> LineProtocol())
							{
								p -> ReadLine();
							}
//							p -> Touch();
						}
						// UnlockWrite (call OnWrite if saved size == 0 && total output buffer size > 0)
					}
					if (FD_ISSET(i, &wfds))
					{
						if (p -> Connecting())
						{
							if (p -> CheckConnect())
							{
								if (p -> IsSSL()) // SSL Enabled socket
									p -> OnSSLConnect();
								else
								if (p -> Socks4())
									p -> OnSocks4Connect();
								else
									p -> OnConnect();
							}
							else
							{
								// failed
								if (p -> Socks4())
								{
									p -> OnSocks4ConnectFailed();
								}
								else
								{
//									LogError(p, "connect failed", Errno, StrError(Errno), LOG_LEVEL_FATAL);
									p -> SetCloseAndDelete( true );
									p -> OnConnectFailed();
								}
							}
//								p -> Touch();
						}
						else
						{
							p -> OnWrite();
//								p -> Touch();
						}
					}
					if (FD_ISSET(i, &efds))
					{
						p -> OnException();
					}
				}
			} // if (p)
		} // for
	}

	bool repeat;
	do
	{
		repeat = false;
		for (socket_m::iterator it3 = m_sockets.begin(); it3 != m_sockets.end(); it3++)
		{
//			SOCKET s = (*it3).first;
			Socket *p = (*it3).second;
			if (p)
			{
				if (!m_slave && p -> IsDetach())
				{
					Set(p -> GetSocket(), false, false, false);
					p -> DetachSocket();
					m_sockets.erase(it3);
					repeat = true;
					break;
				}
/*
				if (p && p -> Timeout() && p -> Inactive() > p -> Timeout())
				{
					p -> SetCloseAndDelete();
				}
*/
				if (p && p -> Connecting() && p -> GetConnectTime() > p -> GetConnectTimeout() )
				{
					LogError(p, "connect", -1, "connect timeout", LOG_LEVEL_FATAL);
					if (p -> Socks4())
					{
						p -> OnSocks4ConnectFailed();
						// retry direct connection
					}
					else
					{
						p -> SetCloseAndDelete(true);
						p -> OnConnectFailed();
					}
				}
				if (p && p -> CloseAndDelete() )
				{
//DEB(printf("%s: calling Close for socket %d\n",m_slave ? "slave" : "master",s);)
					if (p -> Retain() && !p -> Lost())
					{
						PoolSocket *p2 = new PoolSocket(*this, p);
						p2 -> SetDeleteByHandler();
						Add(p2);
//printf("Adding PoolSocket...\n");
					}
					else
					{
						Set(p -> GetSocket(),false,false,false);
						p -> Close();
					}
					p -> OnDelete();
					if (p -> DeleteByHandler())
					{
						delete p;
					}
					m_sockets.erase(it3);
					repeat = true;
					break;
				}
			} // if (p)
		}
		if (repeat)
		{
			m_maxsock = 0;
			for (socket_m::iterator it = m_sockets.begin(); it != m_sockets.end(); it++)
			{
				SOCKET s = (*it).first;
				m_maxsock = s > m_maxsock ? s : m_maxsock;
			}
			for (socket_m::iterator it3 = m_add.begin(); it3 != m_add.end(); it3++)
			{
				SOCKET s = (*it3).first;
				m_maxsock = s > m_maxsock ? s : m_maxsock;
			}
		}
	} while (repeat);
	return n;
}


const std::string& SocketHandler::GetLocalHostname()
{
	if (!m_local_resolved)
		LogError(NULL, "GetLocalHostname", 0, "local address not resolved");
	return m_host;
}


ipaddr_t SocketHandler::GetLocalIP()
{
	if (!m_local_resolved)
		LogError(NULL, "GetLocalHostname", 0, "local address not resolved");
	return m_ip;
}


const std::string& SocketHandler::GetLocalAddress()
{
	if (!m_local_resolved)
		LogError(NULL, "GetLocalHostname", 0, "local address not resolved");
	return m_addr;
}


bool SocketHandler::Valid(Socket *p0)
{
	for (socket_m::iterator it3 = m_sockets.begin(); it3 != m_sockets.end(); it3++)
	{
		Socket *p = (*it3).second;
		if (p0 == p)
			return true;
	}
	return false;
}


void SocketHandler::RegStdLog(StdLog *x)
{
	m_stdlog = x;
}


bool SocketHandler::OkToAccept()
{
	return true;
}


size_t SocketHandler::GetCount()
{
	return m_sockets.size() + m_add.size();
}


void SocketHandler::SetSlave(bool x)
{
	m_slave = x;
}


void SocketHandler::LogError(Socket *p,const std::string& user_text,int err,const std::string& sys_err,loglevel_t t)
{
	if (m_stdlog)
	{
		m_stdlog -> error(this, p, user_text, err, sys_err, t);
	}
}


#ifdef IPPROTO_IPV6
const struct in6_addr& SocketHandler::GetLocalIP6()
{
	if (!m_local_resolved)
		LogError(NULL, "GetLocalHostname", 0, "local address not resolved");
	return m_local_ip6;
}
#endif


const std::string& SocketHandler::GetLocalAddress6()
{
	if (!m_local_resolved)
		LogError(NULL, "GetLocalHostname", 0, "local address not resolved");
	return m_local_addr6;
}


PoolSocket *SocketHandler::FindConnection(int type,const std::string& protocol,ipaddr_t a,port_t port)
{
	for (socket_m::iterator it = m_sockets.begin(); it != m_sockets.end() && m_sockets.size(); it++)
	{
		PoolSocket *pools = dynamic_cast<PoolSocket *>((*it).second);
		if (pools)
		{
			if (pools -> GetSocketType() == type &&
			    pools -> GetSocketProtocol() == protocol &&
			    pools -> GetClientRemoteAddr() == a &&
			    pools -> GetClientRemotePort() == port)
			{
DEB(printf("FindConnection() successful\n");)
				m_sockets.erase(it);
				pools -> SetRetain(); // avoid Close in Socket destructor
				return pools; // Caller is responsible that this socket is deleted
			}
		}
	}
DEB(printf("FindConnection() NOT successful\n");)
	return NULL;
}


void SocketHandler::SetSocks4Host(ipaddr_t a)
{
	m_socks4_host = a;
}


void SocketHandler::SetSocks4Host(const std::string& host)
{
	Socket s(*this);
	s.u2ip(host, m_socks4_host);
}


void SocketHandler::SetSocks4Port(port_t port)
{
	m_socks4_port = port;
}


void SocketHandler::SetSocks4Userid(const std::string& id)
{
	m_socks4_userid = id;
}


int SocketHandler::Resolve(Socket *p,const std::string& host,port_t port)
{
	// check cache
	ResolvSocket *resolv = new ResolvSocket(*this, p);
	resolv -> SetId(++m_resolv_id);
	resolv -> SetHost(host);
	resolv -> SetPort(port);
	resolv -> SetDeleteByHandler();
	ipaddr_t local;
	resolv -> u2ip("127.0.0.1", local);
	if (!resolv -> Open(local, m_resolver_port))
	{
		LogError(resolv, "Resolve", -1, "Can't connect to local resolve server", LOG_LEVEL_FATAL);
	}
	Add(resolv);
	return m_resolv_id;
}


void SocketHandler::EnableResolver(port_t port)
{
	if (!m_resolver)
	{
		m_resolver_port = port;
		m_resolver = new ResolvServer(port);
	}
}




syntax highlighted by Code2HTML, v. 0.9.1