# Copyright (c) 2007-2010 Twisted Matrix Laboratories.
# See LICENSE for details
"""
This module tests twisted.conch.ssh.connection.
"""
from __future__ import division, absolute_import
import struct
from twisted.python.reflect import requireModule
cryptography = requireModule("cryptography")
from twisted.conch import error
if cryptography:
from twisted.conch.ssh import common, connection
else:
class connection:
class SSHConnection: pass
from twisted.conch.ssh import channel
from twisted.python.compat import long
from twisted.trial import unittest
from twisted.conch.test import test_userauth
class TestChannel(channel.SSHChannel):
"""
A mocked-up version of twisted.conch.ssh.channel.SSHChannel.
@ivar gotOpen: True if channelOpen has been called.
@type gotOpen: L{bool}
@ivar specificData: the specific channel open data passed to channelOpen.
@type specificData: L{bytes}
@ivar openFailureReason: the reason passed to openFailed.
@type openFailed: C{error.ConchError}
@ivar inBuffer: a C{list} of strings received by the channel.
@type inBuffer: C{list}
@ivar extBuffer: a C{list} of 2-tuples (type, extended data) of received by
the channel.
@type extBuffer: C{list}
@ivar numberRequests: the number of requests that have been made to this
channel.
@type numberRequests: L{int}
@ivar gotEOF: True if the other side sent EOF.
@type gotEOF: L{bool}
@ivar gotOneClose: True if the other side closed the connection.
@type gotOneClose: L{bool}
@ivar gotClosed: True if the channel is closed.
@type gotClosed: L{bool}
"""
name = b"TestChannel"
gotOpen = False
def logPrefix(self):
return "TestChannel %i" % self.id
def channelOpen(self, specificData):
"""
The channel is open. Set up the instance variables.
"""
self.gotOpen = True
self.specificData = specificData
self.inBuffer = []
self.extBuffer = []
self.numberRequests = 0
self.gotEOF = False
self.gotOneClose = False
self.gotClosed = False
def openFailed(self, reason):
"""
Opening the channel failed. Store the reason why.
"""
self.openFailureReason = reason
def request_test(self, data):
"""
A test request. Return True if data is 'data'.
@type data: L{bytes}
"""
self.numberRequests += 1
return data == b'data'
def dataReceived(self, data):
"""
Data was received. Store it in the buffer.
"""
self.inBuffer.append(data)
def extReceived(self, code, data):
"""
Extended data was received. Store it in the buffer.
"""
self.extBuffer.append((code, data))
def eofReceived(self):
"""
EOF was received. Remember it.
"""
self.gotEOF = True
def closeReceived(self):
"""
Close was received. Remember it.
"""
self.gotOneClose = True
def closed(self):
"""
The channel is closed. Rembember it.
"""
self.gotClosed = True
class TestAvatar:
"""
A mocked-up version of twisted.conch.avatar.ConchUser
"""
_ARGS_ERROR_CODE = 123
def lookupChannel(self, channelType, windowSize, maxPacket, data):
"""
The server wants us to return a channel. If the requested channel is
our TestChannel, return it, otherwise return None.
"""
if channelType == TestChannel.name:
return TestChannel(remoteWindow=windowSize,
remoteMaxPacket=maxPacket,
data=data, avatar=self)
elif channelType == b"conch-error-args":
# Raise a ConchError with backwards arguments to make sure the
# connection fixes it for us. This case should be deprecated and
# deleted eventually, but only after all of Conch gets the argument
# order right.
raise error.ConchError(
self._ARGS_ERROR_CODE, "error args in wrong order")
def gotGlobalRequest(self, requestType, data):
"""
The client has made a global request. If the global request is
'TestGlobal', return True. If the global request is 'TestData',
return True and the request-specific data we received. Otherwise,
return False.
"""
if requestType == b'TestGlobal':
return True
elif requestType == b'TestData':
return True, data
else:
return False
class TestConnection(connection.SSHConnection):
"""
A subclass of SSHConnection for testing.
@ivar channel: the current channel.
@type channel. C{TestChannel}
"""
if not cryptography:
skip = "Cannot run without cryptography"
def logPrefix(self):
return "TestConnection"
def global_TestGlobal(self, data):
"""
The other side made the 'TestGlobal' global request. Return True.
"""
return True
def global_Test_Data(self, data):
"""
The other side made the 'Test-Data' global request. Return True and
the data we received.
"""
return True, data
def channel_TestChannel(self, windowSize, maxPacket, data):
"""
The other side is requesting the TestChannel. Create a C{TestChannel}
instance, store it, and return it.
"""
self.channel = TestChannel(remoteWindow=windowSize,
remoteMaxPacket=maxPacket, data=data)
return self.channel
def channel_ErrorChannel(self, windowSize, maxPacket, data):
"""
The other side is requesting the ErrorChannel. Raise an exception.
"""
raise AssertionError('no such thing')
class ConnectionTests(unittest.TestCase):
if not cryptography:
skip = "Cannot run without cryptography"
if test_userauth.transport is None:
skip = "Cannot run without both cryptography and pyasn1"
def setUp(self):
self.transport = test_userauth.FakeTransport(None)
self.transport.avatar = TestAvatar()
self.conn = TestConnection()
self.conn.transport = self.transport
self.conn.serviceStarted()
def _openChannel(self, channel):
"""
Open the channel with the default connection.
"""
self.conn.openChannel(channel)
self.transport.packets = self.transport.packets[:-1]
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(struct.pack('>2L',
channel.id, 255) + b'\x00\x02\x00\x00\x00\x00\x80\x00')
def tearDown(self):
self.conn.serviceStopped()
def test_linkAvatar(self):
"""
Test that the connection links itself to the avatar in the
transport.
"""
self.assertIs(self.transport.avatar.conn, self.conn)
def test_serviceStopped(self):
"""
Test that serviceStopped() closes any open channels.
"""
channel1 = TestChannel()
channel2 = TestChannel()
self.conn.openChannel(channel1)
self.conn.openChannel(channel2)
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(b'\x00\x00\x00\x00' * 4)
self.assertTrue(channel1.gotOpen)
self.assertFalse(channel2.gotOpen)
self.conn.serviceStopped()
self.assertTrue(channel1.gotClosed)
def test_GLOBAL_REQUEST(self):
"""
Test that global request packets are dispatched to the global_*
methods and the return values are translated into success or failure
messages.
"""
self.conn.ssh_GLOBAL_REQUEST(common.NS(b'TestGlobal') + b'\xff')
self.assertEqual(self.transport.packets,
[(connection.MSG_REQUEST_SUCCESS, b'')])
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS(b'TestData') + b'\xff' +
b'test data')
self.assertEqual(self.transport.packets,
[(connection.MSG_REQUEST_SUCCESS, b'test data')])
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS(b'TestBad') + b'\xff')
self.assertEqual(self.transport.packets,
[(connection.MSG_REQUEST_FAILURE, b'')])
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS(b'TestGlobal') + b'\x00')
self.assertEqual(self.transport.packets, [])
def test_REQUEST_SUCCESS(self):
"""
Test that global request success packets cause the Deferred to be
called back.
"""
d = self.conn.sendGlobalRequest(b'request', b'data', True)
self.conn.ssh_REQUEST_SUCCESS(b'data')
def check(data):
self.assertEqual(data, b'data')
d.addCallback(check)
d.addErrback(self.fail)
return d
def test_REQUEST_FAILURE(self):
"""
Test that global request failure packets cause the Deferred to be
erred back.
"""
d = self.conn.sendGlobalRequest(b'request', b'data', True)
self.conn.ssh_REQUEST_FAILURE(b'data')
def check(f):
self.assertEqual(f.value.data, b'data')
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_CHANNEL_OPEN(self):
"""
Test that open channel packets cause a channel to be created and
opened or a failure message to be returned.
"""
del self.transport.avatar
self.conn.ssh_CHANNEL_OPEN(common.NS(b'TestChannel') +
b'\x00\x00\x00\x01' * 4)
self.assertTrue(self.conn.channel.gotOpen)
self.assertEqual(self.conn.channel.conn, self.conn)
self.assertEqual(self.conn.channel.data, b'\x00\x00\x00\x01')
self.assertEqual(self.conn.channel.specificData, b'\x00\x00\x00\x01')
self.assertEqual(self.conn.channel.remoteWindowLeft, 1)
self.assertEqual(self.conn.channel.remoteMaxPacket, 1)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_CONFIRMATION,
b'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x02\x00\x00'
b'\x00\x00\x80\x00')])
self.transport.packets = []
self.conn.ssh_CHANNEL_OPEN(common.NS(b'BadChannel') +
b'\x00\x00\x00\x02' * 4)
self.flushLoggedErrors()
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_FAILURE,
b'\x00\x00\x00\x02\x00\x00\x00\x03' + common.NS(
b'unknown channel') + common.NS(b''))])
self.transport.packets = []
self.conn.ssh_CHANNEL_OPEN(common.NS(b'ErrorChannel') +
b'\x00\x00\x00\x02' * 4)
self.flushLoggedErrors()
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_FAILURE,
b'\x00\x00\x00\x02\x00\x00\x00\x02' + common.NS(
b'unknown failure') + common.NS(b''))])
def _lookupChannelErrorTest(self, code):
"""
Deliver a request for a channel open which will result in an exception
being raised during channel lookup. Assert that an error response is
delivered as a result.
"""
self.transport.avatar._ARGS_ERROR_CODE = code
self.conn.ssh_CHANNEL_OPEN(
common.NS(b'conch-error-args') + b'\x00\x00\x00\x01' * 4)
errors = self.flushLoggedErrors(error.ConchError)
self.assertEqual(
len(errors), 1, "Expected one error, got: %r" % (errors,))
self.assertEqual(errors[0].value.args, (long(123), "error args in wrong order"))
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_FAILURE,
# The response includes some bytes which identifying the
# associated request, as well as the error code (7b in hex) and
# the error message.
b'\x00\x00\x00\x01\x00\x00\x00\x7b' + common.NS(
b'error args in wrong order') + common.NS(b''))])
def test_lookupChannelError(self):
"""
If a C{lookupChannel} implementation raises L{error.ConchError} with the
arguments in the wrong order, a C{MSG_CHANNEL_OPEN} failure is still
sent in response to the message.
This is a temporary work-around until L{error.ConchError} is given
better attributes and all of the Conch code starts constructing
instances of it properly. Eventually this functionality should be
deprecated and then removed.
"""
self._lookupChannelErrorTest(123)
def test_lookupChannelErrorLongCode(self):
"""
Like L{test_lookupChannelError}, but for the case where the failure code
is represented as a L{long} instead of a L{int}.
"""
self._lookupChannelErrorTest(long(123))
def test_CHANNEL_OPEN_CONFIRMATION(self):
"""
Test that channel open confirmation packets cause the channel to be
notified that it's open.
"""
channel = TestChannel()
self.conn.openChannel(channel)
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(b'\x00\x00\x00\x00'*5)
self.assertEqual(channel.remoteWindowLeft, 0)
self.assertEqual(channel.remoteMaxPacket, 0)
self.assertEqual(channel.specificData, b'\x00\x00\x00\x00')
self.assertEqual(self.conn.channelsToRemoteChannel[channel],
0)
self.assertEqual(self.conn.localToRemoteChannel[0], 0)
def test_CHANNEL_OPEN_FAILURE(self):
"""
Test that channel open failure packets cause the channel to be
notified that its opening failed.
"""
channel = TestChannel()
self.conn.openChannel(channel)
self.conn.ssh_CHANNEL_OPEN_FAILURE(b'\x00\x00\x00\x00\x00\x00\x00'
b'\x01' + common.NS(b'failure!'))
self.assertEqual(channel.openFailureReason.args, (b'failure!', 1))
self.assertIsNone(self.conn.channels.get(channel))
def test_CHANNEL_WINDOW_ADJUST(self):
"""
Test that channel window adjust messages add bytes to the channel
window.
"""
channel = TestChannel()
self._openChannel(channel)
oldWindowSize = channel.remoteWindowLeft
self.conn.ssh_CHANNEL_WINDOW_ADJUST(b'\x00\x00\x00\x00\x00\x00\x00'
b'\x01')
self.assertEqual(channel.remoteWindowLeft, oldWindowSize + 1)
def test_CHANNEL_DATA(self):
"""
Test that channel data messages are passed up to the channel, or
cause the channel to be closed if the data is too large.
"""
channel = TestChannel(localWindow=6, localMaxPacket=5)
self._openChannel(channel)
self.conn.ssh_CHANNEL_DATA(b'\x00\x00\x00\x00' + common.NS(b'data'))
self.assertEqual(channel.inBuffer, [b'data'])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_WINDOW_ADJUST, b'\x00\x00\x00\xff'
b'\x00\x00\x00\x04')])
self.transport.packets = []
longData = b'a' * (channel.localWindowLeft + 1)
self.conn.ssh_CHANNEL_DATA(b'\x00\x00\x00\x00' + common.NS(longData))
self.assertEqual(channel.inBuffer, [b'data'])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b'\x00\x00\x00\xff')])
channel = TestChannel()
self._openChannel(channel)
bigData = b'a' * (channel.localMaxPacket + 1)
self.transport.packets = []
self.conn.ssh_CHANNEL_DATA(b'\x00\x00\x00\x01' + common.NS(bigData))
self.assertEqual(channel.inBuffer, [])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b'\x00\x00\x00\xff')])
def test_CHANNEL_EXTENDED_DATA(self):
"""
Test that channel extended data messages are passed up to the channel,
or cause the channel to be closed if they're too big.
"""
channel = TestChannel(localWindow=6, localMaxPacket=5)
self._openChannel(channel)
self.conn.ssh_CHANNEL_EXTENDED_DATA(b'\x00\x00\x00\x00\x00\x00\x00'
b'\x00' + common.NS(b'data'))
self.assertEqual(channel.extBuffer, [(0, b'data')])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_WINDOW_ADJUST, b'\x00\x00\x00\xff'
b'\x00\x00\x00\x04')])
self.transport.packets = []
longData = b'a' * (channel.localWindowLeft + 1)
self.conn.ssh_CHANNEL_EXTENDED_DATA(b'\x00\x00\x00\x00\x00\x00\x00'
b'\x00' + common.NS(longData))
self.assertEqual(channel.extBuffer, [(0, b'data')])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b'\x00\x00\x00\xff')])
channel = TestChannel()
self._openChannel(channel)
bigData = b'a' * (channel.localMaxPacket + 1)
self.transport.packets = []
self.conn.ssh_CHANNEL_EXTENDED_DATA(b'\x00\x00\x00\x01\x00\x00\x00'
b'\x00' + common.NS(bigData))
self.assertEqual(channel.extBuffer, [])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b'\x00\x00\x00\xff')])
def test_CHANNEL_EOF(self):
"""
Test that channel eof messages are passed up to the channel.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.ssh_CHANNEL_EOF(b'\x00\x00\x00\x00')
self.assertTrue(channel.gotEOF)
def test_CHANNEL_CLOSE(self):
"""
Test that channel close messages are passed up to the channel. Also,
test that channel.close() is called if both sides are closed when this
message is received.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendClose(channel)
self.conn.ssh_CHANNEL_CLOSE(b'\x00\x00\x00\x00')
self.assertTrue(channel.gotOneClose)
self.assertTrue(channel.gotClosed)
def test_CHANNEL_REQUEST_success(self):
"""
Test that channel requests that succeed send MSG_CHANNEL_SUCCESS.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.ssh_CHANNEL_REQUEST(b'\x00\x00\x00\x00' + common.NS(b'test')
+ b'\x00')
self.assertEqual(channel.numberRequests, 1)
d = self.conn.ssh_CHANNEL_REQUEST(b'\x00\x00\x00\x00' + common.NS(
b'test') + b'\xff' + b'data')
def check(result):
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_SUCCESS, b'\x00\x00\x00\xff')])
d.addCallback(check)
return d
def test_CHANNEL_REQUEST_failure(self):
"""
Test that channel requests that fail send MSG_CHANNEL_FAILURE.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.ssh_CHANNEL_REQUEST(b'\x00\x00\x00\x00' + common.NS(
b'test') + b'\xff')
def check(result):
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_FAILURE, b'\x00\x00\x00\xff'
)])
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_CHANNEL_REQUEST_SUCCESS(self):
"""
Test that channel request success messages cause the Deferred to be
called back.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, b'test', b'data', True)
self.conn.ssh_CHANNEL_SUCCESS(b'\x00\x00\x00\x00')
def check(result):
self.assertTrue(result)
return d
def test_CHANNEL_REQUEST_FAILURE(self):
"""
Test that channel request failure messages cause the Deferred to be
erred back.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, b'test', b'', True)
self.conn.ssh_CHANNEL_FAILURE(b'\x00\x00\x00\x00')
def check(result):
self.assertEqual(result.value.value, 'channel request failed')
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_sendGlobalRequest(self):
"""
Test that global request messages are sent in the right format.
"""
d = self.conn.sendGlobalRequest(b'wantReply', b'data', True)
# must be added to prevent errbacking during teardown
d.addErrback(lambda failure: None)
self.conn.sendGlobalRequest(b'noReply', b'', False)
self.assertEqual(self.transport.packets,
[(connection.MSG_GLOBAL_REQUEST, common.NS(b'wantReply') +
b'\xffdata'),
(connection.MSG_GLOBAL_REQUEST, common.NS(b'noReply') +
b'\x00')])
self.assertEqual(self.conn.deferreds, {'global':[d]})
def test_openChannel(self):
"""
Test that open channel messages are sent in the right format.
"""
channel = TestChannel()
self.conn.openChannel(channel, b'aaaa')
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN, common.NS(b'TestChannel') +
b'\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x80\x00aaaa')])
self.assertEqual(channel.id, 0)
self.assertEqual(self.conn.localChannelID, 1)
def test_sendRequest(self):
"""
Test that channel request messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, b'test', b'test', True)
# needed to prevent errbacks during teardown.
d.addErrback(lambda failure: None)
self.conn.sendRequest(channel, b'test2', b'', False)
channel.localClosed = True # emulate sending a close message
self.conn.sendRequest(channel, b'test3', b'', True)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_REQUEST, b'\x00\x00\x00\xff' +
common.NS(b'test') + b'\x01test'),
(connection.MSG_CHANNEL_REQUEST, b'\x00\x00\x00\xff' +
common.NS(b'test2') + b'\x00')])
self.assertEqual(self.conn.deferreds[0], [d])
def test_adjustWindow(self):
"""
Test that channel window adjust messages cause bytes to be added
to the window.
"""
channel = TestChannel(localWindow=5)
self._openChannel(channel)
channel.localWindowLeft = 0
self.conn.adjustWindow(channel, 1)
self.assertEqual(channel.localWindowLeft, 1)
channel.localClosed = True
self.conn.adjustWindow(channel, 2)
self.assertEqual(channel.localWindowLeft, 1)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_WINDOW_ADJUST, b'\x00\x00\x00\xff'
b'\x00\x00\x00\x01')])
def test_sendData(self):
"""
Test that channel data messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendData(channel, b'a')
channel.localClosed = True
self.conn.sendData(channel, b'b')
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_DATA, b'\x00\x00\x00\xff' +
common.NS(b'a'))])
def test_sendExtendedData(self):
"""
Test that channel extended data messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendExtendedData(channel, 1, b'test')
channel.localClosed = True
self.conn.sendExtendedData(channel, 2, b'test2')
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_EXTENDED_DATA, b'\x00\x00\x00\xff' +
b'\x00\x00\x00\x01' + common.NS(b'test'))])
def test_sendEOF(self):
"""
Test that channel EOF messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendEOF(channel)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_EOF, b'\x00\x00\x00\xff')])
channel.localClosed = True
self.conn.sendEOF(channel)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_EOF, b'\x00\x00\x00\xff')])
def test_sendClose(self):
"""
Test that channel close messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendClose(channel)
self.assertTrue(channel.localClosed)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b'\x00\x00\x00\xff')])
self.conn.sendClose(channel)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b'\x00\x00\x00\xff')])
channel2 = TestChannel()
self._openChannel(channel2)
channel2.remoteClosed = True
self.conn.sendClose(channel2)
self.assertTrue(channel2.gotClosed)
def test_getChannelWithAvatar(self):
"""
Test that getChannel dispatches to the avatar when an avatar is
present. Correct functioning without the avatar is verified in
test_CHANNEL_OPEN.
"""
channel = self.conn.getChannel(b'TestChannel', 50, 30, b'data')
self.assertEqual(channel.data, b'data')
self.assertEqual(channel.remoteWindowLeft, 50)
self.assertEqual(channel.remoteMaxPacket, 30)
self.assertRaises(error.ConchError, self.conn.getChannel,
b'BadChannel', 50, 30, b'data')
def test_gotGlobalRequestWithoutAvatar(self):
"""
Test that gotGlobalRequests dispatches to global_* without an avatar.
"""
del self.transport.avatar
self.assertTrue(self.conn.gotGlobalRequest(b'TestGlobal', b'data'))
self.assertEqual(self.conn.gotGlobalRequest(b'Test-Data', b'data'),
(True, b'data'))
self.assertFalse(self.conn.gotGlobalRequest(b'BadGlobal', b'data'))
def test_channelClosedCausesLeftoverChannelDeferredsToErrback(self):
"""
Whenever an SSH channel gets closed any Deferred that was returned by a
sendRequest() on its parent connection must be errbacked.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(
channel, b"dummyrequest", b"dummydata", wantReply=1)
d = self.assertFailure(d, error.ConchError)
self.conn.channelClosed(channel)
return d
class CleanConnectionShutdownTests(unittest.TestCase):
"""
Check whether correct cleanup is performed on connection shutdown.
"""
if not cryptography:
skip = "Cannot run without cryptography"
if test_userauth.transport is None:
skip = "Cannot run without both cryptography and pyasn1"
def setUp(self):
self.transport = test_userauth.FakeTransport(None)
self.transport.avatar = TestAvatar()
self.conn = TestConnection()
self.conn.transport = self.transport
def test_serviceStoppedCausesLeftoverGlobalDeferredsToErrback(self):
"""
Once the service is stopped any leftover global deferred returned by
a sendGlobalRequest() call must be errbacked.
"""
self.conn.serviceStarted()
d = self.conn.sendGlobalRequest(
b"dummyrequest", b"dummydata", wantReply=1)
d = self.assertFailure(d, error.ConchError)
self.conn.serviceStopped()
return d