// distribution boxbackup-0.10 (svn version: 494)
//  
// Copyright (c) 2003 - 2006
//      Ben Summers and contributors.  All rights reserved.
//  
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// 1. Redistributions of source code must retain the above copyright
//    notice, this list of conditions and the following disclaimer.
// 2. Redistributions in binary form must reproduce the above copyright
//    notice, this list of conditions and the following disclaimer in the
//    documentation and/or other materials provided with the distribution.
// 3. All use of this software and associated advertising materials must 
//    display the following acknowledgment:
//        This product includes software developed by Ben Summers.
// 4. The names of the Authors may not be used to endorse or promote
//    products derived from this software without specific prior written
//    permission.
// 
// [Where legally impermissible the Authors do not disclaim liability for 
// direct physical injury or death caused solely by defects in the software 
// unless it is modified by a third party.]
// 
// THIS SOFTWARE IS PROVIDED BY THE AUTHORS ``AS IS'' AND ANY EXPRESS OR
// IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED.  IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY DIRECT,
// INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
// HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
// STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
// ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
// POSSIBILITY OF SUCH DAMAGE.
//  
//  
//  
// --------------------------------------------------------------------------
//
// File
//		Name:    SocketStreamTLS.cpp
//		Purpose: Socket stream encrpyted and authenticated by TLS
//		Created: 2003/08/06
//
// --------------------------------------------------------------------------

#include "Box.h"

#define TLS_CLASS_IMPLEMENTATION_CPP
#include <openssl/ssl.h>
#include <openssl/bio.h>
#include <errno.h>
#include <fcntl.h>

#ifndef WIN32
#include <poll.h>
#endif

#include "SocketStreamTLS.h"
#include "SSLLib.h"
#include "ServerException.h"
#include "TLSContext.h"

#include "MemLeakFindOn.h"

// Allow 5 minutes to handshake (in milliseconds)
#define TLS_HANDSHAKE_TIMEOUT (5*60*1000)

// --------------------------------------------------------------------------
//
// Function
//		Name:    SocketStreamTLS::SocketStreamTLS()
//		Purpose: Constructor
//		Created: 2003/08/06
//
// --------------------------------------------------------------------------
SocketStreamTLS::SocketStreamTLS()
	: mpSSL(0), mpBIO(0)
{
	ResetCounters();
}

// --------------------------------------------------------------------------
//
// Function
//		Name:    SocketStreamTLS::SocketStreamTLS(int)
//		Purpose: Constructor, taking previously connected socket
//		Created: 2003/08/06
//
// --------------------------------------------------------------------------
SocketStreamTLS::SocketStreamTLS(int socket)
	: SocketStream(socket),
	  mpSSL(0), mpBIO(0)
{
}

// --------------------------------------------------------------------------
//
// Function
//		Name:    SocketStreamTLS::~SocketStreamTLS()
//		Purpose: Destructor
//		Created: 2003/08/06
//
// --------------------------------------------------------------------------
SocketStreamTLS::~SocketStreamTLS()
{
	if(mpSSL)
	{
		// Attempt to close to avoid problems
		Close();
		
		// And if that didn't work...
		if(mpSSL)
		{
			::SSL_free(mpSSL);
			mpSSL = 0;
			mpBIO = 0;	// implicity freed by the SSL_free call
		}
	}
	
	// If we only got to creating that BIO.
	if(mpBIO)
	{
		::BIO_free(mpBIO);
		mpBIO = 0;
	}
}


// --------------------------------------------------------------------------
//
// Function
//		Name:    SocketStreamTLS::Open(const TLSContext &, int, const char *, int)
//		Purpose: Open connection, and perform TLS handshake
//		Created: 2003/08/06
//
// --------------------------------------------------------------------------
void SocketStreamTLS::Open(const TLSContext &rContext, int Type, const char *Name, int Port)
{
	SocketStream::Open(Type, Name, Port);
	Handshake(rContext);
	ResetCounters();
}


