/**
** File ......... SSLSocket.cpp
** Published .... 2004-02-13
** Author ....... grymse@alhem.net
**/
/*
Copyright (C) 2004,2005 Anders Hedstrom
This program is free software; you can redistribute it and/or
modify it under the terms of the GNU General Public License
as published by the Free Software Foundation; either version 2
of the License, or (at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
*/
#ifdef _WIN32
#pragma warning(disable:4786)
#endif
#ifdef HAVE_OPENSSL
#include <assert.h>
#include "SSLSocket.h"
#include <openssl/rand.h>
#ifdef _WIN32
#define strcasecmp stricmp
#endif
#define DEB(x)
#define D2(x)
// statics
BIO *SSLSocket::bio_err = NULL;
std::string SSLSocket::m_password = "";
SSLSocket::SSLSocket(SocketHandler& h)
:TcpSocket(h)
,m_context(NULL)
,is_client(false)
,is_server(false)
,m_ssl(NULL)
,m_sbio(NULL)
{
D2( printf("SSLSocket()\n");)
}
SSLSocket::~SSLSocket()
{
D2( printf("~SSLSocket()\n");)
if (m_ssl)
{
DEB( printf("SSL_free()\n");)
SSL_free(m_ssl);
}
if (m_context)
{
DEB( printf("SSL_CTX_free()\n");)
SSL_CTX_free(m_context);
}
}
void SSLSocket::OnConnect()
{
D2( printf("SSLSocket::OnConnect()\n");)
SetNonblocking(true);
if (!is_client)
{
if (m_context)
{
DEB( printf("SSL Context already initialized - closing socket\n");)
SetCloseAndDelete(true);
return;
}
DEB( printf("InitAsClient()\n");)
InitAsClient();
is_client = true;
}
if (m_context)
{
/* Connect the SSL socket */
m_ssl = SSL_new(m_context);
if (!m_ssl)
{
DEB( printf(" m_ssl is NULL\n");)
}
m_sbio = BIO_new_socket(GetSocket(), BIO_NOCLOSE);
if (!m_sbio)
{
DEB( printf(" m_sbio is NULL\n");)
}
SSL_set_bio(m_ssl, m_sbio, m_sbio);
SetSSLConnecting();
if (SSLCheckConnect())
{
OnSSLInitDone();
}
}
else
{
SetCloseAndDelete();
}
}
void SSLSocket::InitAsClient()
{
InitializeContext();
}
void SSLSocket::OnAccept()
{
D2( printf("SSLSocket::OnAccept()\n");)
SetNonblocking(true);
if (!is_server)
{
if (m_context)
{
DEB( printf("SSL Context already initialized - closing socket\n");)
SetCloseAndDelete(true);
return;
}
InitAsServer();
is_server = true;
}
if (m_context)
{
m_ssl = SSL_new(m_context);
if (!m_ssl)
{
DEB( printf(" m_ssl is NULL\n");)
}
m_sbio = BIO_new_socket(GetSocket(), BIO_NOCLOSE);
if (!m_sbio)
{
DEB( printf(" m_sbio is NULL\n");)
}
SSL_set_bio(m_ssl, m_sbio, m_sbio);
SetSSLConnecting();
if (SSLCheckConnect())
{
OnSSLInitDone();
}
}
}
void SSLSocket::InitAsServer()
{
assert(!"not implemented.\n");
}
bool SSLSocket::SSLCheckConnect()
{
if (is_client) // SSL_connect
{
D2( printf("SSLSocket::SSLCheckConnect() is_client\n");)
int r = SSL_connect(m_ssl);
D2( printf(" SSLCheckConnect is_client, SSL_connect returns %d\n",r);)
if (r > 0)
{
SetSSLConnecting(false);
CheckCertificateChain( "");//ServerHOST);
SetNonblocking(false);
return true;
}
else
if (!r)
{
SetSSLConnecting(false);
SetCloseAndDelete();
}
else
{
r = SSL_get_error(m_ssl, r);
if (r != SSL_ERROR_WANT_READ && r != SSL_ERROR_WANT_WRITE)
{
D2( printf("SSL_connect() failed - closing socket, return code: %d\n",r);)
SetSSLConnecting(false);
SetCloseAndDelete(true);
}
}
}
else
if (is_server)
{
D2( printf("SSLSocket::SSLCheckConnect() is_server\n");)
int r = SSL_accept(m_ssl);
D2( printf(" SSLCheckConnect is_server, SSL_accept returns %d\n",r);)
if (r > 0)
{
SetSSLConnecting(false);
CheckCertificateChain( "");//ClientHOST);
SetNonblocking(false);
return true;
}
else
if (!r)
{
SetSSLConnecting(false);
SetCloseAndDelete();
}
else
{
r = SSL_get_error(m_ssl, r);
if (r != SSL_ERROR_WANT_READ && r != SSL_ERROR_WANT_WRITE)
{
D2( printf("SSL_accept() failed - closing socket, return code: %d\n",r);)
SetSSLConnecting(false);
SetCloseAndDelete(true);
}
}
}
else
{
D2( printf("SSLSocket::SSLCheckConnect() is_NOTHING!!!!!!\n");)
SetSSLConnecting(false);
SetCloseAndDelete();
}
return false;
}
void SSLSocket::OnRead()
{
D2( printf("SSLSocket::OnRead()\n");)
if (!Ready())
return;
char buf[TCP_BUFSIZE_READ];
int n = SSL_read(m_ssl, buf, TCP_BUFSIZE_READ);
if (n == -1)
{
n = SSL_get_error(m_ssl, n);
switch (n)
{
case SSL_ERROR_NONE:
if (!ibuf.Write(buf, n))
{
// overflow
}
break;
case SSL_ERROR_ZERO_RETURN:
DEB( printf("SSL_read() returns zero - closing socket\n");)
SetCloseAndDelete(true);
break;
default:
{
DEB( printf("SSL read problem, errcode = %d\n",n);)
}
}
SetCloseAndDelete(true); // %!
DEB( perror("read() error");)
}
else
if (!n)
{
SetCloseAndDelete(true);
DEB( printf("read() returns 0\n");)
}
else
{
DEB( printf("SSLSocket OnRead read %d bytes\n",n);)
// for (size_t i = 0; i < n; i++)
// printf("%c",buf[i]);
if (!ibuf.Write(buf,n))
{
// overflow
DEB( printf(" *** overflow ibuf Write\n");)
}
}
}
void SSLSocket::OnWrite()
{
/*
if (!Handler().Valid(this))
return;
if (!Ready())
return;
*/
D2( printf("SSLSocket::OnWrite()\n");)
int n = SSL_write(m_ssl,obuf.GetStart(),obuf.GetL());
DEB( printf("OnWrite: %d bytes sent\n",n);)
if (n == -1)
{
SetCloseAndDelete(true);
DEB( perror("write() error");)
}
else
if (!n)
{
SetCloseAndDelete(true);
DEB( printf("write() returns 0\n");)
}
else
{
DEB( printf(" %d bytes written\n",n);)
obuf.Remove(n);
}
{
bool br;
bool bw;
bool bx;
Handler().Get(GetSocket(), br, bw, bx);
if (obuf.GetLength())
Set(br, true);
else
Set(br, false);
}
}
void SSLSocket::InitializeContext(SSL_METHOD *meth_in)
{
SSL_METHOD *meth;
if (!bio_err)
{
/* An error write context */
bio_err = BIO_new_fp(stderr, BIO_NOCLOSE);
/* Global system initialization*/
SSL_library_init();
SSL_load_error_strings();
OpenSSL_add_all_algorithms();
}
/* Create our context*/
meth = meth_in ? meth_in : SSLv3_method();
m_context = SSL_CTX_new(meth);
/* Load the CAs we trust*/
/*
if (!(SSL_CTX_load_verify_locations(m_context, CA_LIST, 0)))
{
DEB( printf("Couldn't read CA list\n");)
}
SSL_CTX_set_verify_depth(m_context, 1);
SSL_CTX_set_verify(m_context, SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb);
*/
/* Load randomness */
if (!(RAND_load_file(RANDOM, 1024*1024)))
{
DEB( printf("Couldn't load randomness\n");)
}
}
void SSLSocket::InitializeContext(const std::string& keyfile,const std::string& password,SSL_METHOD *meth_in)
{
SSL_METHOD *meth;
if (!bio_err)
{
/* An error write context */
bio_err = BIO_new_fp(stderr, BIO_NOCLOSE);
/* Global system initialization*/
SSL_library_init();
SSL_load_error_strings();
OpenSSL_add_all_algorithms();
}
/* Create our context*/
meth = meth_in ? meth_in : SSLv3_method();
m_context = SSL_CTX_new(meth);
/* Load our keys and certificates*/
if (!(SSL_CTX_use_certificate_file(m_context, keyfile.c_str(), SSL_FILETYPE_PEM)))
{
DEB( printf("Couldn't read certificate file\n");)
}
m_password = password;
SSL_CTX_set_default_passwd_cb(m_context, password_cb);
if (!(SSL_CTX_use_PrivateKey_file(m_context, keyfile.c_str(), SSL_FILETYPE_PEM)))
{
DEB( printf("Couldn't read key file\n");)
}
/* Load the CAs we trust*/
/*
if (!(SSL_CTX_load_verify_locations(m_context, CA_LIST, 0)))
{
DEB( printf("Couldn't read CA list\n");)
}
SSL_CTX_set_verify_depth(m_context, 1);
SSL_CTX_set_verify(m_context, SSL_VERIFY_PEER|SSL_VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb);
*/
/* Load randomness */
if (!(RAND_load_file(RANDOM, 1024*1024)))
{
DEB( printf("Couldn't load randomness\n");)
}
}
// static
int SSLSocket::verify_cb(int ok, X509_STORE_CTX *store)
{
char data[256];
if (!ok)
{
X509 *cert = X509_STORE_CTX_get_current_cert(store);
int depth = X509_STORE_CTX_get_error_depth(store);
int err = X509_STORE_CTX_get_error(store);
fprintf(stderr, "-Error with certificate at depth: %i\n", depth);
X509_NAME_oneline(X509_get_issuer_name(cert), data, 256);
fprintf(stderr, " issuer = %s\n", data);
X509_NAME_oneline(X509_get_subject_name(cert), data, 256);
fprintf(stderr, " subject = %s\n", data);
fprintf(stderr, " err %i:%s\n", err, X509_verify_cert_error_string(err));
}
return ok;
}
// static
int SSLSocket::password_cb(char *buf,int num,int rwflag,void *userdata)
{
if((size_t)num<m_password.size()+1)
return(0);
strcpy(buf,m_password.c_str());
return(m_password.size());
}
bool SSLSocket::CheckCertificateChain(const std::string& host)
{
X509 *peer;
char peer_CN[256];
if (SSL_get_verify_result(m_ssl) != X509_V_OK)
{
DEB( printf("Certificate doesn't verify\n");)
return false;
}
/*Check the cert chain. The chain length
is automatically checked by OpenSSL when we
set the verify depth in the ctx */
/*Check the common name*/
peer = SSL_get_peer_certificate(m_ssl);
if (!peer)
{
return false;
}
X509_NAME_get_text_by_NID(X509_get_subject_name(peer), NID_commonName, peer_CN, 256);
if (host.size() && strcasecmp(peer_CN,host.c_str()))
{
DEB( printf("Common name doesn't match host name\n");)
return false;
}
return true;
}
int SSLSocket::Close()
{
D2( printf("SSLSocket::Close()\n");)
if (m_ssl)
SSL_shutdown(m_ssl);
return Socket::Close();
}
#endif // HAVE_OPENSSL
syntax highlighted by Code2HTML, v. 0.9.1