/usr/share/pyshared/twisted/conch/test/test_channel.py is in python-twisted-conch 1:11.1.0-1.
This file is owned by root:root, with mode 0o644.
The actual contents of the file can be viewed below.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 | # Copyright (C) 2007-2008 Twisted Matrix Laboratories
# See LICENSE for details
"""
Test ssh/channel.py.
"""
from twisted.conch.ssh import channel
from twisted.trial import unittest
class MockTransport(object):
"""
A mock Transport. All we use is the getPeer() and getHost() methods.
Channels implement the ITransport interface, and their getPeer() and
getHost() methods return ('SSH', <transport's getPeer/Host value>) so
we need to implement these methods so they have something to draw
from.
"""
def getPeer(self):
return ('MockPeer',)
def getHost(self):
return ('MockHost',)
class MockConnection(object):
"""
A mock for twisted.conch.ssh.connection.SSHConnection. Record the data
that channels send, and when they try to close the connection.
@ivar data: a C{dict} mapping channel id #s to lists of data sent by that
channel.
@ivar extData: a C{dict} mapping channel id #s to lists of 2-tuples
(extended data type, data) sent by that channel.
@ivar closes: a C{dict} mapping channel id #s to True if that channel sent
a close message.
"""
transport = MockTransport()
def __init__(self):
self.data = {}
self.extData = {}
self.closes = {}
def logPrefix(self):
"""
Return our logging prefix.
"""
return "MockConnection"
def sendData(self, channel, data):
"""
Record the sent data.
"""
self.data.setdefault(channel, []).append(data)
def sendExtendedData(self, channel, type, data):
"""
Record the sent extended data.
"""
self.extData.setdefault(channel, []).append((type, data))
def sendClose(self, channel):
"""
Record that the channel sent a close message.
"""
self.closes[channel] = True
class ChannelTestCase(unittest.TestCase):
def setUp(self):
"""
Initialize the channel. remoteMaxPacket is 10 so that data is able
to be sent (the default of 0 means no data is sent because no packets
are made).
"""
self.conn = MockConnection()
self.channel = channel.SSHChannel(conn=self.conn,
remoteMaxPacket=10)
self.channel.name = 'channel'
def test_init(self):
"""
Test that SSHChannel initializes correctly. localWindowSize defaults
to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
defaults (what OpenSSH uses for those variables).
The values in the second set of assertions are meaningless; they serve
only to verify that the instance variables are assigned in the correct
order.
"""
c = channel.SSHChannel(conn=self.conn)
self.assertEqual(c.localWindowSize, 131072)
self.assertEqual(c.localWindowLeft, 131072)
self.assertEqual(c.localMaxPacket, 32768)
self.assertEqual(c.remoteWindowLeft, 0)
self.assertEqual(c.remoteMaxPacket, 0)
self.assertEqual(c.conn, self.conn)
self.assertEqual(c.data, None)
self.assertEqual(c.avatar, None)
c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
self.assertEqual(c2.localWindowSize, 1)
self.assertEqual(c2.localWindowLeft, 1)
self.assertEqual(c2.localMaxPacket, 2)
self.assertEqual(c2.remoteWindowLeft, 3)
self.assertEqual(c2.remoteMaxPacket, 4)
self.assertEqual(c2.conn, 5)
self.assertEqual(c2.data, 6)
self.assertEqual(c2.avatar, 7)
def test_str(self):
"""
Test that str(SSHChannel) works gives the channel name and local and
remote windows at a glance..
"""
self.assertEqual(str(self.channel), '<SSHChannel channel (lw 131072 '
'rw 0)>')
def test_logPrefix(self):
"""
Test that SSHChannel.logPrefix gives the name of the channel, the
local channel ID and the underlying connection.
"""
self.assertEqual(self.channel.logPrefix(), 'SSHChannel channel '
'(unknown) on MockConnection')
def test_addWindowBytes(self):
"""
Test that addWindowBytes adds bytes to the window and resumes writing
if it was paused.
"""
cb = [False]
def stubStartWriting():
cb[0] = True
self.channel.startWriting = stubStartWriting
self.channel.write('test')
self.channel.writeExtended(1, 'test')
self.channel.addWindowBytes(50)
self.assertEqual(self.channel.remoteWindowLeft, 50 - 4 - 4)
self.assertTrue(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(self.channel.buf, '')
self.assertEqual(self.conn.data[self.channel], ['test'])
self.assertEqual(self.channel.extBuf, [])
self.assertEqual(self.conn.extData[self.channel], [(1, 'test')])
cb[0] = False
self.channel.addWindowBytes(20)
self.assertFalse(cb[0])
self.channel.write('a'*80)
self.channel.loseConnection()
self.channel.addWindowBytes(20)
self.assertFalse(cb[0])
def test_requestReceived(self):
"""
Test that requestReceived handles requests by dispatching them to
request_* methods.
"""
self.channel.request_test_method = lambda data: data == ''
self.assertTrue(self.channel.requestReceived('test-method', ''))
self.assertFalse(self.channel.requestReceived('test-method', 'a'))
self.assertFalse(self.channel.requestReceived('bad-method', ''))
def test_closeReceieved(self):
"""
Test that the default closeReceieved closes the connection.
"""
self.assertFalse(self.channel.closing)
self.channel.closeReceived()
self.assertTrue(self.channel.closing)
def test_write(self):
"""
Test that write handles data correctly. Send data up to the size
of the remote window, splitting the data into packets of length
remoteMaxPacket.
"""
cb = [False]
def stubStopWriting():
cb[0] = True
# no window to start with
self.channel.stopWriting = stubStopWriting
self.channel.write('d')
self.channel.write('a')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
# regular write
self.channel.addWindowBytes(20)
self.channel.write('ta')
data = self.conn.data[self.channel]
self.assertEqual(data, ['da', 'ta'])
self.assertEqual(self.channel.remoteWindowLeft, 16)
# larger than max packet
self.channel.write('12345678901')
self.assertEqual(data, ['da', 'ta', '1234567890', '1'])
self.assertEqual(self.channel.remoteWindowLeft, 5)
# running out of window
cb[0] = False
self.channel.write('123456')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(data, ['da', 'ta', '1234567890', '1', '12345'])
self.assertEqual(self.channel.buf, '6')
self.assertEqual(self.channel.remoteWindowLeft, 0)
def test_writeExtended(self):
"""
Test that writeExtended handles data correctly. Send extended data
up to the size of the window, splitting the extended data into packets
of length remoteMaxPacket.
"""
cb = [False]
def stubStopWriting():
cb[0] = True
# no window to start with
self.channel.stopWriting = stubStopWriting
self.channel.writeExtended(1, 'd')
self.channel.writeExtended(1, 'a')
self.channel.writeExtended(2, 't')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
# regular write
self.channel.addWindowBytes(20)
self.channel.writeExtended(2, 'a')
data = self.conn.extData[self.channel]
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a')])
self.assertEqual(self.channel.remoteWindowLeft, 16)
# larger than max packet
self.channel.writeExtended(3, '12345678901')
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a'),
(3, '1234567890'), (3, '1')])
self.assertEqual(self.channel.remoteWindowLeft, 5)
# running out of window
cb[0] = False
self.channel.writeExtended(4, '123456')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(data, [(1, 'da'), (2, 't'), (2, 'a'),
(3, '1234567890'), (3, '1'), (4, '12345')])
self.assertEqual(self.channel.extBuf, [[4, '6']])
self.assertEqual(self.channel.remoteWindowLeft, 0)
def test_writeSequence(self):
"""
Test that writeSequence is equivalent to write(''.join(sequece)).
"""
self.channel.addWindowBytes(20)
self.channel.writeSequence(map(str, range(10)))
self.assertEqual(self.conn.data[self.channel], ['0123456789'])
def test_loseConnection(self):
"""
Tesyt that loseConnection() doesn't close the channel until all
the data is sent.
"""
self.channel.write('data')
self.channel.writeExtended(1, 'datadata')
self.channel.loseConnection()
self.assertEqual(self.conn.closes.get(self.channel), None)
self.channel.addWindowBytes(4) # send regular data
self.assertEqual(self.conn.closes.get(self.channel), None)
self.channel.addWindowBytes(8) # send extended data
self.assertTrue(self.conn.closes.get(self.channel))
def test_getPeer(self):
"""
Test that getPeer() returns ('SSH', <connection transport peer>).
"""
self.assertEqual(self.channel.getPeer(), ('SSH', 'MockPeer'))
def test_getHost(self):
"""
Test that getHost() returns ('SSH', <connection transport host>).
"""
self.assertEqual(self.channel.getHost(), ('SSH', 'MockHost'))
|