# Copyright 2009-2010 10gen, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the connection module.""" import datetime import os import signal import sys import time import thread import unittest sys.path[0:0] = [""] from nose.plugins.skip import SkipTest from bson.son import SON from bson.tz_util import utc from pymongo.connection import Connection from pymongo.database import Database from pymongo.errors import (AutoReconnect, ConfigurationError, ConnectionFailure, InvalidName, InvalidURI, OperationFailure) from test import version def get_connection(*args, **kwargs): host = os.environ.get("DB_IP", "localhost") port = int(os.environ.get("DB_PORT", 27017)) return Connection(host, port, *args, **kwargs) class TestConnection(unittest.TestCase): def setUp(self): self.host = os.environ.get("DB_IP", "localhost") self.port = int(os.environ.get("DB_PORT", 27017)) def test_types(self): self.assertRaises(TypeError, Connection, 1) self.assertRaises(TypeError, Connection, 1.14) self.assertRaises(TypeError, Connection, "localhost", "27017") self.assertRaises(TypeError, Connection, "localhost", 1.14) self.assertRaises(TypeError, Connection, "localhost", []) self.assertRaises(ConfigurationError, Connection, []) def test_constants(self): Connection.HOST = self.host Connection.PORT = self.port self.assert_(Connection()) Connection.HOST = "somedomainthatdoesntexist.org" Connection.PORT = 123456789 self.assertRaises(ConnectionFailure, Connection, connectTimeoutMS=600) self.assert_(Connection(self.host, self.port)) Connection.HOST = self.host Connection.PORT = self.port self.assert_(Connection()) def test_connect(self): self.assertRaises(ConnectionFailure, Connection, "somedomainthatdoesntexist.org", connectTimeoutMS=600) self.assertRaises(ConnectionFailure, Connection, self.host, 123456789) self.assert_(Connection(self.host, self.port)) def test_host_w_port(self): self.assert_(Connection("%s:%d" % (self.host, self.port))) self.assertRaises(ConnectionFailure, Connection, "%s:1234567" % (self.host,), self.port) def test_repr(self): self.assertEqual(repr(Connection(self.host, self.port)), "Connection('%s', %d)" % (self.host, self.port)) def test_getters(self): self.assertEqual(Connection(self.host, self.port).host, self.host) self.assertEqual(Connection(self.host, self.port).port, self.port) self.assertEqual(set([(self.host, self.port)]), Connection(self.host, self.port).nodes) def test_get_db(self): connection = Connection(self.host, self.port) def make_db(base, name): return base[name] self.assertRaises(InvalidName, make_db, connection, "") self.assertRaises(InvalidName, make_db, connection, "te$t") self.assertRaises(InvalidName, make_db, connection, "te.t") self.assertRaises(InvalidName, make_db, connection, "te\\t") self.assertRaises(InvalidName, make_db, connection, "te/t") self.assertRaises(InvalidName, make_db, connection, "te st") self.assert_(isinstance(connection.test, Database)) self.assertEqual(connection.test, connection["test"]) self.assertEqual(connection.test, Database(connection, "test")) def test_database_names(self): connection = Connection(self.host, self.port) connection.pymongo_test.test.save({"dummy": u"object"}) connection.pymongo_test_mike.test.save({"dummy": u"object"}) dbs = connection.database_names() self.assert_("pymongo_test" in dbs) self.assert_("pymongo_test_mike" in dbs) def test_drop_database(self): connection = Connection(self.host, self.port) self.assertRaises(TypeError, connection.drop_database, 5) self.assertRaises(TypeError, connection.drop_database, None) connection.pymongo_test.test.save({"dummy": u"object"}) dbs = connection.database_names() self.assert_("pymongo_test" in dbs) connection.drop_database("pymongo_test") dbs = connection.database_names() self.assert_("pymongo_test" not in dbs) connection.pymongo_test.test.save({"dummy": u"object"}) dbs = connection.database_names() self.assert_("pymongo_test" in dbs) connection.drop_database(connection.pymongo_test) dbs = connection.database_names() self.assert_("pymongo_test" not in dbs) def test_copy_db(self): c = Connection(self.host, self.port) self.assertRaises(TypeError, c.copy_database, 4, "foo") self.assertRaises(TypeError, c.copy_database, "foo", 4) self.assertRaises(InvalidName, c.copy_database, "foo", "$foo") c.pymongo_test.test.drop() c.drop_database("pymongo_test1") c.drop_database("pymongo_test2") c.pymongo_test.test.insert({"foo": "bar"}) self.assertFalse("pymongo_test1" in c.database_names()) self.assertFalse("pymongo_test2" in c.database_names()) c.copy_database("pymongo_test", "pymongo_test1") self.assert_("pymongo_test1" in c.database_names()) self.assertEqual("bar", c.pymongo_test1.test.find_one()["foo"]) c.copy_database("pymongo_test", "pymongo_test2", "%s:%d" % (self.host, self.port)) self.assert_("pymongo_test2" in c.database_names()) self.assertEqual("bar", c.pymongo_test2.test.find_one()["foo"]) if version.at_least(c, (1, 3, 3, 1)): c.drop_database("pymongo_test1") c.pymongo_test.add_user("mike", "password") self.assertRaises(OperationFailure, c.copy_database, "pymongo_test", "pymongo_test1", username="foo", password="bar") self.assertFalse("pymongo_test1" in c.database_names()) self.assertRaises(OperationFailure, c.copy_database, "pymongo_test", "pymongo_test1", username="mike", password="bar") self.assertFalse("pymongo_test1" in c.database_names()) c.copy_database("pymongo_test", "pymongo_test1", username="mike", password="password") self.assert_("pymongo_test1" in c.database_names()) self.assertEqual("bar", c.pymongo_test1.test.find_one()["foo"]) def test_iteration(self): connection = Connection(self.host, self.port) def iterate(): [a for a in connection] self.assertRaises(TypeError, iterate) # TODO this test is probably very dependent on the machine its running on # due to timing issues, but I want to get something in here. def test_low_network_timeout(self): c = None i = 0 n = 10 while c is None and i < n: try: c = Connection(self.host, self.port, network_timeout=0.0001) except AutoReconnect: i += 1 if i == n: raise SkipTest() coll = c.pymongo_test.test for _ in range(1000): try: coll.find_one() except AutoReconnect: pass except AssertionError: self.fail() def test_disconnect(self): c = Connection(self.host, self.port) coll = c.foo.bar c.disconnect() c.disconnect() coll.count() c.disconnect() c.disconnect() coll.count() def test_from_uri(self): c = Connection(self.host, self.port) self.assertEqual(c, Connection("mongodb://%s:%d" % (self.host, self.port))) c.admin.system.users.remove({}) c.pymongo_test.system.users.remove({}) c.admin.add_user("admin", "pass") c.admin.authenticate("admin", "pass") c.pymongo_test.add_user("user", "pass") self.assertRaises(ConfigurationError, Connection, "mongodb://foo:bar@%s:%d" % (self.host, self.port)) self.assertRaises(ConfigurationError, Connection, "mongodb://admin:bar@%s:%d" % (self.host, self.port)) self.assertRaises(ConfigurationError, Connection, "mongodb://user:pass@%s:%d" % (self.host, self.port)) Connection("mongodb://admin:pass@%s:%d" % (self.host, self.port)) self.assertRaises(ConfigurationError, Connection, "mongodb://admin:pass@%s:%d/pymongo_test" % (self.host, self.port)) self.assertRaises(ConfigurationError, Connection, "mongodb://user:foo@%s:%d/pymongo_test" % (self.host, self.port)) Connection("mongodb://user:pass@%s:%d/pymongo_test" % (self.host, self.port)) self.assert_(Connection("mongodb://%s:%d" % (self.host, self.port), slave_okay=True).slave_okay) self.assert_(Connection("mongodb://%s:%d/?slaveok=true;w=2" % (self.host, self.port)).slave_okay) c.admin.system.users.remove({}) c.pymongo_test.system.users.remove({}) def test_fork(self): # Test using a connection before and after a fork. if sys.platform == "win32": raise SkipTest() try: from multiprocessing import Process, Pipe except ImportError: raise SkipTest() db = Connection(self.host, self.port).pymongo_test # Failure occurs if the connection is used before the fork db.test.find_one() db.connection.end_request() def loop(pipe): while True: try: db.test.insert({"a": "b"}, safe=True) for _ in db.test.find(): pass except: pipe.send(True) os._exit(1) cp1, cc1 = Pipe() cp2, cc2 = Pipe() p1 = Process(target=loop, args=(cc1,)) p2 = Process(target=loop, args=(cc2,)) p1.start() p2.start() p1.join(1) p2.join(1) p1.terminate() p2.terminate() p1.join() p2.join() cc1.close() cc2.close() # recv will only have data if the subprocess failed try: cp1.recv() self.fail() except EOFError: pass try: cp2.recv() self.fail() except EOFError: pass def test_document_class(self): c = Connection(self.host, self.port) db = c.pymongo_test db.test.insert({"x": 1}) self.assertEqual(dict, c.document_class) self.assert_(isinstance(db.test.find_one(), dict)) self.assertFalse(isinstance(db.test.find_one(), SON)) c.document_class = SON self.assertEqual(SON, c.document_class) self.assert_(isinstance(db.test.find_one(), SON)) self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON)) c = Connection(self.host, self.port, document_class=SON) db = c.pymongo_test self.assertEqual(SON, c.document_class) self.assert_(isinstance(db.test.find_one(), SON)) self.assertFalse(isinstance(db.test.find_one(as_class=dict), SON)) c.document_class = dict self.assertEqual(dict, c.document_class) self.assert_(isinstance(db.test.find_one(), dict)) self.assertFalse(isinstance(db.test.find_one(), SON)) def test_timeouts(self): conn = Connection(self.host, self.port, connectTimeoutMS=300) self.assertEquals(0.3, conn._Connection__pool.conn_timeout) conn = Connection(self.host, self.port, socketTimeoutMS=300) self.assertEquals(0.3, conn._Connection__pool.net_timeout) def test_network_timeout(self): no_timeout = Connection(self.host, self.port) timeout_sec = 1 timeout = Connection(self.host, self.port, network_timeout=timeout_sec) no_timeout.pymongo_test.drop_collection("test") no_timeout.pymongo_test.test.insert({"x": 1}, safe=True) # A $where clause that takes a second longer than the timeout where_func = """function (doc) { var d = new Date().getTime() + (%f + 1) * 1000;; var x = new Date().getTime(); while (x < d) { x = new Date().getTime(); } return true; }""" % timeout_sec def get_x(db): return db.test.find().where(where_func).next()["x"] self.assertEqual(1, get_x(no_timeout.pymongo_test)) self.assertRaises(ConnectionFailure, get_x, timeout.pymongo_test) def get_x_timeout(db, t): return db.test.find( network_timeout=t).where(where_func).next()["x"] self.assertEqual(1, get_x_timeout(timeout.pymongo_test, None)) self.assertRaises(ConnectionFailure, get_x_timeout, no_timeout.pymongo_test, 0.1) def test_tz_aware(self): aware = Connection(self.host, self.port, tz_aware=True) naive = Connection(self.host, self.port) aware.pymongo_test.drop_collection("test") now = datetime.datetime.utcnow() aware.pymongo_test.test.insert({"x": now}, safe=True) self.assertEqual(None, naive.pymongo_test.test.find_one()["x"].tzinfo) self.assertEqual(utc, aware.pymongo_test.test.find_one()["x"].tzinfo) self.assertEqual( aware.pymongo_test.test.find_one()["x"].replace(tzinfo=None), naive.pymongo_test.test.find_one()["x"]) def test_ipv6(self): try: connection = Connection("[::1]") except: # Either mongod was started without --ipv6 # or the OS doesn't support it (or both). raise SkipTest() # Try a few simple things connection = Connection("mongodb://[::1]:%d" % (self.port,)) connection = Connection("mongodb://[::1]:%d/" "?slaveOk=true" % (self.port,)) connection = Connection("[::1]:%d," "localhost:%d" % (self.port, self.port)) connection = Connection("localhost:%d," "[::1]:%d" % (self.port, self.port)) connection.pymongo_test.test.save({"dummy": u"object"}) connection.pymongo_test_bernie.test.save({"dummy": u"object"}) dbs = connection.database_names() self.assert_("pymongo_test" in dbs) self.assert_("pymongo_test_bernie" in dbs) def test_fsync_lock_unlock(self): c = get_connection() self.assertFalse(c.is_locked) # async flushing not supported on windows... if sys.platform not in ('cygwin', 'win32'): c.fsync(async=True) self.assertFalse(c.is_locked) c.fsync(lock=True) self.assertTrue(c.is_locked) locked = True c.unlock() for _ in xrange(5): locked = c.is_locked if not locked: break time.sleep(1) self.assertFalse(locked) def test_contextlib(self): if sys.version_info < (2, 6): raise SkipTest() import contextlib conn = get_connection() conn.pymongo_test.drop_collection("test") conn.pymongo_test.test.insert({"foo": "bar"}) self.assertNotEqual(None, conn._Connection__pool.sock) self.assertEqual(0, len(conn._Connection__pool.sockets)) # We need exec here because if the Python version is less than 2.6 # these with-statements won't even compile. exec """ with contextlib.closing(conn): self.assertEquals("bar", conn.pymongo_test.test.find_one()["foo"]) """ self.assertEqual(None, conn._Connection__pool.sock) self.assertEqual(0, len(conn._Connection__pool.sockets)) exec """ with get_connection() as connection: self.assertEquals("bar", connection.pymongo_test.test.find_one()["foo"]) """ self.assertEqual(None, connection._Connection__pool.sock) self.assertEqual(0, len(connection._Connection__pool.sockets)) def test_interrupt_signal(self): # Test fix for PYTHON-294 -- make sure Connection closes its # socket if it gets an interrupt while waiting to recv() from it. c = get_connection() db = c.pymongo_test # A $where clause which takes 1.5 sec to execute where = '''function() { var d = new Date((new Date()).getTime() + 1.5 * 1000); while (d > (new Date())) { }; return true; }''' # Need exactly 1 document so find() will execute its $where clause once db.drop_collection('foo') db.foo.insert({'_id': 1}, safe=True) old_signal_handler = None try: # Platform-specific hacks for raising a KeyboardInterrupt on the main # thread while find() is in-progress: On Windows, SIGALRM is unavailable # so we use second thread. In our Bamboo setup on Linux, the thread # technique causes an error in the test at sock.recv(): # TypeError: 'int' object is not callable # We don't know what causes this in Bamboo, so we hack around it. if sys.platform == 'win32': def interrupter(): time.sleep(0.25) # Raises KeyboardInterrupt in the main thread thread.interrupt_main() thread.start_new_thread(interrupter, ()) else: # Convert SIGALRM to SIGINT -- it's hard to schedule a SIGINT for one # second in the future, but easy to schedule SIGALRM. def sigalarm(num, frame): raise KeyboardInterrupt old_signal_handler = signal.signal(signal.SIGALRM, sigalarm) signal.alarm(1) raised = False try: # Will be interrupted by a KeyboardInterrupt. db.foo.find({'$where': where}).next() except KeyboardInterrupt: raised = True # Can't use self.assertRaises() because it doesn't catch system # exceptions self.assert_(raised, "Didn't raise expected KeyboardInterrupt") # Raises AssertionError due to PYTHON-294 -- Mongo's response to the # previous find() is still waiting to be read on the socket, so the # request id's don't match. self.assertEqual( {'_id': 1}, db.foo.find().next() ) finally: if old_signal_handler: signal.signal(signal.SIGALRM, old_signal_handler) if __name__ == "__main__": unittest.main()