"""Test the PersistentPg module. Note: We don't test performance here, so the test does not predicate whether PersistentPg actually will help in improving performance or not. We also assume that the underlying SteadyPg connections are tested. Copyright and credit info: * This test was contributed by Christoph Zwerschke """ __version__ = '0.9.4' __revision__ = "$Rev: 6696 $" __date__ = "$Date: 2007-07-07 11:02:24 -0600 (Sat, 07 Jul 2007) $" import sys import unittest sys.path.insert(1, '../..') # The TestSteadyPg module serves as a mock object for the pg API module: from DBUtils.Testing import TestSteadyPg from DBUtils.PersistentPg import PersistentPg class TestPersistentPg(unittest.TestCase): def test0_CheckVersion(self): TestPersistentPgVersion = __version__ from DBUtils.PersistentPg import __version__ as PersistentPgVersion self.assertEqual(PersistentPgVersion, TestPersistentPgVersion) def test1_PersistentPgClose(self): for closeable in (0, 1): persist = PersistentPg() persist._closeable = closeable db = persist.connection() self.assert_(db._con.db and db._con.valid) db.close() self.assert_(closeable ^ (db._con.db is not None and db._con.valid)) db.close() self.assert_(closeable ^ (db._con.db is not None and db._con.valid)) db._close() self.assert_(not db._con.db or not db._con.valid) db._close() self.assert_(not db._con.db or not db._con.valid) def test2_PersistentPgThreads(self): numThreads = 3 persist = PersistentPg() from Queue import Queue, Empty queryQueue, resultQueue = [], [] for i in range(numThreads): queryQueue.append(Queue(1)) resultQueue.append(Queue(1)) def runQueries(i): this_db = persist.connection().db while 1: try: try: q = queryQueue[i].get(1, 1) except TypeError: q = queryQueue[i].get(1) except Empty: q = None if not q: break db = persist.connection() if db.db != this_db: r = 'error - not persistent' else: if q == 'ping': r = 'ok - thread alive' elif q == 'close': db.db.close() r = 'ok - connection closed' else: r = db.query(q) r = '%d(%d): %s' % (i, db._usage, r) try: resultQueue[i].put(r, 1, 0.1) except TypeError: resultQueue[i].put(r, 1) db.close() from threading import Thread threads = [] for i in range(numThreads): thread = Thread(target=runQueries, args=(i,)) threads.append(thread) thread.start() for i in range(numThreads): try: queryQueue[i].put('ping', 1, 0.1) except TypeError: queryQueue[i].put('ping', 1) for i in range(numThreads): try: r = resultQueue[i].get(1, 0.1) except TypeError: r = resultQueue[i].get(1) self.assertEqual(r, '%d(0): ok - thread alive' % i) self.assert_(threads[i].isAlive()) for i in range(numThreads): for j in range(i + 1): try: queryQueue[i].put('select test%d' % j, 1, 0.1) r = resultQueue[i].get(1, 0.1) except TypeError: queryQueue[i].put('select test%d' % j, 1) r = resultQueue[i].get(1) self.assertEqual(r, '%d(%d): test%d' % (i, j + 1, j)) try: queryQueue[1].put('select test4', 1, 0.1) r = resultQueue[1].get(1, 0.1) except TypeError: queryQueue[1].put('select test4', 1) r = resultQueue[1].get(1) self.assertEqual(r, '1(3): test4') try: queryQueue[1].put('close', 1, 0.1) r = resultQueue[1].get(1, 0.1) except TypeError: queryQueue[1].put('close', 1) r = resultQueue[1].get(1) self.assertEqual(r, '1(3): ok - connection closed') for j in range(2): try: queryQueue[1].put('select test%d' % j, 1, 0.1) r = resultQueue[1].get(1, 0.1) except TypeError: queryQueue[1].put('select test%d' % j, 1) r = resultQueue[1].get(1) self.assertEqual(r, '1(%d): test%d' % (j + 1, j)) for i in range(numThreads): self.assert_(threads[i].isAlive()) try: queryQueue[i].put('ping', 1, 0.1) except TypeError: queryQueue[i].put('ping', 1) for i in range(numThreads): try: r = resultQueue[i].get(1, 0.1) except TypeError: r = resultQueue[i].get(1) self.assertEqual(r, '%d(%d): ok - thread alive' % (i, i + 1)) self.assert_(threads[i].isAlive()) for i in range(numThreads): try: queryQueue[i].put(None, 1, 1) except TypeError: queryQueue[i].put(None, 1) def test3_PersistentPgMaxUsage(self): persist = PersistentPg(20) db = persist.connection() self.assertEqual(db._maxusage, 20) for i in range(100): r = db.query('select test%d' % i) self.assertEqual(r, 'test%d' % i) self.assert_(db.db.status) j = i % 20 + 1 self.assertEqual(db._usage, j) self.assertEqual(db.num_queries, j) def test4_PersistentPgSetSession(self): persist = PersistentPg(3, ('set datestyle',)) db = persist.connection() self.assertEqual(db._maxusage, 3) self.assertEqual(db._setsession_sql, ('set datestyle',)) self.assertEqual(db.db.session, ['datestyle']) db.query('set test') for i in range(3): self.assertEqual(db.db.session, ['datestyle', 'test']) db.query('select test') self.assertEqual(db.db.session, ['datestyle']) if __name__ == '__main__': unittest.main()