// --------------------------------------------------------------------------
//
// Function
//		Name:    SocketStreamTLS::Handshake(const TLSContext &, bool)
//		Purpose: Perform TLS handshake
//		Created: 2003/08/06
//
// --------------------------------------------------------------------------
void SocketStreamTLS::Handshake(const TLSContext &rContext, bool IsServer)
{
	if(mpBIO || mpSSL) {THROW_EXCEPTION(ServerException, TLSAlreadyHandshaked)}

	// Create a BIO for this socket
	mpBIO = ::BIO_new(::BIO_s_socket());
	if(mpBIO == 0)
	{
		SSLLib::LogError("Create socket bio");
		THROW_EXCEPTION(ServerException, TLSAllocationFailed)
	}

	tOSSocketHandle socket = GetSocketHandle();
	BIO_set_fd(mpBIO, socket, BIO_NOCLOSE);
	
	// Then the SSL object
	mpSSL = ::SSL_new(rContext.GetRawContext());
	if(mpSSL == 0)
	{
		SSLLib::LogError("Create ssl");
		THROW_EXCEPTION(ServerException, TLSAllocationFailed)
	}

#ifndef WIN32
	// Make the socket non-blocking so timeouts on Read work
	// This is more portable than using ioctl with FIONBIO
	int statusFlags = 0;
	if(::fcntl(socket, F_GETFL, &statusFlags) < 0
	   || ::fcntl(socket, F_SETFL, statusFlags | O_NONBLOCK) == -1)
	{
		THROW_EXCEPTION(ServerException, SocketSetNonBlockingFailed)
	}
#endif
	
	// FIXME: This is less portable than the above. However, it MAY be needed
	// for cygwin, which has/had bugs with fcntl
	//
	// int nonblocking = true;
	// if(::ioctl(socket, FIONBIO, &nonblocking) == -1)
	// {
	// 	THROW_EXCEPTION(ServerException, SocketSetNonBlockingFailed)
	// }

	// Set the two to know about each other
	::SSL_set_bio(mpSSL, mpBIO, mpBIO);

	bool waitingForHandshake = true;
	while(waitingForHandshake)
	{
		// Attempt to do the handshake
		int r = 0;
		if(IsServer)
		{
			r = ::SSL_accept(mpSSL);
		}
		else
		{
			r = ::SSL_connect(mpSSL);
		}

		// check return code
		int se;
		switch((se = ::SSL_get_error(mpSSL, r)))
		{
		case SSL_ERROR_NONE:
			// No error, handshake succeeded
			waitingForHandshake = false;
			break;

		case SSL_ERROR_WANT_READ:
		case SSL_ERROR_WANT_WRITE:
			// wait for the requried data
			if(WaitWhenRetryRequired(se, TLS_HANDSHAKE_TIMEOUT) == false)
			{
				// timed out
				THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeTimedOut)
			}
			break;
			
		default: // (and SSL_ERROR_ZERO_RETURN)
			// Error occured
			if(IsServer)
			{
				SSLLib::LogError("Accept");
				THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeFailed)
			}
			else
			{
				SSLLib::LogError("Connect");
				THROW_EXCEPTION(ConnectionException, Conn_TLSHandshakeFailed)
			}
		}
	}
	
	// And that's it
}

// --------------------------------------------------------------------------
//
// Function
//		Name:    WaitWhenRetryRequired(int, int)
//		Purpose: Waits until the condition required by the TLS layer is met.
//				 Returns true if the condition is met, false if timed out.
//		Created: 2003/08/15
//
// --------------------------------------------------------------------------
bool SocketStreamTLS::WaitWhenRetryRequired(int SSLErrorCode, int Timeout)
{
	struct pollfd p;
	p.fd = GetSocketHandle();
	switch(SSLErrorCode)
	{
	case SSL_ERROR_WANT_READ:
		p.events = POLLIN;
		break;
		
	case SSL_ERROR_WANT_WRITE:
		p.events = POLLOUT;
		break;

	default:
		// Not good!
		THROW_EXCEPTION(ServerException, Internal)
		break;
	}
	p.revents = 0;
	switch(::poll(&p, 1, (Timeout == IOStream::TimeOutInfinite)?INFTIM:Timeout))
	{
	case -1:
		// error
		if(errno == EINTR)
		{
			// Signal. Do "time out"
			return false;
		}
		else
		{
			// Bad!
			THROW_EXCEPTION(ServerException, SocketPollError)
		}
		break;

	case 0:
		// Condition not met, timed out
		return false;
		break;

	default:
		// good to go!
		return true;
		break;
	}

	return true;
}

// --------------------------------------------------------------------------
//
// Function
//		Name:    SocketStreamTLS::Read(void *, int, int Timeout)
//		Purpose: See base class
//		Created: 2003/08/06
//
// --------------------------------------------------------------------------
int SocketStreamTLS::Read(void *pBuffer, int NBytes, int Timeout)
{
	if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}

	// Make sure zero byte reads work as expected
	if(NBytes == 0)
	{
		return 0;
	}
	
	while(true)
	{
		int r = ::SSL_read(mpSSL, pBuffer, NBytes);

		int se;
		switch((se = ::SSL_get_error(mpSSL, r)))
		{
		case SSL_ERROR_NONE:
			// No error, return number of bytes read
			mBytesRead += r;
			return r;
			break;

		case SSL_ERROR_ZERO_RETURN:
			// Connection closed
			MarkAsReadClosed();
			return 0;
			break;

		case SSL_ERROR_WANT_READ:
		case SSL_ERROR_WANT_WRITE:
			// wait for the requried data
			// Will only get once around this loop, so don't need to calculate timeout values
			if(WaitWhenRetryRequired(se, Timeout) == false)
			{
				// timed out
				return 0;
			}
			break;
			
		default:
			SSLLib::LogError("Read");
			THROW_EXCEPTION(ConnectionException, Conn_TLSReadFailed)
			break;
		}
	}
}

