"""Test the SimplePooledDB module. Note: We don't test performance here, so the test does not predicate whether SimplePooledDB actually will help in improving performance or not. We also do not test any real world DB-API 2 module, we just mock the basic connection functionality of an arbitrary module. Copyright and credit info: * This test was contributed by Christoph Zwerschke """ __version__ = '0.9.4' __revision__ = "$Rev: 6696 $" __date__ = "$Date: 2007-07-07 11:02:24 -0600 (Sat, 07 Jul 2007) $" import sys # This module also serves as a mock object for an arbitrary DB-API 2 module: dbModule = sys.modules[__name__] threadsafety = 1 def connect(database, user): return Connection(database, user) class Connection: def __init__(self, database, user): self.database = database self.user = user self.open_cursors = 0 def close(self): self.open_cursors = 0 def cursor(self): self.open_cursors += 1 import unittest sys.path.insert(1, '../..') from DBUtils import SimplePooledDB class TestSimplePooledDB(unittest.TestCase): def my_dbpool(self, mythreadsafety, maxConnections): global threadsafety threadsafety = mythreadsafety return SimplePooledDB.PooledDB(dbModule, maxConnections, 'SimplePooledDBTestDB', 'SimplePooledDBTestUser') def test0_check_version(self): self.assertEqual(SimplePooledDB.__version__, __version__) def test1_no_threadsafety(self): for threadsafety in (None, -1, 0, 4): self.assertRaises(SimplePooledDB.NotSupportedError, self.my_dbpool, threadsafety, 1) def test2_create_connection(self): for threadsafety in (1, 2, 3): dbpool = self.my_dbpool(threadsafety, 1) db = dbpool.connection() self.assert_(hasattr(db, 'cursor')) self.assert_(hasattr(db, 'open_cursors')) self.assertEqual(db.open_cursors, 0) self.assert_(hasattr(db, 'database')) self.assertEqual(db.database, 'SimplePooledDBTestDB') self.assert_(hasattr(db, 'user')) self.assertEqual(db.user, 'SimplePooledDBTestUser') db.cursor() self.assertEqual(db.open_cursors, 1) def test3_close_connection(self): for threadsafety in (1, 2, 3): dbpool = self.my_dbpool(threadsafety, 1) db = dbpool.connection() self.assertEqual(db.open_cursors, 0) db.cursor() self.assertEqual(db.open_cursors, 1) db.close() self.assert_(not hasattr(db, 'open_cursors')) db = dbpool.connection() self.assert_(hasattr(db, 'database')) self.assertEqual(db.database, 'SimplePooledDBTestDB') self.assert_(hasattr(db, 'user')) self.assertEqual(db.user, 'SimplePooledDBTestUser') self.assertEqual(db.open_cursors, 1) db.cursor() self.assertEqual(db.open_cursors, 2) def test4_two_connections(self): for threadsafety in (1, 2, 3): dbpool = self.my_dbpool(threadsafety, 2) db1 = dbpool.connection() for i in range(5): db1.cursor() db2 = dbpool.connection() self.assertNotEqual(db1, db2) for i in range(7): db2.cursor() self.assertEqual(db1.open_cursors, 5) self.assertEqual(db2.open_cursors, 7) db1.close() db1 = dbpool.connection() self.assertNotEqual(db1, db2) self.assert_(hasattr(db1, 'cursor')) for i in range(3): db1.cursor() self.assertEqual(db1.open_cursors, 8) db2.cursor() self.assertEqual(db2.open_cursors, 8) def test5_threadsafety_1(self): dbpool = self.my_dbpool(1, 2) from Queue import Queue, Empty queue = Queue(3) def connection(): queue.put(dbpool.connection()) from threading import Thread thread1 = Thread(target=connection).start() thread2 = Thread(target=connection).start() thread3 = Thread(target=connection).start() try: db1 = queue.get(1, 1) db2 = queue.get(1, 1) except TypeError: db1 = queue.get(1) db2 = queue.get(1) self.assertNotEqual(db1, db2) self.assertNotEqual(db1._con, db2._con) try: self.assertRaises(Empty, queue.get, 1, 0.1) except TypeError: self.assertRaises(Empty, queue.get, 0) db2.close() try: db3 = queue.get(1, 1) except TypeError: db3 = queue.get(1) self.assertNotEqual(db1, db3) self.assertNotEqual(db1._con, db3._con) def test6_threadsafety_2(self): for threadsafety in (2, 3): dbpool = self.my_dbpool(threadsafety, 2) db1 = dbpool.connection() db2 = dbpool.connection() for i in xrange(100): dbpool.connection().cursor() self.assertEqual(db1.open_cursors, 50) self.assertEqual(db2.open_cursors, 50) if __name__ == '__main__': unittest.main()