Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(1420)

Delta Between Two Patch Sets: ssh/mux_test.go

Issue 14225043: code review 14225043: go.crypto/ssh: reimplement SSH connection protocol modu... (Closed)
Left Patch Set: diff -r 2cd6b3b93cdb https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Right Patch Set: diff -r cd1eea1eb828 https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
Left: Side by side diff | Download
Right: Side by side diff | Download
« no previous file with change/comment | « ssh/mux.go ('k') | ssh/server.go » ('j') | no next file with change/comment »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
LEFTRIGHT
1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
1 package ssh 5 package ssh
2 6
3 import ( 7 import (
4 "io" 8 "io"
5 "io/ioutil" 9 "io/ioutil"
6 "log"
7 "sync" 10 "sync"
8 "testing" 11 "testing"
9 "time" 12 "time"
10 ) 13 )
11 14
12 var _ = log.Println
13
14 func muxPair() (*mux, *mux) { 15 func muxPair() (*mux, *mux) {
15 a, b := memPipe() 16 a, b := memPipe()
16 17
17 s := newMux(a) 18 s := newMux(a)
18 c := newMux(b) 19 c := newMux(b)
19 20
20 go s.Loop() 21 go s.Loop()
21 go c.Loop() 22 go c.Loop()
22 23
23 c.chanList.offset = 'c'
24 s.chanList.offset = 's'
25 return s, c 24 return s, c
26 } 25 }
27 26
28 func channelPair(t *testing.T) (Channel, Channel) { 27 // Returns both ends of a channel, and the mux for the the 2nd
28 // channel.
29 func channelPair(t *testing.T) (*channel, *channel, *mux) {
29 c, s := muxPair() 30 c, s := muxPair()
30 31
31 » res := make(chan Channel, 1) 32 » res := make(chan *channel, 1)
32 » go func() { 33 » go func() {
33 » » ch, err := s.Accept() 34 » » ch, ok := <-s.incomingChannels
35 » » if !ok {
36 » » » t.Fatalf("No incoming channel")
37 » » }
38 » » if ch.ChannelType() != "chan" {
39 » » » t.Fatalf("got type %q want chan", ch.ChannelType())
40 » » }
41 » » err := ch.Accept()
34 if err != nil { 42 if err != nil {
35 t.Fatalf("Accept %v", err) 43 t.Fatalf("Accept %v", err)
36 } 44 }
37 if ch.ChannelType() != "chan" {
38 t.Fatalf("got type %q want chan", ch.ChannelType())
39 }
40 ch.Accept()
41 res <- ch 45 res <- ch
42 }() 46 }()
43 47
44 ch, err := c.OpenChannel("chan", nil) 48 ch, err := c.OpenChannel("chan", nil)
45 if err != nil { 49 if err != nil {
46 t.Fatalf("OpenChannel: %v", err) 50 t.Fatalf("OpenChannel: %v", err)
47 } 51 }
48 52
49 » return <-res, ch 53 » return <-res, ch, c
50 } 54 }
51 55
52 func TestMuxReadWrite(t *testing.T) { 56 func TestMuxReadWrite(t *testing.T) {
53 » s, c := channelPair(t) 57 » s, c, _ := channelPair(t)
54 58
55 magic := "hello world" 59 magic := "hello world"
56 magicExt := "hello stderr" 60 magicExt := "hello stderr"
57 » var wg sync.WaitGroup 61 » go func() {
58 » wg.Add(1)
59 » go func() {
60 » » defer wg.Done()
61
62 _, err := s.Write([]byte(magic)) 62 _, err := s.Write([]byte(magic))
63 if err != nil { 63 if err != nil {
64 t.Fatalf("Write: %v", err) 64 t.Fatalf("Write: %v", err)
65 } 65 }
66 » » _, err = s.Stderr().Write([]byte(magicExt)) 66 » » _, err = s.Extended(1).Write([]byte(magicExt))
67 if err != nil { 67 if err != nil {
68 t.Fatalf("Write: %v", err) 68 t.Fatalf("Write: %v", err)
69 } 69 }
70 err = s.Close() 70 err = s.Close()
71 if err != nil { 71 if err != nil {
72 t.Fatalf("Close: %v", err) 72 t.Fatalf("Close: %v", err)
73 } 73 }
74 }() 74 }()
75 75
76 var buf [1024]byte 76 var buf [1024]byte
77 n, err := c.Read(buf[:]) 77 n, err := c.Read(buf[:])
78 if err != nil { 78 if err != nil {
79 t.Fatalf("server Read: %v", err) 79 t.Fatalf("server Read: %v", err)
80 } 80 }
81 got := string(buf[:n]) 81 got := string(buf[:n])
82 if got != magic { 82 if got != magic {
83 t.Fatalf("server: got %q want %q", got, magic) 83 t.Fatalf("server: got %q want %q", got, magic)
84 } 84 }
85 85
86 » n, err = c.Stderr().Read(buf[:]) 86 » n, err = c.Extended(1).Read(buf[:])
87 if err != nil { 87 if err != nil {
88 t.Fatalf("server Read: %v", err) 88 t.Fatalf("server Read: %v", err)
89 } 89 }
90 90
91 got = string(buf[:n]) 91 got = string(buf[:n])
92 if got != magicExt { 92 if got != magicExt {
93 t.Fatalf("server: got %q want %q", got, magic) 93 t.Fatalf("server: got %q want %q", got, magic)
94 } 94 }
95 } 95 }
96 96
97 func TestMuxFlowControl(t *testing.T) { 97 func TestMuxFlowControl(t *testing.T) {
98 writerMux, readerMux := muxPair() 98 writerMux, readerMux := muxPair()
99 99
100 var wg sync.WaitGroup
101 wg.Add(2)
102
103 // this goroutine reads just a bit. 100 // this goroutine reads just a bit.
104 go func() { 101 go func() {
105 » » reader, err := readerMux.Accept() 102 » » reader, ok := <-readerMux.incomingChannels
103 » » if !ok {
104 » » » t.Fatalf("no incoming channel")
105 » » }
106 » » err := reader.Accept()
106 if err != nil { 107 if err != nil {
107 t.Fatalf("Accept: %v", err)
108 }
109 if err = reader.Accept(); err != nil {
110 t.Fatalf("Accept: %v", err) 108 t.Fatalf("Accept: %v", err)
111 } 109 }
112 110
113 b := make([]byte, 1024) 111 b := make([]byte, 1024)
114 n, err := reader.Read(b) 112 n, err := reader.Read(b)
115 if err != nil || n != len(b) { 113 if err != nil || n != len(b) {
116 t.Errorf("Read: %v, %d bytes", err, n) 114 t.Errorf("Read: %v, %d bytes", err, n)
117 } 115 }
118 wg.Done()
119 }() 116 }()
120 117
121 writer, err := writerMux.OpenChannel("pipe", nil) 118 writer, err := writerMux.OpenChannel("pipe", nil)
122 if err != nil { 119 if err != nil {
123 t.Fatalf("OpenChannel: %v", err) 120 t.Fatalf("OpenChannel: %v", err)
124 } 121 }
125 122
126 // This goroutine writes is blocked from writing by the slow 123 // This goroutine writes is blocked from writing by the slow
127 // reader 124 // reader
128 go func() { 125 go func() {
129 largeData := make([]byte, 3*(1<<15)) 126 largeData := make([]byte, 3*(1<<15))
130 n, err := writer.Write(largeData) 127 n, err := writer.Write(largeData)
131 if err != io.EOF { 128 if err != io.EOF {
132 t.Errorf("want EOF, got %v", err) 129 t.Errorf("want EOF, got %v", err)
133 } 130 }
134 want := 1024 + (1 << 15) 131 want := 1024 + (1 << 15)
135 if n != want { 132 if n != want {
136 t.Errorf("wrote %d, want %d", n, want) 133 t.Errorf("wrote %d, want %d", n, want)
137 } 134 }
138 wg.Done()
139 }() 135 }()
140 136
141 // Wait for a bit for things to subside. The write should be 137 // Wait for a bit for things to subside. The write should be
142 // blocked. 138 // blocked.
143 time.Sleep(1 * time.Millisecond) 139 time.Sleep(1 * time.Millisecond)
144 140
145 readerMux.Disconnect(0, "") 141 readerMux.Disconnect(0, "")
146 writerMux.Disconnect(0, "") 142 writerMux.Disconnect(0, "")
147
148 wg.Done()
149 } 143 }
150 144
151 func TestMuxReject(t *testing.T) { 145 func TestMuxReject(t *testing.T) {
152 client, server := muxPair() 146 client, server := muxPair()
153 147
154 go func() { 148 go func() {
155 » » ch, err := server.Accept() 149 » » ch, ok := <-server.incomingChannels
156 » » if err != nil { 150 » » if !ok {
157 » » » t.Fatalf("Accept: %v", err) 151 » » » t.Fatalf("Accept")
158 } 152 }
159 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { 153 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
160 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) 154 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
161 } 155 }
162 ch.Reject(RejectionReason(42), "message") 156 ch.Reject(RejectionReason(42), "message")
163 }() 157 }()
164 158
165 ch, err := client.OpenChannel("ch", []byte("extra")) 159 ch, err := client.OpenChannel("ch", []byte("extra"))
166 if ch != nil { 160 if ch != nil {
167 t.Fatal("openChannel not rejected") 161 t.Fatal("openChannel not rejected")
168 } 162 }
169 163
170 » ocf, ok := err.(*OpenChannelFailed) 164 » ocf, ok := err.(*OpenChannelError)
171 if !ok { 165 if !ok {
172 » » t.Errorf("got %#v want *OpenChannelFailed", err) 166 » » t.Errorf("got %#v want *OpenChannelError", err)
173 » } 167 » } else if ocf.Reason != 42 || ocf.Message != "message" {
174 168 » » t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "messag e")
175 » if ocf.Reason != 42 || ocf.Message != "message" {
176 » » t.Errorf("got %#v, want {Reason: 42, Mepassage: %q}", ocf, "mess age")
177 } 169 }
178 170
179 want := "ssh: rejected: unknown reason 42 (message)" 171 want := "ssh: rejected: unknown reason 42 (message)"
180 if err.Error() != want { 172 if err.Error() != want {
181 t.Errorf("got %q, want %q", err.Error(), want) 173 t.Errorf("got %q, want %q", err.Error(), want)
182 } 174 }
183 } 175 }
184 176
185 func TestMuxChannelRequest(t *testing.T) { 177 func TestMuxChannelRequest(t *testing.T) {
186 » client, server := channelPair(t) 178 » client, server, _ := channelPair(t)
187 var received int 179 var received int
188 var wg sync.WaitGroup 180 var wg sync.WaitGroup
189 wg.Add(1) 181 wg.Add(1)
190 go func() { 182 go func() {
191 » » for r := range server.ReceivedRequests() { 183 » » for r := range server.incomingRequests {
192 received++ 184 received++
193 if r.WantReply { 185 if r.WantReply {
194 server.AckRequest(r.Request == "yes") 186 server.AckRequest(r.Request == "yes")
195 } 187 }
196 } 188 }
197 wg.Done() 189 wg.Done()
198 }() 190 }()
199 _, err := client.SendRequest("yes", false, nil) 191 _, err := client.SendRequest("yes", false, nil)
200 if err != nil { 192 if err != nil {
201 t.Fatalf("SendRequest: %v", err) 193 t.Fatalf("SendRequest: %v", err)
(...skipping 22 matching lines...) Expand all
224 216
225 if received != 3 { 217 if received != 3 {
226 t.Errorf("got %d requests, want %d", received) 218 t.Errorf("got %d requests, want %d", received)
227 } 219 }
228 } 220 }
229 221
230 func TestMuxGlobalRequest(t *testing.T) { 222 func TestMuxGlobalRequest(t *testing.T) {
231 clientMux, serverMux := muxPair() 223 clientMux, serverMux := muxPair()
232 224
233 var seen bool 225 var seen bool
234 » var wg sync.WaitGroup 226 » go func() {
235 » wg.Add(1) 227 » » for r := range serverMux.incomingRequests {
236 » go func() {
237 » » for r := range serverMux.ReceivedRequests() {
238 seen = seen || r.Request == "peek" 228 seen = seen || r.Request == "peek"
239 if r.WantReply { 229 if r.WantReply {
240 » » » » err := serverMux.AckGlobalRequest(r.Request == " yes", 230 » » » » err := serverMux.AckRequest(r.Request == "yes",
241 append([]byte(r.Request), r.Payload...)) 231 append([]byte(r.Request), r.Payload...))
242 if err != nil { 232 if err != nil {
243 t.Errorf("AckRequest: %v", err) 233 t.Errorf("AckRequest: %v", err)
244 } 234 }
245 } 235 }
246 } 236 }
247 wg.Done()
248 }() 237 }()
249 238
250 _, _, err := clientMux.SendRequest("peek", false, nil) 239 _, _, err := clientMux.SendRequest("peek", false, nil)
251 if err != nil { 240 if err != nil {
252 t.Errorf("SendRequest: %v", err) 241 t.Errorf("SendRequest: %v", err)
253 } 242 }
254 243
255 ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) 244 ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
256 if !ok || string(data) != "yesa" || err != nil { 245 if !ok || string(data) != "yesa" || err != nil {
257 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", 246 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
258 ok, data, err) 247 ok, data, err)
259 } 248 }
260 if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil { 249 if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
261 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", 250 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
262 ok, data, err) 251 ok, data, err)
263 } 252 }
264 253
265 if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok | | string(data) != "noa" || err != nil { 254 if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok | | string(data) != "noa" || err != nil {
266 t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", 255 t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
267 ok, data, err) 256 ok, data, err)
268 } 257 }
269 » // not really related to global reqs, but try disconnect too. 258
270 » clientMux.Disconnect(42, "whatever") 259 » clientMux.Disconnect(0, "")
271 260 » if !seen {
272 » wg.Wait() 261 » » t.Errorf("never saw 'peek' request")
262 » }
263 }
264
265 func TestMuxGlobalRequestUnblock(t *testing.T) {
266 » clientMux, serverMux := muxPair()
267
268 » result := make(chan error, 1)
269 » go func() {
270 » » _, _, err := clientMux.SendRequest("hello", true, nil)
271 » » result <- err
272 » }()
273
274 » <-serverMux.incomingRequests
275 » serverMux.conn.Close()
276 » err := <-result
277
278 » if err != io.EOF {
279 » » t.Errorf("want EOF, got %v", io.EOF)
280 » }
281 }
282
283 func TestMuxChannelRequestUnblock(t *testing.T) {
284 » a, b, connB := channelPair(t)
285
286 » result := make(chan error, 1)
287 » go func() {
288 » » _, err := a.SendRequest("hello", true, nil)
289 » » result <- err
290 » }()
291
292 » <-b.incomingRequests
293 » connB.conn.Close()
294 » err := <-result
295
296 » if err != io.EOF {
297 » » t.Errorf("want EOF, got %v", err)
298 » }
299 }
300
301 func TestMuxDisconnect(t *testing.T) {
302 » a, b := muxPair()
303 » go func() {
304 » » for r := range b.incomingRequests {
305 » » » if r.WantReply {
306 » » » » b.AckRequest(true, nil)
307 » » » }
308 » » }
309 » }()
310
311 » a.Disconnect(42, "whatever")
312 » ok, _, err := a.SendRequest("hello", true, nil)
313 » if ok || err == nil {
314 » » t.Errorf("got reply after disconnecting")
315 » }
273 } 316 }
274 317
275 func TestMuxCloseChannel(t *testing.T) { 318 func TestMuxCloseChannel(t *testing.T) {
276 » r, w := channelPair(t) 319 » r, w, _ := channelPair(t)
277 320
278 timeout := time.After(10 * time.Millisecond) 321 timeout := time.After(10 * time.Millisecond)
279 result := make(chan error, 1) 322 result := make(chan error, 1)
280 go func() { 323 go func() {
281 var b [1024]byte 324 var b [1024]byte
282 _, err := r.Read(b[:]) 325 _, err := r.Read(b[:])
283 result <- err 326 result <- err
284 }() 327 }()
285 if err := w.Close(); err != nil { 328 if err := w.Close(); err != nil {
286 t.Errorf("w.Close: %v", err) 329 t.Errorf("w.Close: %v", err)
287 } 330 }
288 331
289 if _, err := w.Write([]byte("hello")); err != io.EOF { 332 if _, err := w.Write([]byte("hello")); err != io.EOF {
290 t.Errorf("got err %v, want io.EOF after Close", err) 333 t.Errorf("got err %v, want io.EOF after Close", err)
291 } 334 }
292 335
293 select { 336 select {
294 case e := <-result: 337 case e := <-result:
295 if e != io.EOF { 338 if e != io.EOF {
296 t.Errorf("got %v (%T), want io.EOF", e, e) 339 t.Errorf("got %v (%T), want io.EOF", e, e)
297 } 340 }
298 case <-timeout: 341 case <-timeout:
299 t.Errorf("timed out waiting for read to exit") 342 t.Errorf("timed out waiting for read to exit")
300 } 343 }
301 } 344 }
302 345
303 func TestMuxCloseWriteChannel(t *testing.T) { 346 func TestMuxCloseWriteChannel(t *testing.T) {
304 » r, w := channelPair(t) 347 » r, w, _ := channelPair(t)
305 348
306 timeout := time.After(10 * time.Millisecond) 349 timeout := time.After(10 * time.Millisecond)
307 result := make(chan error, 1) 350 result := make(chan error, 1)
308 go func() { 351 go func() {
309 var b [1024]byte 352 var b [1024]byte
310 _, err := r.Read(b[:]) 353 _, err := r.Read(b[:])
311 result <- err 354 result <- err
312 }() 355 }()
313 if err := w.CloseWrite(); err != nil { 356 if err := w.CloseWrite(); err != nil {
314 t.Errorf("w.CloseWrite: %v", err) 357 t.Errorf("w.CloseWrite: %v", err)
(...skipping 18 matching lines...) Expand all
333 376
334 packet := make([]byte, 1+4+4+1) 377 packet := make([]byte, 1+4+4+1)
335 packet[0] = msgChannelData 378 packet[0] = msgChannelData
336 marshalUint32(packet[1:], 29348723 /* invalid channel id */) 379 marshalUint32(packet[1:], 29348723 /* invalid channel id */)
337 marshalUint32(packet[5:], 1) 380 marshalUint32(packet[5:], 1)
338 packet[9] = 42 381 packet[9] = 42
339 382
340 a.conn.writePacket(packet) 383 a.conn.writePacket(packet)
341 go a.SendRequest("hello", false, nil) 384 go a.SendRequest("hello", false, nil)
342 // 'a' wrote an invalid packet, so 'b' has exited. 385 // 'a' wrote an invalid packet, so 'b' has exited.
343 » req, ok := <-b.ReceivedRequests() 386 » req, ok := <-b.incomingRequests
344 if ok { 387 if ok {
345 t.Errorf("got request %#v after receiving invalid packet", req) 388 t.Errorf("got request %#v after receiving invalid packet", req)
346 } 389 }
347 } 390 }
348 391
349 func TestZeroWindowAdjust(t *testing.T) { 392 func TestZeroWindowAdjust(t *testing.T) {
350 » a, b := channelPair(t) 393 » a, b, _ := channelPair(t)
351 394
352 go func() { 395 go func() {
353 io.WriteString(a, "hello") 396 io.WriteString(a, "hello")
354 // bogus adjust. 397 // bogus adjust.
355 » » a.(*channel).sendMessage( 398 » » a.sendMessage(
356 msgChannelWindowAdjust, windowAdjustMsg{}) 399 msgChannelWindowAdjust, windowAdjustMsg{})
357 io.WriteString(a, "world") 400 io.WriteString(a, "world")
358 a.Close() 401 a.Close()
359 }() 402 }()
360 403
361 want := "helloworld" 404 want := "helloworld"
362 c, _ := ioutil.ReadAll(b) 405 c, _ := ioutil.ReadAll(b)
363 if string(c) != want { 406 if string(c) != want {
364 t.Errorf("got %q want %q", c, want) 407 t.Errorf("got %q want %q", c, want)
365 } 408 }
366 } 409 }
410
411 func TestMuxMaxPacketSize(t *testing.T) {
412 a, b, _ := channelPair(t)
413
414 large := make([]byte, a.maxPacket+1)
415 if err := a.writePacket(large); err == nil {
416 t.Errorf("channel sent out packet larger than maxPacket")
417 }
418
419 packet := make([]byte, 1+4+4+1+len(large))
420 packet[0] = msgChannelData
421 marshalUint32(packet[1:], a.remoteId)
422 marshalUint32(packet[5:], uint32(len(large)))
423 packet[9] = 42
424
425 if err := a.mux.conn.writePacket(packet); err != nil {
426 t.Errorf("could not send packet")
427 }
428
429 go a.SendRequest("hello", false, nil)
430
431 _, ok := <-b.incomingRequests
432 if ok {
433 t.Errorf("connection still alive after receiving large packet.")
434 }
435 }
LEFTRIGHT

Powered by Google App Engine
RSS Feeds Recent Issues | This issue
This is Rietveld f62528b