# Copyright 2009-2010 10gen, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test built in connection-pooling.""" import os import random import sys import threading import time import unittest sys.path[0:0] = [""] from nose.plugins.skip import SkipTest from pymongo.connection import Connection, _Pool from pymongo.errors import ConfigurationError from test_connection import get_connection N = 50 DB = "pymongo-pooling-tests" class MongoThread(threading.Thread): def __init__(self, test_case): threading.Thread.__init__(self) self.connection = test_case.c self.db = self.connection[DB] self.ut = test_case class SaveAndFind(MongoThread): def run(self): for _ in xrange(N): rand = random.randint(0, N) id = self.db.sf.save({"x": rand}, safe=True) self.ut.assertEqual(rand, self.db.sf.find_one(id)["x"]) self.connection.end_request() class Unique(MongoThread): def run(self): for _ in xrange(N): self.db.unique.insert({}) self.ut.assertEqual(None, self.db.error()) self.connection.end_request() class NonUnique(MongoThread): def run(self): for _ in xrange(N): self.db.unique.insert({"_id": "mike"}) self.ut.assertNotEqual(None, self.db.error()) self.connection.end_request() class Disconnect(MongoThread): def run(self): for _ in xrange(N): self.connection.disconnect() class NoRequest(MongoThread): def run(self): errors = 0 for _ in xrange(N): self.db.unique.insert({"_id": "mike"}) if self.db.error() is None: errors += 1 self.ut.assertEqual(0, errors) def run_cases(ut, cases): threads = [] for case in cases: for i in range(10): thread = case(ut) thread.start() threads.append(thread) for t in threads: t.join() class OneOp(threading.Thread): def __init__(self, connection): threading.Thread.__init__(self) self.c = connection def run(self): assert len(self.c._Connection__pool.sockets) == 1 self.c.test.test.find_one() assert len(self.c._Connection__pool.sockets) == 0 self.c.end_request() assert len(self.c._Connection__pool.sockets) == 1 class CreateAndReleaseSocket(threading.Thread): def __init__(self, connection): threading.Thread.__init__(self) self.c = connection def run(self): self.c.test.test.find_one() time.sleep(1) self.c.end_request() class TestPooling(unittest.TestCase): def setUp(self): self.c = get_connection() # reset the db self.c.drop_database(DB) self.c[DB].unique.insert({"_id": "mike"}) self.c[DB].unique.find_one() def tearDown(self): self.c = None def test_max_pool_size_validation(self): self.assertRaises(ValueError, Connection, max_pool_size=-1) self.assertRaises(ConfigurationError, Connection, max_pool_size='foo') c = Connection(max_pool_size=100) self.assertEqual(c.max_pool_size, 100) def test_no_disconnect(self): run_cases(self, [NoRequest, NonUnique, Unique, SaveAndFind]) def test_simple_disconnect(self): self.c.test.stuff.find() self.assertEqual(0, len(self.c._Connection__pool.sockets)) self.assertNotEqual(None, self.c._Connection__pool.sock) self.c.end_request() self.assertEqual(1, len(self.c._Connection__pool.sockets)) self.assertEqual(None, self.c._Connection__pool.sock) self.c.disconnect() self.assertEqual(0, len(self.c._Connection__pool.sockets)) self.assertEqual(None, self.c._Connection__pool.sock) def test_disconnect(self): run_cases(self, [SaveAndFind, Disconnect, Unique]) def test_independent_pools(self): p = _Pool(10, None, None, False) self.assertEqual([], p.sockets) self.c.end_request() self.assertEqual([], p.sockets) def test_dependent_pools(self): c = get_connection() self.assertEqual(1, len(c._Connection__pool.sockets)) c.test.test.find_one() self.assertEqual(0, len(c._Connection__pool.sockets)) c.end_request() self.assertEqual(1, len(c._Connection__pool.sockets)) t = OneOp(c) t.start() t.join() self.assertEqual(1, len(c._Connection__pool.sockets)) c.test.test.find_one() self.assertEqual(0, len(c._Connection__pool.sockets)) def test_multiple_connections(self): a = get_connection() b = get_connection() self.assertEqual(1, len(a._Connection__pool.sockets)) self.assertEqual(1, len(b._Connection__pool.sockets)) a.test.test.find_one() a.end_request() self.assertEqual(1, len(a._Connection__pool.sockets)) self.assertEqual(1, len(b._Connection__pool.sockets)) a_sock = a._Connection__pool.sockets[0] b.end_request() self.assertEqual(1, len(a._Connection__pool.sockets)) self.assertEqual(1, len(b._Connection__pool.sockets)) b.test.test.find_one() self.assertEqual(1, len(a._Connection__pool.sockets)) self.assertEqual(0, len(b._Connection__pool.sockets)) b.end_request() b_sock = b._Connection__pool.sockets[0] b.test.test.find_one() a.test.test.find_one() self.assertEqual(b_sock, b._Connection__pool.get_socket(b.host, b.port)[0]) self.assertEqual(a_sock, a._Connection__pool.get_socket(a.host, a.port)[0]) def test_pool_with_fork(self): if sys.platform == "win32": raise SkipTest() try: from multiprocessing import Process, Pipe except ImportError: raise SkipTest() a = get_connection() a.test.test.find_one() a.end_request() self.assertEqual(1, len(a._Connection__pool.sockets)) a_sock = a._Connection__pool.sockets[0] def loop(pipe): c = get_connection() self.assertEqual(1, len(c._Connection__pool.sockets)) c.test.test.find_one() self.assertEqual(0, len(c._Connection__pool.sockets)) c.end_request() self.assertEqual(1, len(c._Connection__pool.sockets)) pipe.send(c._Connection__pool.sockets[0].getsockname()) cp1, cc1 = Pipe() cp2, cc2 = Pipe() p1 = Process(target=loop, args=(cc1,)) p2 = Process(target=loop, args=(cc2,)) p1.start() p2.start() p1.join(1) p2.join(1) p1.terminate() p2.terminate() p1.join() p2.join() cc1.close() cc2.close() b_sock = cp1.recv() c_sock = cp2.recv() self.assert_(a_sock.getsockname() != b_sock) self.assert_(a_sock.getsockname() != c_sock) self.assert_(b_sock != c_sock) self.assertEqual(a_sock, a._Connection__pool.get_socket(a.host, a.port)[0]) def test_max_pool_size(self): c = get_connection(max_pool_size=4) threads = [] for i in range(40): t = CreateAndReleaseSocket(c) t.start() threads.append(t) for t in threads: t.join() # There's a race condition, so be lenient self.assert_(abs(4 - len(c._Connection__pool.sockets)) < 4) if __name__ == "__main__": unittest.main()