// --------------------------------------------------------------------------
//
// Function
//		Name:    SocketStreamTLS::Write(const void *, int)
//		Purpose: See base class
//		Created: 2003/08/06
//
// --------------------------------------------------------------------------
void SocketStreamTLS::Write(const void *pBuffer, int NBytes)
{
	if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}
	
	// Make sure zero byte writes work as expected
	if(NBytes == 0)
	{
		return;
	}
	
	// from man SSL_write
	//
	// SSL_write() will only return with success, when the
	// complete contents of buf of length num has been written.
	//
	// So no worries about partial writes and moving the buffer around
	
	while(true)
	{
		// try the write
		int r = ::SSL_write(mpSSL, pBuffer, NBytes);
		
		int se;
		switch((se = ::SSL_get_error(mpSSL, r)))
		{
		case SSL_ERROR_NONE:
			// No error, data sent, return success
			mBytesWritten += r;
			return;
			break;

		case SSL_ERROR_ZERO_RETURN:
			// Connection closed
			MarkAsWriteClosed();
			THROW_EXCEPTION(ConnectionException, Conn_TLSClosedWhenWriting)
			break;

		case SSL_ERROR_WANT_READ:
		case SSL_ERROR_WANT_WRITE:
			// wait for the requried data
			{
			#ifndef NDEBUG
				bool conditionmet = 
			#endif
				WaitWhenRetryRequired(se, IOStream::TimeOutInfinite);
				ASSERT(conditionmet);
			}
			break;
		
		default:
			SSLLib::LogError("Write");
			THROW_EXCEPTION(ConnectionException, Conn_TLSWriteFailed)
			break;
		}
	}
}

// --------------------------------------------------------------------------
//
// Function
//		Name:    SocketStreamTLS::Close()
//		Purpose: See base class
//		Created: 2003/08/06
//
// --------------------------------------------------------------------------
void SocketStreamTLS::Close()
{
	if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}

	// Base class to close
	SocketStream::Close();

	// Free resources
	::SSL_free(mpSSL);
	mpSSL = 0;
	mpBIO = 0;	// implicitly freed by SSL_free
}

// --------------------------------------------------------------------------
//
// Function
//		Name:    SocketStreamTLS::Shutdown()
//		Purpose: See base class
//		Created: 2003/08/06
//
// --------------------------------------------------------------------------
void SocketStreamTLS::Shutdown(bool Read, bool Write)
{
	if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}

	if(::SSL_shutdown(mpSSL) < 0)
	{
		SSLLib::LogError("Shutdown");
		THROW_EXCEPTION(ConnectionException, Conn_TLSShutdownFailed)
	}

	// Don't ask the base class to shutdown -- BIO does this, apparently.
}

// --------------------------------------------------------------------------
//
// Function
//		Name:    SocketStreamTLS::GetPeerCommonName()
//		Purpose: Returns the common name of the other end of the connection
//		Created: 2003/08/06
//
// --------------------------------------------------------------------------
std::string SocketStreamTLS::GetPeerCommonName()
{
	if(!mpSSL) {THROW_EXCEPTION(ServerException, TLSNoSSLObject)}

	// Get certificate
	X509 *cert = ::SSL_get_peer_certificate(mpSSL);
	if(cert == 0)
	{
		::X509_free(cert);
		THROW_EXCEPTION(ConnectionException, Conn_TLSNoPeerCertificate)
	}

	// Subject details	
	X509_NAME *subject = ::X509_get_subject_name(cert); 
	if(subject == 0)
	{
		::X509_free(cert);
		THROW_EXCEPTION(ConnectionException, Conn_TLSPeerCertificateInvalid)
	}
	
	// Common name
	char commonName[256];
	if(::X509_NAME_get_text_by_NID(subject, NID_commonName, commonName, sizeof(commonName)) <= 0)
	{
		::X509_free(cert);
		THROW_EXCEPTION(ConnectionException, Conn_TLSPeerCertificateInvalid)
	}
	// Terminate just in case
	commonName[sizeof(commonName)-1] = '\0';
	
	// Done.
	return std::string(commonName);
}


syntax highlighted by Code2HTML, v. 0.9.1