# Copyright (c) 2001-2004 Twisted Matrix Laboratories. # See LICENSE for details. from twisted.trial import unittest from twisted.web import server, static, client, error, util, resource from twisted.internet import reactor, defer, interfaces from twisted.python.util import sibpath from twisted.python import components try: from twisted.internet import ssl except: ssl = None import os serverCallID = None class LongTimeTakingResource(resource.Resource): def render(self, request): global serverCallID serverCallID = reactor.callLater(1, self.writeIt, request) return server.NOT_DONE_YET def writeIt(self, request): request.write("hello!!!") request.finish() class CookieMirrorResource(resource.Resource): def render(self, request): l = [] for k,v in request.received_cookies.items(): l.append((k, v)) l.sort() return repr(l) class RawCookieMirrorResource(resource.Resource): def render(self, request): return repr(request.getHeader('cookie')) class ErrorResource(resource.Resource): def render(self, request): request.setResponseCode(401) if request.args.get("showlength"): request.setHeader("content-length", "0") return "" class NoLengthResource(resource.Resource): def render(self, request): return "nolength" class HostHeaderResource(resource.Resource): def render(self, request): return request.received_headers["host"] class PayloadResource(resource.Resource): def render(self, request): data = request.content.read() if len(data) != 100 or int(request.received_headers["content-length"]) != 100: return "ERROR" return data class BrokenDownloadResource(resource.Resource): def render(self, request): # only sends 3 bytes even though it claims to send 5 request.setHeader("content-length", "5") request.write('abc') return '' class WebClientTestCase(unittest.TestCase): def _listen(self, site): return reactor.listenTCP(0, site, interface="127.0.0.1") def setUp(self): name = str(id(self)) + "_webclient" if not os.path.exists(name): os.mkdir(name) f = open(os.path.join(name, "file"), "wb") f.write("0123456789") f.close() r = static.File(name) r.putChild("redirect", util.Redirect("/file")) r.putChild("wait", LongTimeTakingResource()) r.putChild("error", ErrorResource()) r.putChild("nolength", NoLengthResource()) r.putChild("host", HostHeaderResource()) r.putChild("payload", PayloadResource()) r.putChild("broken", BrokenDownloadResource()) site = server.Site(r, timeout=None) self.port = self._listen(site) self.portno = self.port.getHost().port def tearDown(self): if serverCallID and serverCallID.active(): serverCallID.cancel() return self.port.stopListening() def getURL(self, path): return "http://127.0.0.1:%d/%s" % (self.portno, path) def testPayload(self): s = "0123456789" * 10 return client.getPage(self.getURL("payload"), postdata=s ).addCallback(self.assertEquals, s ) def testBrokenDownload(self): # test what happens when download gets disconnected in the middle d = client.getPage(self.getURL("broken")) d = unittest.assertFailure(d, client.PartialDownloadError) d.addCallback(lambda exc: self.assertEquals(exc.response, "abc")) return d def testHostHeader(self): # if we pass Host header explicitly, it should be used, otherwise # it should extract from url return defer.gatherResults([ client.getPage(self.getURL("host")).addCallback(self.assertEquals, "127.0.0.1"), client.getPage(self.getURL("host"), headers={"Host": "www.example.com"}).addCallback(self.assertEquals, "www.example.com")]) def testGetPage(self): return client.getPage(self.getURL("file")).addCallback(self.assertEquals, "0123456789") def testTimeoutNotTriggering(self): # Test that when the timeout doesn't trigger, things work as expected. d = client.getPage(self.getURL("wait"), timeout=100) d.addCallback(self.assertEquals, "hello!!!") return d def testTimeoutTriggering(self): # Test that when the timeout does trigger, we get a defer.TimeoutError. return unittest.assertFailure( client.getPage(self.getURL("wait"), timeout=0.5), defer.TimeoutError) def testDownloadPage(self): downloads = [] downloadData = [("file", self.mktemp(), "0123456789"), ("nolength", self.mktemp(), "nolength")] for (url, name, data) in downloadData: d = client.downloadPage(self.getURL(url), name) d.addCallback(self._cbDownloadPageTest, data, name) downloads.append(d) return defer.gatherResults(downloads) def _cbDownloadPageTest(self, ignored, data, name): bytes = file(name, "rb").read() self.assertEquals(bytes, data) def testDownloadPageError1(self): class errorfile: def write(self, data): raise IOError, "badness happened during write" def close(self): pass ef = errorfile() return unittest.assertFailure( client.downloadPage(self.getURL("file"), ef), IOError) def testDownloadPageError2(self): class errorfile: def write(self, data): pass def close(self): raise IOError, "badness happened during close" ef = errorfile() return unittest.assertFailure( client.downloadPage(self.getURL("file"), ef), IOError) def testDownloadPageError3(self): if not os.geteuid(): raise unittest.SkipTest('this does not work as root') # make sure failures in open() are caught too. This is tricky. # Might only work on posix. tmpfile = open("unwritable", "wb") tmpfile.close() os.chmod("unwritable", 0) # make it unwritable (to us) d = unittest.assertFailure( client.downloadPage(self.getURL("file"), "unwritable"), IOError) d.addBoth(self._cleanupDownloadPageError3) return d def _cleanupDownloadPageError3(self, ignored): os.chmod("unwritable", 0700) os.unlink("unwritable") return ignored def _downloadTest(self, method): dl = [] for (url, code) in [("nosuchfile", "404"), ("error", "401"), ("error?showlength=1", "401")]: d = method(url) d = unittest.assertFailure(d, error.Error) d.addCallback(lambda exc, code=code: self.assertEquals(exc.args[0], code)) dl.append(d) return defer.DeferredList(dl, fireOnOneErrback=True) def testServerError(self): return self._downloadTest(lambda url: client.getPage(self.getURL(url))) def testDownloadServerError(self): return self._downloadTest(lambda url: client.downloadPage(self.getURL(url), url.split('?')[0])) def testFactoryInfo(self): url = self.getURL('file') scheme, host, port, path = client._parse(url) factory = client.HTTPClientFactory(url) reactor.connectTCP(host, port, factory) return factory.deferred.addCallback(self._cbFactoryInfo, factory) def _cbFactoryInfo(self, ignoredResult, factory): self.assertEquals(factory.status, '200') self.assert_(factory.version.startswith('HTTP/')) self.assertEquals(factory.message, 'OK') self.assertEquals(factory.response_headers['content-length'][0], '10') def testRedirect(self): return client.getPage(self.getURL("redirect")).addCallback(self._cbRedirect) def _cbRedirect(self, pageData): self.assertEquals(pageData, "0123456789") d = unittest.assertFailure( client.getPage(self.getURL("redirect"), followRedirect=0), error.PageRedirect) d.addCallback(self._cbCheckLocation) return d def _cbCheckLocation(self, exc): self.assertEquals(exc.location, "/file") def testPartial(self): name = self.mktemp() f = open(name, "wb") f.write("abcd") f.close() downloads = [] partialDownload = [(True, "abcd456789"), (True, "abcd456789"), (False, "0123456789")] d = defer.succeed(None) for (partial, expectedData) in partialDownload: d.addCallback(self._cbRunPartial, name, partial) d.addCallback(self._cbPartialTest, expectedData, name) return d testPartial.skip = "Cannot test until webserver can serve partial data properly" def _cbRunPartial(self, ignored, name, partial): return client.downloadPage(self.getURL("file"), name, supportPartial=partial) def _cbPartialTest(self, ignored, expectedData, filename): bytes = file(filename, "rb").read() self.assertEquals(bytes, expectedData) class WebClientSSLTestCase(WebClientTestCase): def _listen(self, site): from twisted import test return reactor.listenSSL(0, site, contextFactory=ssl.DefaultOpenSSLContextFactory( sibpath(test.__file__, 'server.pem'), sibpath(test.__file__, 'server.pem'), ), interface="127.0.0.1") def getURL(self, path): return "https://127.0.0.1:%d/%s" % (self.portno, path) def testFactoryInfo(self): url = self.getURL('file') scheme, host, port, path = client._parse(url) factory = client.HTTPClientFactory(url) reactor.connectSSL(host, port, factory, ssl.ClientContextFactory()) # The base class defines _cbFactoryInfo correctly for this return factory.deferred.addCallback(self._cbFactoryInfo, factory) class WebClientRedirectBetweenSSLandPlainText(unittest.TestCase): def getHTTPS(self, path): return "https://127.0.0.1:%d/%s" % (self.tlsPortno, path) def getHTTP(self, path): return "http://127.0.0.1:%d/%s" % (self.plainPortno, path) def setUp(self): plainRoot = static.Data('not me', 'text/plain') tlsRoot = static.Data('me neither', 'text/plain') plainSite = server.Site(plainRoot, timeout=None) tlsSite = server.Site(tlsRoot, timeout=None) from twisted import test self.tlsPort = reactor.listenSSL(0, tlsSite, contextFactory=ssl.DefaultOpenSSLContextFactory( sibpath(test.__file__, 'server.pem'), sibpath(test.__file__, 'server.pem'), ), interface="127.0.0.1") self.plainPort = reactor.listenTCP(0, plainSite, interface="127.0.0.1") self.plainPortno = self.plainPort.getHost().port self.tlsPortno = self.tlsPort.getHost().port plainRoot.putChild('one', util.Redirect(self.getHTTPS('two'))) tlsRoot.putChild('two', util.Redirect(self.getHTTP('three'))) plainRoot.putChild('three', util.Redirect(self.getHTTPS('four'))) tlsRoot.putChild('four', static.Data('FOUND IT!', 'text/plain')) def tearDown(self): ds = map(defer.maybeDeferred, [self.plainPort.stopListening, self.tlsPort.stopListening]) return defer.gatherResults(ds) def testHoppingAround(self): return client.getPage(self.getHTTP("one") ).addCallback(self.assertEquals, "FOUND IT!" ) class FakeTransport: disconnecting = False def __init__(self): self.data = [] def write(self, stuff): self.data.append(stuff) class CookieTestCase(unittest.TestCase): def _listen(self, site): return reactor.listenTCP(0, site, interface="127.0.0.1") def setUp(self): root = static.Data('El toro!', 'text/plain') root.putChild("cookiemirror", CookieMirrorResource()) root.putChild("rawcookiemirror", RawCookieMirrorResource()) site = server.Site(root, timeout=None) self.port = self._listen(site) self.portno = self.port.getHost().port def tearDown(self): return self.port.stopListening() def getHTTP(self, path): return "http://127.0.0.1:%d/%s" % (self.portno, path) def testNoCookies(self): return client.getPage(self.getHTTP("cookiemirror") ).addCallback(self.assertEquals, "[]" ) def testSomeCookies(self): cookies = {'foo': 'bar', 'baz': 'quux'} return client.getPage(self.getHTTP("cookiemirror"), cookies=cookies ).addCallback(self.assertEquals, "[('baz', 'quux'), ('foo', 'bar')]" ) def testRawNoCookies(self): return client.getPage(self.getHTTP("rawcookiemirror") ).addCallback(self.assertEquals, "None" ) def testRawSomeCookies(self): cookies = {'foo': 'bar', 'baz': 'quux'} return client.getPage(self.getHTTP("rawcookiemirror"), cookies=cookies ).addCallback(self.assertEquals, "'foo=bar; baz=quux'" ) def testCookieHeaderParsing(self): d = defer.Deferred() factory = client.HTTPClientFactory('http://foo.example.com/') proto = factory.buildProtocol('127.42.42.42') proto.transport = FakeTransport() proto.connectionMade() for line in [ '200 Ok', 'Squash: yes', 'Hands: stolen', 'Set-Cookie: CUSTOMER=WILE_E_COYOTE; path=/; expires=Wednesday, 09-Nov-99 23:12:40 GMT', 'Set-Cookie: PART_NUMBER=ROCKET_LAUNCHER_0001; path=/', 'Set-Cookie: SHIPPING=FEDEX; path=/foo', '', 'body', 'more body', ]: proto.dataReceived(line + '\r\n') self.assertEquals(proto.transport.data, ['GET / HTTP/1.0\r\n', 'Host: foo.example.com\r\n', 'User-Agent: Twisted PageGetter\r\n', '\r\n']) self.assertEquals(factory.cookies, { 'CUSTOMER': 'WILE_E_COYOTE', 'PART_NUMBER': 'ROCKET_LAUNCHER_0001', 'SHIPPING': 'FEDEX', }) if ssl is None or not hasattr(ssl, 'DefaultOpenSSLContextFactory'): for case in [WebClientSSLTestCase, WebClientRedirectBetweenSSLandPlainText]: case.skip = "OpenSSL not present" if not interfaces.IReactorSSL(reactor, None): for case in [WebClientSSLTestCase, WebClientRedirectBetweenSSLandPlainText]: case.skip = "Reactor doesn't support SSL"