# Copyright (c) 2001-2004 Twisted Matrix Laboratories. # See LICENSE for details. from sets import Set import warnings from twisted.internet import interfaces, defer, main from twisted.persisted import styles from twisted.python import log, failure from ops import ReadFileOp, WriteFileOp from util import StateEventMachineType from zope.interface import implements from socket import error as socket_error class ConnectedSocket(log.Logger, styles.Ephemeral, object): __metaclass__ = StateEventMachineType implements(interfaces.ITransport, interfaces.IProducer, interfaces.IConsumer) events = ["write", "loseConnection", "writeDone", "writeErr", "readDone", "readErr", "shutdown"] bufferSize = 2**2**2**2 producer = None writing = False reading = False write_shutdown = False read_shutdown = False producerPaused = False def __init__(self, socket, protocol, sockfactory): self.state = "connected" from twisted.internet import reactor self.socket = socket self.protocol = protocol self.sf = sockfactory self.writebuf = [] self.readbuf = reactor.AllocateReadBuffer(self.bufferSize) self.reactor = reactor self.bufferEvents = {"buffer full": Set(), "buffer empty": Set()} self.offset = 0 self.writeBufferedSize = 0 self.producerBuffer = [] self.read_op = ReadFileOp(self) self.write_op = WriteFileOp(self) # XXX: these two should be specified like before, with a class field def addBufferCallback(self, handler, event): self.bufferEvents[event].add(handler) def removeBufferCallback(self, handler, event): self.bufferEvents[event].remove(handler) def callBufferHandlers(self, event, *a, **kw): for i in self.bufferEvents[event].copy(): i(*a, **kw) def handle_connected_write(self, data): if self.writebuf and len(self.writebuf[-1]) < self.bufferSize: # mmmhhh silly heuristics self.writebuf[-1] += data else: self.writebuf.append(data) self.writeBufferedSize += len(data) if self.writeBufferedSize >= self.bufferSize: self.callBufferHandlers(event = "buffer full") if not self.writing: self.startWriting() handle_disconnecting_write = handle_connected_write def handle_disconnected_write(self, data): pass # blarf def writeSequence(self, iovec): self.write("".join(iovec)) def _cbDisconnecting(self): if self.producer: return self.removeBufferCallback(self._cbDisconnecting, "buffer empty") self.connectionLost(failure.Failure(main.CONNECTION_DONE)) def handle_connected_loseConnection(self): self.stopReading() if self.writing: self.addBufferCallback(self._cbDisconnecting, "buffer empty") self.state = "disconnecting" else: self.reactor.callLater(0, self.connectionLost, failure.Failure(main.CONNECTION_DONE)) def handle_disconnecting_loseConnection(self): pass handle_disconnected_loseConnection = handle_disconnecting_loseConnection def _cbWriteShutdown(self): self.removeBufferCallback(self._cbWriteShutdown, "buffer empty") self.socket.shutdown(1) def handle_connected_shutdown(self, write = False, read = False): if read and not self.read_shutdown: self.read_shutdown = True self.socket.shutdown(0) if write and not self.write_shutdown: self.write_shutdown = True # don't need to keep "we are shutting write side down", right? if self.writing: self.addBufferCallback(self._cbWriteShutdown, "buffer empty") else: self.socket.shutdown(1) def connectionLost(self, reason): # log.msg("connectionLost called with reason", reason, "for socket", id(self)) # import traceback # for i in traceback.format_stack(): # log.msg(i[:-1]) self.state = "disconnected" protocol = self.protocol del self.protocol # XXX: perhaps the following needs to be around to avoid resetting the connection ungracefully try: self.socket.shutdown(2) except socket_error: pass # this should call closesocket() and kill it dead! self.socket.close() del self.socket self.sf.connectionLost(reason) try: protocol.connectionLost(reason) except TypeError, e: # while this may break, it will only break on deprecated code # as opposed to other approaches that might've broken on # code that uses the new API (e.g. inspect). if e.args and e.args[0] == "connectionLost() takes exactly 1 argument (2 given)": warnings.warn("Protocol %s's connectionLost should accept a reason argument" % protocol, category=DeprecationWarning, stacklevel=2) protocol.connectionLost() else: raise def startReading(self): if self.state != "connected": return self.reading = True while self.producerBuffer: item = self.producerBuffer.pop(0) self.protocol.dataReceived(item) if not self.reading: return try: self.read_op.initiateOp(self.socket.fileno(), self.readbuf) except WindowsError, we: # log.msg("initiating read failed with args %s" % (we,)) self.reactor.callLater(0, self.connectionLost, failure.Failure(main.CONNECTION_DONE)) def stopReading(self): self.reading = False def handle_connected_readDone(self, bytes): if self.reading: self.protocol.dataReceived(self.readbuf[:bytes]) self.startReading() else: self.producerBuffer.append(self.readbuf[:bytes]) def handle_disconnecting_readDone(self, bytes): pass # a leftover read op from before we began disconnecting def handle_connected_readErr(self, ret, bytes): # log.msg("read failed with err %s" % (ret,)) self.connectionLost(failure.Failure(main.CONNECTION_DONE)) handle_disconnecting_readErr = handle_connected_readErr def handle_disconnected_readErr(self, ret, bytes): pass # no kicking the dead horse def handle_disconnected_readDone(self, bytes): pass # no kicking the dead horse def startWriting(self): self.writing = True b = buffer(self.writebuf[0], self.offset) # ll = map(len, self.writebuf) # log.msg("buffer lengths are", ll, "total", sum(ll)) try: self.write_op.initiateOp(self.socket.fileno(), b) except WindowsError, we: # log.msg("initiating write failed with args %s" % (we,)) self.reactor.callLater(0, self.connectionLost, failure.Failure(main.CONNECTION_DONE)) def stopWriting(self): self.writing = False def handle_connected_writeDone(self, bytes): self.offset += bytes self.writeBufferedSize -= bytes if self.offset == len(self.writebuf[0]): del self.writebuf[0] self.offset = 0 if self.writebuf == []: self.writing = False self.callBufferHandlers(event = "buffer empty") else: self.startWriting() handle_disconnecting_writeDone = handle_connected_writeDone def handle_connected_writeErr(self, ret, bytes): self.connectionLost(failure.Failure(main.CONNECTION_DONE)) handle_disconnecting_writeErr = handle_connected_writeErr def handle_disconnected_writeErr(self, ret, bytes): pass # no kicking the dead horse def handle_disconnected_writeDone(self, bytes): pass # no kicking the dead horse # consumer interface implementation def registerProducer(self, producer, streaming): if self.producer is not None: raise RuntimeError("Cannot register producer %s, because producer %s was never unregistered." % (producer, self.producer)) if self.state == "disconnected": producer.stopProducing() else: self.producer = producer self.streamingProducer = streaming self.addBufferCallback(self.milkProducer, "buffer empty") self.addBufferCallback(self.stfuProducer, "buffer full") if not streaming: self.producerPaused = False producer.resumeProducing() def milkProducer(self): if not self.streamingProducer or self.producerPaused: self.producerPaused = False self.producer.resumeProducing() def stfuProducer(self): self.producerPaused = True self.producer.pauseProducing() def unregisterProducer(self): self.removeBufferCallback(self.stfuProducer, "buffer full") self.removeBufferCallback(self.milkProducer, "buffer empty") self.producer = None def stopConsuming(self): self.unregisterProducer() self.loseConnection() # producer interface implementation def resumeProducing(self): self.startReading() def pauseProducing(self): self.stopReading() def stopProducing(self): self.loseConnection() def __repr__(self): return self.repstr def logPrefix(self): return self.logstr # groan, stupid LineReceiver and LineOnlyReceiver want to see this in a transport disconnecting = property(lambda self: self.state == "disconnecting") connected = property(lambda self: self.state == "connected")