Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(552)

Delta Between Two Patch Sets: ssh/mux_test.go

Issue 14225043: code review 14225043: go.crypto/ssh: reimplement SSH connection protocol modu... (Closed)
Left Patch Set: diff -r 5ff5636e18c9 https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Right Patch Set: diff -r cd1eea1eb828 https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
Left: Side by side diff | Download
Right: Side by side diff | Download
« no previous file with change/comment | « ssh/mux.go ('k') | ssh/server.go » ('j') | no next file with change/comment »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
LEFTRIGHT
1 // Copyright 2013 The Go Authors. All rights reserved. 1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style 2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file. 3 // license that can be found in the LICENSE file.
4 4
5 package ssh 5 package ssh
6 6
7 import ( 7 import (
8 "io" 8 "io"
9 "io/ioutil" 9 "io/ioutil"
10 "sync" 10 "sync"
11 "testing" 11 "testing"
12 "time" 12 "time"
13 ) 13 )
14 14
15 func muxPair() (*mux, *mux) { 15 func muxPair() (*mux, *mux) {
16 a, b := memPipe() 16 a, b := memPipe()
17 17
18 s := newMux(a) 18 s := newMux(a)
19 c := newMux(b) 19 c := newMux(b)
20 20
21 go s.Loop() 21 go s.Loop()
22 go c.Loop() 22 go c.Loop()
23 23
24 return s, c 24 return s, c
25 } 25 }
26 26
27 // Returns both ends of a channel, and the mux for the the 2nd 27 // Returns both ends of a channel, and the mux for the the 2nd
28 // channel. 28 // channel.
29 func channelPair(t *testing.T) (Channel, Channel, *mux) { 29 func channelPair(t *testing.T) (*channel, *channel, *mux) {
30 c, s := muxPair() 30 c, s := muxPair()
31 31
32 » res := make(chan Channel, 1) 32 » res := make(chan *channel, 1)
33 » go func() { 33 » go func() {
34 » » ch, ok := <-s.IncomingChannels() 34 » » ch, ok := <-s.incomingChannels
35 if !ok { 35 if !ok {
36 t.Fatalf("No incoming channel") 36 t.Fatalf("No incoming channel")
37 } 37 }
38 if ch.ChannelType() != "chan" { 38 if ch.ChannelType() != "chan" {
39 t.Fatalf("got type %q want chan", ch.ChannelType()) 39 t.Fatalf("got type %q want chan", ch.ChannelType())
40 } 40 }
41 » » channel, err := ch.Accept() 41 » » err := ch.Accept()
42 if err != nil { 42 if err != nil {
43 t.Fatalf("Accept %v", err) 43 t.Fatalf("Accept %v", err)
44 } 44 }
45 » » res <- channel 45 » » res <- ch
46 }() 46 }()
47 47
48 ch, err := c.OpenChannel("chan", nil) 48 ch, err := c.OpenChannel("chan", nil)
49 if err != nil { 49 if err != nil {
50 t.Fatalf("OpenChannel: %v", err) 50 t.Fatalf("OpenChannel: %v", err)
51 } 51 }
52 52
53 return <-res, ch, c 53 return <-res, ch, c
54 } 54 }
55 55
56 func TestMuxReadWrite(t *testing.T) { 56 func TestMuxReadWrite(t *testing.T) {
57 s, c, _ := channelPair(t) 57 s, c, _ := channelPair(t)
58 58
59 magic := "hello world" 59 magic := "hello world"
60 magicExt := "hello stderr" 60 magicExt := "hello stderr"
61 go func() { 61 go func() {
62 _, err := s.Write([]byte(magic)) 62 _, err := s.Write([]byte(magic))
63 if err != nil { 63 if err != nil {
64 t.Fatalf("Write: %v", err) 64 t.Fatalf("Write: %v", err)
65 } 65 }
66 » » _, err = s.Stderr().Write([]byte(magicExt)) 66 » » _, err = s.Extended(1).Write([]byte(magicExt))
67 if err != nil { 67 if err != nil {
68 t.Fatalf("Write: %v", err) 68 t.Fatalf("Write: %v", err)
69 } 69 }
70 err = s.Close() 70 err = s.Close()
71 if err != nil { 71 if err != nil {
72 t.Fatalf("Close: %v", err) 72 t.Fatalf("Close: %v", err)
73 } 73 }
74 }() 74 }()
75 75
76 var buf [1024]byte 76 var buf [1024]byte
77 n, err := c.Read(buf[:]) 77 n, err := c.Read(buf[:])
78 if err != nil { 78 if err != nil {
79 t.Fatalf("server Read: %v", err) 79 t.Fatalf("server Read: %v", err)
80 } 80 }
81 got := string(buf[:n]) 81 got := string(buf[:n])
82 if got != magic { 82 if got != magic {
83 t.Fatalf("server: got %q want %q", got, magic) 83 t.Fatalf("server: got %q want %q", got, magic)
84 } 84 }
85 85
86 » n, err = c.Stderr().Read(buf[:]) 86 » n, err = c.Extended(1).Read(buf[:])
87 if err != nil { 87 if err != nil {
88 t.Fatalf("server Read: %v", err) 88 t.Fatalf("server Read: %v", err)
89 } 89 }
90 90
91 got = string(buf[:n]) 91 got = string(buf[:n])
92 if got != magicExt { 92 if got != magicExt {
93 t.Fatalf("server: got %q want %q", got, magic) 93 t.Fatalf("server: got %q want %q", got, magic)
94 } 94 }
95 } 95 }
96 96
97 func TestMuxFlowControl(t *testing.T) { 97 func TestMuxFlowControl(t *testing.T) {
98 writerMux, readerMux := muxPair() 98 writerMux, readerMux := muxPair()
99 99
100 // this goroutine reads just a bit. 100 // this goroutine reads just a bit.
101 go func() { 101 go func() {
102 » » readerCreate, ok := <-readerMux.IncomingChannels() 102 » » reader, ok := <-readerMux.incomingChannels
103 if !ok { 103 if !ok {
104 t.Fatalf("no incoming channel") 104 t.Fatalf("no incoming channel")
105 } 105 }
106 » » reader, err := readerCreate.Accept() 106 » » err := reader.Accept()
107 if err != nil { 107 if err != nil {
108 t.Fatalf("Accept: %v", err) 108 t.Fatalf("Accept: %v", err)
109 } 109 }
110 110
111 b := make([]byte, 1024) 111 b := make([]byte, 1024)
112 n, err := reader.Read(b) 112 n, err := reader.Read(b)
113 if err != nil || n != len(b) { 113 if err != nil || n != len(b) {
114 t.Errorf("Read: %v, %d bytes", err, n) 114 t.Errorf("Read: %v, %d bytes", err, n)
115 } 115 }
116 }() 116 }()
(...skipping 22 matching lines...) Expand all
139 time.Sleep(1 * time.Millisecond) 139 time.Sleep(1 * time.Millisecond)
140 140
141 readerMux.Disconnect(0, "") 141 readerMux.Disconnect(0, "")
142 writerMux.Disconnect(0, "") 142 writerMux.Disconnect(0, "")
143 } 143 }
144 144
145 func TestMuxReject(t *testing.T) { 145 func TestMuxReject(t *testing.T) {
146 client, server := muxPair() 146 client, server := muxPair()
147 147
148 go func() { 148 go func() {
149 » » ch, ok := <-server.IncomingChannels() 149 » » ch, ok := <-server.incomingChannels
150 if !ok { 150 if !ok {
151 t.Fatalf("Accept") 151 t.Fatalf("Accept")
152 } 152 }
153 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { 153 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
154 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) 154 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
155 } 155 }
156 ch.Reject(RejectionReason(42), "message") 156 ch.Reject(RejectionReason(42), "message")
157 }() 157 }()
158 158
159 ch, err := client.OpenChannel("ch", []byte("extra")) 159 ch, err := client.OpenChannel("ch", []byte("extra"))
(...skipping 13 matching lines...) Expand all
173 t.Errorf("got %q, want %q", err.Error(), want) 173 t.Errorf("got %q, want %q", err.Error(), want)
174 } 174 }
175 } 175 }
176 176
177 func TestMuxChannelRequest(t *testing.T) { 177 func TestMuxChannelRequest(t *testing.T) {
178 client, server, _ := channelPair(t) 178 client, server, _ := channelPair(t)
179 var received int 179 var received int
180 var wg sync.WaitGroup 180 var wg sync.WaitGroup
181 wg.Add(1) 181 wg.Add(1)
182 go func() { 182 go func() {
183 » » for r := range server.IncomingRequests() { 183 » » for r := range server.incomingRequests {
184 received++ 184 received++
185 if r.WantReply { 185 if r.WantReply {
186 server.AckRequest(r.Request == "yes") 186 server.AckRequest(r.Request == "yes")
187 } 187 }
188 } 188 }
189 wg.Done() 189 wg.Done()
190 }() 190 }()
191 _, err := client.SendRequest("yes", false, nil) 191 _, err := client.SendRequest("yes", false, nil)
192 if err != nil { 192 if err != nil {
193 t.Fatalf("SendRequest: %v", err) 193 t.Fatalf("SendRequest: %v", err)
(...skipping 23 matching lines...) Expand all
217 if received != 3 { 217 if received != 3 {
218 t.Errorf("got %d requests, want %d", received) 218 t.Errorf("got %d requests, want %d", received)
219 } 219 }
220 } 220 }
221 221
222 func TestMuxGlobalRequest(t *testing.T) { 222 func TestMuxGlobalRequest(t *testing.T) {
223 clientMux, serverMux := muxPair() 223 clientMux, serverMux := muxPair()
224 224
225 var seen bool 225 var seen bool
226 go func() { 226 go func() {
227 » » for r := range serverMux.IncomingRequests() { 227 » » for r := range serverMux.incomingRequests {
228 seen = seen || r.Request == "peek" 228 seen = seen || r.Request == "peek"
229 if r.WantReply { 229 if r.WantReply {
230 err := serverMux.AckRequest(r.Request == "yes", 230 err := serverMux.AckRequest(r.Request == "yes",
231 append([]byte(r.Request), r.Payload...)) 231 append([]byte(r.Request), r.Payload...))
232 if err != nil { 232 if err != nil {
233 t.Errorf("AckRequest: %v", err) 233 t.Errorf("AckRequest: %v", err)
234 } 234 }
235 } 235 }
236 } 236 }
237 }() 237 }()
(...skipping 26 matching lines...) Expand all
264 264
265 func TestMuxGlobalRequestUnblock(t *testing.T) { 265 func TestMuxGlobalRequestUnblock(t *testing.T) {
266 clientMux, serverMux := muxPair() 266 clientMux, serverMux := muxPair()
267 267
268 result := make(chan error, 1) 268 result := make(chan error, 1)
269 go func() { 269 go func() {
270 _, _, err := clientMux.SendRequest("hello", true, nil) 270 _, _, err := clientMux.SendRequest("hello", true, nil)
271 result <- err 271 result <- err
272 }() 272 }()
273 273
274 » <-serverMux.IncomingRequests() 274 » <-serverMux.incomingRequests
275 serverMux.conn.Close() 275 serverMux.conn.Close()
276 err := <-result 276 err := <-result
277 277
278 if err != io.EOF { 278 if err != io.EOF {
279 t.Errorf("want EOF, got %v", io.EOF) 279 t.Errorf("want EOF, got %v", io.EOF)
280 } 280 }
281 } 281 }
282 282
283 func TestMuxChannelRequestUnblock(t *testing.T) { 283 func TestMuxChannelRequestUnblock(t *testing.T) {
284 a, b, connB := channelPair(t) 284 a, b, connB := channelPair(t)
285 285
286 result := make(chan error, 1) 286 result := make(chan error, 1)
287 go func() { 287 go func() {
288 _, err := a.SendRequest("hello", true, nil) 288 _, err := a.SendRequest("hello", true, nil)
289 result <- err 289 result <- err
290 }() 290 }()
291 291
292 » <-b.IncomingRequests() 292 » <-b.incomingRequests
293 connB.conn.Close() 293 connB.conn.Close()
294 err := <-result 294 err := <-result
295 295
296 if err != io.EOF { 296 if err != io.EOF {
297 t.Errorf("want EOF, got %v", err) 297 t.Errorf("want EOF, got %v", err)
298 } 298 }
299 } 299 }
300 300
301 func TestMuxDisconnect(t *testing.T) { 301 func TestMuxDisconnect(t *testing.T) {
302 a, b := muxPair() 302 a, b := muxPair()
303 go func() { 303 go func() {
304 » » for r := range b.IncomingRequests() { 304 » » for r := range b.incomingRequests {
305 if r.WantReply { 305 if r.WantReply {
306 b.AckRequest(true, nil) 306 b.AckRequest(true, nil)
307 } 307 }
308 } 308 }
309 }() 309 }()
310 310
311 a.Disconnect(42, "whatever") 311 a.Disconnect(42, "whatever")
312 ok, _, err := a.SendRequest("hello", true, nil) 312 ok, _, err := a.SendRequest("hello", true, nil)
313 if ok || err == nil { 313 if ok || err == nil {
314 t.Errorf("got reply after disconnecting") 314 t.Errorf("got reply after disconnecting")
(...skipping 61 matching lines...) Expand 10 before | Expand all | Expand 10 after
376 376
377 packet := make([]byte, 1+4+4+1) 377 packet := make([]byte, 1+4+4+1)
378 packet[0] = msgChannelData 378 packet[0] = msgChannelData
379 marshalUint32(packet[1:], 29348723 /* invalid channel id */) 379 marshalUint32(packet[1:], 29348723 /* invalid channel id */)
380 marshalUint32(packet[5:], 1) 380 marshalUint32(packet[5:], 1)
381 packet[9] = 42 381 packet[9] = 42
382 382
383 a.conn.writePacket(packet) 383 a.conn.writePacket(packet)
384 go a.SendRequest("hello", false, nil) 384 go a.SendRequest("hello", false, nil)
385 // 'a' wrote an invalid packet, so 'b' has exited. 385 // 'a' wrote an invalid packet, so 'b' has exited.
386 » req, ok := <-b.IncomingRequests() 386 » req, ok := <-b.incomingRequests
387 if ok { 387 if ok {
388 t.Errorf("got request %#v after receiving invalid packet", req) 388 t.Errorf("got request %#v after receiving invalid packet", req)
389 } 389 }
390 } 390 }
391 391
392 func TestZeroWindowAdjust(t *testing.T) { 392 func TestZeroWindowAdjust(t *testing.T) {
393 a, b, _ := channelPair(t) 393 a, b, _ := channelPair(t)
394 394
395 go func() { 395 go func() {
396 io.WriteString(a, "hello") 396 io.WriteString(a, "hello")
397 // bogus adjust. 397 // bogus adjust.
398 » » a.(*channel).sendMessage( 398 » » a.sendMessage(
399 msgChannelWindowAdjust, windowAdjustMsg{}) 399 msgChannelWindowAdjust, windowAdjustMsg{})
400 io.WriteString(a, "world") 400 io.WriteString(a, "world")
401 a.Close() 401 a.Close()
402 }() 402 }()
403 403
404 want := "helloworld" 404 want := "helloworld"
405 c, _ := ioutil.ReadAll(b) 405 c, _ := ioutil.ReadAll(b)
406 if string(c) != want { 406 if string(c) != want {
407 t.Errorf("got %q want %q", c, want) 407 t.Errorf("got %q want %q", c, want)
408 } 408 }
409 } 409 }
410 410
411 func TestMuxMaxPacketSize(t *testing.T) { 411 func TestMuxMaxPacketSize(t *testing.T) {
412 a, b, _ := channelPair(t) 412 a, b, _ := channelPair(t)
413 413
414 » ch := a.(*channel) 414 » large := make([]byte, a.maxPacket+1)
415 » large := make([]byte, ch.maxPacket+1) 415 » if err := a.writePacket(large); err == nil {
416 » if err := ch.writePacket(large); err == nil {
417 t.Errorf("channel sent out packet larger than maxPacket") 416 t.Errorf("channel sent out packet larger than maxPacket")
418 } 417 }
419 418
420 packet := make([]byte, 1+4+4+1+len(large)) 419 packet := make([]byte, 1+4+4+1+len(large))
421 packet[0] = msgChannelData 420 packet[0] = msgChannelData
422 » marshalUint32(packet[1:], ch.remoteId) 421 » marshalUint32(packet[1:], a.remoteId)
423 marshalUint32(packet[5:], uint32(len(large))) 422 marshalUint32(packet[5:], uint32(len(large)))
424 packet[9] = 42 423 packet[9] = 42
425 424
426 » if err := ch.mux.conn.writePacket(packet); err != nil { 425 » if err := a.mux.conn.writePacket(packet); err != nil {
427 t.Errorf("could not send packet") 426 t.Errorf("could not send packet")
428 } 427 }
429 428
430 go a.SendRequest("hello", false, nil) 429 go a.SendRequest("hello", false, nil)
431 430
432 » _, ok := <-b.IncomingRequests() 431 » _, ok := <-b.incomingRequests
433 if ok { 432 if ok {
434 t.Errorf("connection still alive after receiving large packet.") 433 t.Errorf("connection still alive after receiving large packet.")
435 } 434 }
436 } 435 }
LEFTRIGHT

Powered by Google App Engine
RSS Feeds Recent Issues | This issue
This is Rietveld f62528b