Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(1801)

Delta Between Two Patch Sets: ssh/mux_test.go

Issue 14225043: code review 14225043: go.crypto/ssh: reimplement SSH connection protocol modu... (Closed)
Left Patch Set: diff -r ef3d7138b0b5 https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Right Patch Set: diff -r cd1eea1eb828 https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
Left: Side by side diff | Download
Right: Side by side diff | Download
« no previous file with change/comment | « ssh/mux.go ('k') | ssh/server.go » ('j') | no next file with change/comment »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
LEFTRIGHT
1 // Copyright 2013 The Go Authors. All rights reserved. 1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style 2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file. 3 // license that can be found in the LICENSE file.
4 4
5 package ssh 5 package ssh
6 6
7 import ( 7 import (
8 "io" 8 "io"
9 "io/ioutil" 9 "io/ioutil"
10 "sync" 10 "sync"
11 "testing" 11 "testing"
12 "time" 12 "time"
13 ) 13 )
14 14
15 func muxPair() (*mux, *mux) { 15 func muxPair() (*mux, *mux) {
16 a, b := memPipe() 16 a, b := memPipe()
17 17
18 s := newMux(a) 18 s := newMux(a)
19 c := newMux(b) 19 c := newMux(b)
20 20
21 go s.Loop() 21 go s.Loop()
22 go c.Loop() 22 go c.Loop()
23 23
24 return s, c 24 return s, c
25 } 25 }
26 26
27 // Returns both ends of a channel, and the mux for the the 2nd 27 // Returns both ends of a channel, and the mux for the the 2nd
28 // channel. 28 // channel.
29 func channelPair(t *testing.T) (Channel, Channel, *mux) { 29 func channelPair(t *testing.T) (*channel, *channel, *mux) {
30 c, s := muxPair() 30 c, s := muxPair()
31 31
32 » res := make(chan Channel, 1) 32 » res := make(chan *channel, 1)
33 » go func() { 33 » go func() {
34 » » ch, err := s.Accept() 34 » » ch, ok := <-s.incomingChannels
35 » » if !ok {
36 » » » t.Fatalf("No incoming channel")
37 » » }
38 » » if ch.ChannelType() != "chan" {
39 » » » t.Fatalf("got type %q want chan", ch.ChannelType())
40 » » }
41 » » err := ch.Accept()
35 if err != nil { 42 if err != nil {
36 t.Fatalf("Accept %v", err) 43 t.Fatalf("Accept %v", err)
37 } 44 }
38 if ch.ChannelType() != "chan" {
39 t.Fatalf("got type %q want chan", ch.ChannelType())
40 }
41 ch.Accept()
42 res <- ch 45 res <- ch
43 }() 46 }()
44 47
45 ch, err := c.OpenChannel("chan", nil) 48 ch, err := c.OpenChannel("chan", nil)
46 if err != nil { 49 if err != nil {
47 t.Fatalf("OpenChannel: %v", err) 50 t.Fatalf("OpenChannel: %v", err)
48 } 51 }
49 52
50 return <-res, ch, c 53 return <-res, ch, c
51 } 54 }
52 55
53 func TestMuxReadWrite(t *testing.T) { 56 func TestMuxReadWrite(t *testing.T) {
54 s, c, _ := channelPair(t) 57 s, c, _ := channelPair(t)
55 58
56 magic := "hello world" 59 magic := "hello world"
57 magicExt := "hello stderr" 60 magicExt := "hello stderr"
58 go func() { 61 go func() {
59 _, err := s.Write([]byte(magic)) 62 _, err := s.Write([]byte(magic))
60 if err != nil { 63 if err != nil {
61 t.Fatalf("Write: %v", err) 64 t.Fatalf("Write: %v", err)
62 } 65 }
63 » » _, err = s.Stderr().Write([]byte(magicExt)) 66 » » _, err = s.Extended(1).Write([]byte(magicExt))
64 if err != nil { 67 if err != nil {
65 t.Fatalf("Write: %v", err) 68 t.Fatalf("Write: %v", err)
66 } 69 }
67 err = s.Close() 70 err = s.Close()
68 if err != nil { 71 if err != nil {
69 t.Fatalf("Close: %v", err) 72 t.Fatalf("Close: %v", err)
70 } 73 }
71 }() 74 }()
72 75
73 var buf [1024]byte 76 var buf [1024]byte
74 n, err := c.Read(buf[:]) 77 n, err := c.Read(buf[:])
75 if err != nil { 78 if err != nil {
76 t.Fatalf("server Read: %v", err) 79 t.Fatalf("server Read: %v", err)
77 } 80 }
78 got := string(buf[:n]) 81 got := string(buf[:n])
79 if got != magic { 82 if got != magic {
80 t.Fatalf("server: got %q want %q", got, magic) 83 t.Fatalf("server: got %q want %q", got, magic)
81 } 84 }
82 85
83 » n, err = c.Stderr().Read(buf[:]) 86 » n, err = c.Extended(1).Read(buf[:])
84 if err != nil { 87 if err != nil {
85 t.Fatalf("server Read: %v", err) 88 t.Fatalf("server Read: %v", err)
86 } 89 }
87 90
88 got = string(buf[:n]) 91 got = string(buf[:n])
89 if got != magicExt { 92 if got != magicExt {
90 t.Fatalf("server: got %q want %q", got, magic) 93 t.Fatalf("server: got %q want %q", got, magic)
91 } 94 }
92 } 95 }
93 96
94 func TestMuxFlowControl(t *testing.T) { 97 func TestMuxFlowControl(t *testing.T) {
95 writerMux, readerMux := muxPair() 98 writerMux, readerMux := muxPair()
96 99
97 // this goroutine reads just a bit. 100 // this goroutine reads just a bit.
98 go func() { 101 go func() {
99 » » reader, err := readerMux.Accept() 102 » » reader, ok := <-readerMux.incomingChannels
103 » » if !ok {
104 » » » t.Fatalf("no incoming channel")
105 » » }
106 » » err := reader.Accept()
100 if err != nil { 107 if err != nil {
101 t.Fatalf("Accept: %v", err)
102 }
103 if err = reader.Accept(); err != nil {
104 t.Fatalf("Accept: %v", err) 108 t.Fatalf("Accept: %v", err)
105 } 109 }
106 110
107 b := make([]byte, 1024) 111 b := make([]byte, 1024)
108 n, err := reader.Read(b) 112 n, err := reader.Read(b)
109 if err != nil || n != len(b) { 113 if err != nil || n != len(b) {
110 t.Errorf("Read: %v, %d bytes", err, n) 114 t.Errorf("Read: %v, %d bytes", err, n)
111 } 115 }
112 }() 116 }()
113 117
(...skipping 21 matching lines...) Expand all
135 time.Sleep(1 * time.Millisecond) 139 time.Sleep(1 * time.Millisecond)
136 140
137 readerMux.Disconnect(0, "") 141 readerMux.Disconnect(0, "")
138 writerMux.Disconnect(0, "") 142 writerMux.Disconnect(0, "")
139 } 143 }
140 144
141 func TestMuxReject(t *testing.T) { 145 func TestMuxReject(t *testing.T) {
142 client, server := muxPair() 146 client, server := muxPair()
143 147
144 go func() { 148 go func() {
145 » » ch, err := server.Accept() 149 » » ch, ok := <-server.incomingChannels
146 » » if err != nil { 150 » » if !ok {
147 » » » t.Fatalf("Accept: %v", err) 151 » » » t.Fatalf("Accept")
148 } 152 }
149 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { 153 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
150 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) 154 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
151 } 155 }
152 ch.Reject(RejectionReason(42), "message") 156 ch.Reject(RejectionReason(42), "message")
153 }() 157 }()
154 158
155 ch, err := client.OpenChannel("ch", []byte("extra")) 159 ch, err := client.OpenChannel("ch", []byte("extra"))
156 if ch != nil { 160 if ch != nil {
157 t.Fatal("openChannel not rejected") 161 t.Fatal("openChannel not rejected")
(...skipping 11 matching lines...) Expand all
169 t.Errorf("got %q, want %q", err.Error(), want) 173 t.Errorf("got %q, want %q", err.Error(), want)
170 } 174 }
171 } 175 }
172 176
173 func TestMuxChannelRequest(t *testing.T) { 177 func TestMuxChannelRequest(t *testing.T) {
174 client, server, _ := channelPair(t) 178 client, server, _ := channelPair(t)
175 var received int 179 var received int
176 var wg sync.WaitGroup 180 var wg sync.WaitGroup
177 wg.Add(1) 181 wg.Add(1)
178 go func() { 182 go func() {
179 » » for r := range server.IncomingRequests() { 183 » » for r := range server.incomingRequests {
180 received++ 184 received++
181 if r.WantReply { 185 if r.WantReply {
182 server.AckRequest(r.Request == "yes") 186 server.AckRequest(r.Request == "yes")
183 } 187 }
184 } 188 }
185 wg.Done() 189 wg.Done()
186 }() 190 }()
187 _, err := client.SendRequest("yes", false, nil) 191 _, err := client.SendRequest("yes", false, nil)
188 if err != nil { 192 if err != nil {
189 t.Fatalf("SendRequest: %v", err) 193 t.Fatalf("SendRequest: %v", err)
(...skipping 23 matching lines...) Expand all
213 if received != 3 { 217 if received != 3 {
214 t.Errorf("got %d requests, want %d", received) 218 t.Errorf("got %d requests, want %d", received)
215 } 219 }
216 } 220 }
217 221
218 func TestMuxGlobalRequest(t *testing.T) { 222 func TestMuxGlobalRequest(t *testing.T) {
219 clientMux, serverMux := muxPair() 223 clientMux, serverMux := muxPair()
220 224
221 var seen bool 225 var seen bool
222 go func() { 226 go func() {
223 » » for r := range serverMux.IncomingRequests() { 227 » » for r := range serverMux.incomingRequests {
224 seen = seen || r.Request == "peek" 228 seen = seen || r.Request == "peek"
225 if r.WantReply { 229 if r.WantReply {
226 err := serverMux.AckRequest(r.Request == "yes", 230 err := serverMux.AckRequest(r.Request == "yes",
227 append([]byte(r.Request), r.Payload...)) 231 append([]byte(r.Request), r.Payload...))
228 if err != nil { 232 if err != nil {
229 t.Errorf("AckRequest: %v", err) 233 t.Errorf("AckRequest: %v", err)
230 } 234 }
231 } 235 }
232 } 236 }
233 }() 237 }()
(...skipping 26 matching lines...) Expand all
260 264
261 func TestMuxGlobalRequestUnblock(t *testing.T) { 265 func TestMuxGlobalRequestUnblock(t *testing.T) {
262 clientMux, serverMux := muxPair() 266 clientMux, serverMux := muxPair()
263 267
264 result := make(chan error, 1) 268 result := make(chan error, 1)
265 go func() { 269 go func() {
266 _, _, err := clientMux.SendRequest("hello", true, nil) 270 _, _, err := clientMux.SendRequest("hello", true, nil)
267 result <- err 271 result <- err
268 }() 272 }()
269 273
270 » <-serverMux.IncomingRequests() 274 » <-serverMux.incomingRequests
271 serverMux.conn.Close() 275 serverMux.conn.Close()
272 err := <-result 276 err := <-result
273 277
274 if err != io.EOF { 278 if err != io.EOF {
275 t.Errorf("want EOF, got %v", io.EOF) 279 t.Errorf("want EOF, got %v", io.EOF)
276 } 280 }
277 } 281 }
278 282
279 func TestMuxChannelRequestUnblock(t *testing.T) { 283 func TestMuxChannelRequestUnblock(t *testing.T) {
280 a, b, connB := channelPair(t) 284 a, b, connB := channelPair(t)
281 285
282 result := make(chan error, 1) 286 result := make(chan error, 1)
283 go func() { 287 go func() {
284 _, err := a.SendRequest("hello", true, nil) 288 _, err := a.SendRequest("hello", true, nil)
285 result <- err 289 result <- err
286 }() 290 }()
287 291
288 » <-b.IncomingRequests() 292 » <-b.incomingRequests
289 connB.conn.Close() 293 connB.conn.Close()
290 err := <-result 294 err := <-result
291 295
292 if err != io.EOF { 296 if err != io.EOF {
293 t.Errorf("want EOF, got %v", err) 297 t.Errorf("want EOF, got %v", err)
294 } 298 }
295 } 299 }
296 300
297 func TestMuxDisconnect(t *testing.T) { 301 func TestMuxDisconnect(t *testing.T) {
298 a, b := muxPair() 302 a, b := muxPair()
299 go func() { 303 go func() {
300 » » for r := range b.IncomingRequests() { 304 » » for r := range b.incomingRequests {
301 if r.WantReply { 305 if r.WantReply {
302 b.AckRequest(true, nil) 306 b.AckRequest(true, nil)
303 } 307 }
304 } 308 }
305 }() 309 }()
306 310
307 a.Disconnect(42, "whatever") 311 a.Disconnect(42, "whatever")
308 ok, _, err := a.SendRequest("hello", true, nil) 312 ok, _, err := a.SendRequest("hello", true, nil)
309 if ok || err == nil { 313 if ok || err == nil {
310 t.Errorf("got reply after disconnecting") 314 t.Errorf("got reply after disconnecting")
(...skipping 61 matching lines...) Expand 10 before | Expand all | Expand 10 after
372 376
373 packet := make([]byte, 1+4+4+1) 377 packet := make([]byte, 1+4+4+1)
374 packet[0] = msgChannelData 378 packet[0] = msgChannelData
375 marshalUint32(packet[1:], 29348723 /* invalid channel id */) 379 marshalUint32(packet[1:], 29348723 /* invalid channel id */)
376 marshalUint32(packet[5:], 1) 380 marshalUint32(packet[5:], 1)
377 packet[9] = 42 381 packet[9] = 42
378 382
379 a.conn.writePacket(packet) 383 a.conn.writePacket(packet)
380 go a.SendRequest("hello", false, nil) 384 go a.SendRequest("hello", false, nil)
381 // 'a' wrote an invalid packet, so 'b' has exited. 385 // 'a' wrote an invalid packet, so 'b' has exited.
382 » req, ok := <-b.IncomingRequests() 386 » req, ok := <-b.incomingRequests
383 if ok { 387 if ok {
384 t.Errorf("got request %#v after receiving invalid packet", req) 388 t.Errorf("got request %#v after receiving invalid packet", req)
385 } 389 }
386 } 390 }
387 391
388 func TestZeroWindowAdjust(t *testing.T) { 392 func TestZeroWindowAdjust(t *testing.T) {
389 a, b, _ := channelPair(t) 393 a, b, _ := channelPair(t)
390 394
391 go func() { 395 go func() {
392 io.WriteString(a, "hello") 396 io.WriteString(a, "hello")
393 // bogus adjust. 397 // bogus adjust.
394 » » a.(*channel).sendMessage( 398 » » a.sendMessage(
395 msgChannelWindowAdjust, windowAdjustMsg{}) 399 msgChannelWindowAdjust, windowAdjustMsg{})
396 io.WriteString(a, "world") 400 io.WriteString(a, "world")
397 a.Close() 401 a.Close()
398 }() 402 }()
399 403
400 want := "helloworld" 404 want := "helloworld"
401 c, _ := ioutil.ReadAll(b) 405 c, _ := ioutil.ReadAll(b)
402 if string(c) != want { 406 if string(c) != want {
403 t.Errorf("got %q want %q", c, want) 407 t.Errorf("got %q want %q", c, want)
404 } 408 }
405 } 409 }
406 410
407 func TestMuxMaxPacketSize(t *testing.T) { 411 func TestMuxMaxPacketSize(t *testing.T) {
408 a, b, _ := channelPair(t) 412 a, b, _ := channelPair(t)
409 413
410 » ch := a.(*channel) 414 » large := make([]byte, a.maxPacket+1)
411 » large := make([]byte, ch.maxPacket+1) 415 » if err := a.writePacket(large); err == nil {
412 » if err := ch.writePacket(large); err == nil {
413 t.Errorf("channel sent out packet larger than maxPacket") 416 t.Errorf("channel sent out packet larger than maxPacket")
414 } 417 }
415 418
416 packet := make([]byte, 1+4+4+1+len(large)) 419 packet := make([]byte, 1+4+4+1+len(large))
417 packet[0] = msgChannelData 420 packet[0] = msgChannelData
418 » marshalUint32(packet[1:], ch.remoteId) 421 » marshalUint32(packet[1:], a.remoteId)
419 marshalUint32(packet[5:], uint32(len(large))) 422 marshalUint32(packet[5:], uint32(len(large)))
420 packet[9] = 42 423 packet[9] = 42
421 424
422 » if err := ch.mux.conn.writePacket(packet); err != nil { 425 » if err := a.mux.conn.writePacket(packet); err != nil {
423 t.Errorf("could not send packet") 426 t.Errorf("could not send packet")
424 } 427 }
425 428
426 go a.SendRequest("hello", false, nil) 429 go a.SendRequest("hello", false, nil)
427 430
428 » _, ok := <-b.IncomingRequests() 431 » _, ok := <-b.incomingRequests
429 if ok { 432 if ok {
430 t.Errorf("connection still alive after receiving large packet.") 433 t.Errorf("connection still alive after receiving large packet.")
431 } 434 }
432 } 435 }
LEFTRIGHT

Powered by Google App Engine
RSS Feeds Recent Issues | This issue
This is Rietveld f62528b