# Copyright (c) 2001-2004 Twisted Matrix Laboratories. # See LICENSE for details. import socket from twisted.persisted import styles from twisted.internet.base import BaseConnector from twisted.internet import defer, interfaces, error from twisted.python import failure from abstract import ConnectedSocket from ops import ConnectExOp from util import StateEventMachineType from zope.interface import implements class ClientSocket(ConnectedSocket): def __init__(self, sock, protocol, sf): ConnectedSocket.__init__(self, sock, protocol, sf) self.repstr = '<%s to %s at %x>' % (self.__class__, self.sf.addr, id(self)) self.logstr = protocol.__class__.__name__+",client" self.startReading() class _SubConnector: state = "connecting" socket = None def __init__(self, sf): self.sf = sf def startConnecting(self): d = defer.maybeDeferred(self.sf.resolveAddress) d.addCallback(self._cbResolveDone) d.addErrback(self._ebResolveErr) def _cbResolveDone(self, addr): if self.state == "dead": return try: skt = socket.socket(*self.sf.sockinfo) except socket.error, se: raise error.ConnectBindError(se[0], se[1]) try: if self.sf.bindAddress is None: self.sf.bindAddress = ("", 0) # necessary for ConnectEx skt.bind(self.sf.bindAddress) except socket.error, se: raise error.ConnectBindError(se[0], se[1]) self.socket = skt op = ConnectExOp(self) op.initiateOp(self.socket, addr) def _ebResolveErr(self, fail): if self.state == "dead": return self.sf.connectionFailed(fail) def connectDone(self): if self.state == "dead": return self.sf.connectionSuccess() def connectErr(self, err): if self.state == "dead": return self.sf.connectionFailed(err) class SocketConnector(styles.Ephemeral, object): __metaclass__ = StateEventMachineType implements(interfaces.IConnector) transport_class = ClientSocket events = ["stopConnecting", "disconnect", "connect"] sockinfo = None factoryStarted = False timeoutID = None def __init__(self, addr, factory, timeout, bindAddress): from twisted.internet import reactor self.state = "disconnected" self.addr = addr self.factory = factory self.timeout = timeout self.bindAddress = bindAddress self.reactor = reactor self.prepareAddress() def handle_connecting_stopConnecting(self): self.connectionFailed(failure.Failure(error.UserError())) def handle_disconnected_stopConnecting(self): raise error.NotConnectingError handle_connected_stopConnecting = handle_disconnected_stopConnecting handle_connecting_disconnect = handle_connecting_stopConnecting def handle_connected_disconnect(self): self.transport.loseConnection() def handle_disconnected_disconnect(self): pass def handle_connecting_connect(self): raise RuntimeError, "can't connect in this state" handle_connected_connect = handle_connecting_connect def handle_disconnected_connect(self): self.state = "connecting" if not self.factoryStarted: self.factory.doStart() self.factoryStarted = True if self.timeout is not None: self.timeoutID = self.reactor.callLater(self.timeout, self.connectionFailed, failure.Failure(error.TimeoutError())) self.sub = _SubConnector(self) self.sub.startConnecting() self.factory.startedConnecting(self) def prepareAddress(self): raise NotImplementedError def resolveAddress(self): raise NotImplementedError def connectionLost(self, reason): self.state = "disconnected" self.factory.clientConnectionLost(self, reason) if self.state == "disconnected": # factory hasn't called our connect() method self.factory.doStop() self.factoryStarted = 0 def connectionFailed(self, reason): if self.sub.socket: self.sub.socket.close() self.sub.state = "dead" del self.sub self.state = "disconnected" self.cancelTimeout() self.factory.clientConnectionFailed(self, reason) if self.state == "disconnected": # factory hasn't called our connect() method self.factory.doStop() self.factoryStarted = 0 def cancelTimeout(self): if self.timeoutID: try: self.timeoutID.cancel() except ValueError: pass del self.timeoutID def connectionSuccess(self): socket = self.sub.socket self.sub.state = "dead" del self.sub self.state = "connected" self.cancelTimeout() p = self.factory.buildProtocol(self.buildAddress(socket.getpeername())) self.transport = self.transport_class(socket, p, self) p.makeConnection(self.transport)