LEFT | RIGHT |
1 // Copyright 2013 The Go Authors. All rights reserved. | 1 // Copyright 2013 The Go Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style | 2 // Use of this source code is governed by a BSD-style |
3 // license that can be found in the LICENSE file. | 3 // license that can be found in the LICENSE file. |
4 | 4 |
5 package ssh | 5 package ssh |
6 | 6 |
7 import ( | 7 import ( |
8 "io" | 8 "io" |
9 "io/ioutil" | 9 "io/ioutil" |
10 "sync" | 10 "sync" |
11 "testing" | 11 "testing" |
12 "time" | 12 "time" |
13 ) | 13 ) |
14 | 14 |
15 func muxPair() (*mux, *mux) { | 15 func muxPair() (*mux, *mux) { |
16 a, b := memPipe() | 16 a, b := memPipe() |
17 | 17 |
18 s := newMux(a) | 18 s := newMux(a) |
19 c := newMux(b) | 19 c := newMux(b) |
20 | 20 |
21 go s.Loop() | 21 go s.Loop() |
22 go c.Loop() | 22 go c.Loop() |
23 | 23 |
24 return s, c | 24 return s, c |
25 } | 25 } |
26 | 26 |
27 // Returns both ends of a channel, and the mux for the the 2nd | 27 // Returns both ends of a channel, and the mux for the the 2nd |
28 // channel. | 28 // channel. |
29 func channelPair(t *testing.T) (Channel, Channel, *mux) { | 29 func channelPair(t *testing.T) (*channel, *channel, *mux) { |
30 c, s := muxPair() | 30 c, s := muxPair() |
31 | 31 |
32 » res := make(chan Channel, 1) | 32 » res := make(chan *channel, 1) |
33 » go func() { | 33 » go func() { |
34 » » ch, ok := <-s.IncomingChannels() | 34 » » ch, ok := <-s.incomingChannels |
35 if !ok { | 35 if !ok { |
36 t.Fatalf("No incoming channel") | 36 t.Fatalf("No incoming channel") |
37 } | 37 } |
38 if ch.ChannelType() != "chan" { | 38 if ch.ChannelType() != "chan" { |
39 t.Fatalf("got type %q want chan", ch.ChannelType()) | 39 t.Fatalf("got type %q want chan", ch.ChannelType()) |
40 } | 40 } |
41 » » channel, err := ch.Accept() | 41 » » err := ch.Accept() |
42 if err != nil { | 42 if err != nil { |
43 t.Fatalf("Accept %v", err) | 43 t.Fatalf("Accept %v", err) |
44 } | 44 } |
45 » » res <- channel | 45 » » res <- ch |
46 }() | 46 }() |
47 | 47 |
48 ch, err := c.OpenChannel("chan", nil) | 48 ch, err := c.OpenChannel("chan", nil) |
49 if err != nil { | 49 if err != nil { |
50 t.Fatalf("OpenChannel: %v", err) | 50 t.Fatalf("OpenChannel: %v", err) |
51 } | 51 } |
52 | 52 |
53 return <-res, ch, c | 53 return <-res, ch, c |
54 } | 54 } |
55 | 55 |
56 func TestMuxReadWrite(t *testing.T) { | 56 func TestMuxReadWrite(t *testing.T) { |
57 s, c, _ := channelPair(t) | 57 s, c, _ := channelPair(t) |
58 | 58 |
59 magic := "hello world" | 59 magic := "hello world" |
60 magicExt := "hello stderr" | 60 magicExt := "hello stderr" |
61 go func() { | 61 go func() { |
62 _, err := s.Write([]byte(magic)) | 62 _, err := s.Write([]byte(magic)) |
63 if err != nil { | 63 if err != nil { |
64 t.Fatalf("Write: %v", err) | 64 t.Fatalf("Write: %v", err) |
65 } | 65 } |
66 » » _, err = s.Stderr().Write([]byte(magicExt)) | 66 » » _, err = s.Extended(1).Write([]byte(magicExt)) |
67 if err != nil { | 67 if err != nil { |
68 t.Fatalf("Write: %v", err) | 68 t.Fatalf("Write: %v", err) |
69 } | 69 } |
70 err = s.Close() | 70 err = s.Close() |
71 if err != nil { | 71 if err != nil { |
72 t.Fatalf("Close: %v", err) | 72 t.Fatalf("Close: %v", err) |
73 } | 73 } |
74 }() | 74 }() |
75 | 75 |
76 var buf [1024]byte | 76 var buf [1024]byte |
77 n, err := c.Read(buf[:]) | 77 n, err := c.Read(buf[:]) |
78 if err != nil { | 78 if err != nil { |
79 t.Fatalf("server Read: %v", err) | 79 t.Fatalf("server Read: %v", err) |
80 } | 80 } |
81 got := string(buf[:n]) | 81 got := string(buf[:n]) |
82 if got != magic { | 82 if got != magic { |
83 t.Fatalf("server: got %q want %q", got, magic) | 83 t.Fatalf("server: got %q want %q", got, magic) |
84 } | 84 } |
85 | 85 |
86 » n, err = c.Stderr().Read(buf[:]) | 86 » n, err = c.Extended(1).Read(buf[:]) |
87 if err != nil { | 87 if err != nil { |
88 t.Fatalf("server Read: %v", err) | 88 t.Fatalf("server Read: %v", err) |
89 } | 89 } |
90 | 90 |
91 got = string(buf[:n]) | 91 got = string(buf[:n]) |
92 if got != magicExt { | 92 if got != magicExt { |
93 t.Fatalf("server: got %q want %q", got, magic) | 93 t.Fatalf("server: got %q want %q", got, magic) |
94 } | 94 } |
95 } | 95 } |
96 | 96 |
97 func TestMuxFlowControl(t *testing.T) { | 97 func TestMuxFlowControl(t *testing.T) { |
98 writerMux, readerMux := muxPair() | 98 writerMux, readerMux := muxPair() |
99 | 99 |
100 // this goroutine reads just a bit. | 100 // this goroutine reads just a bit. |
101 go func() { | 101 go func() { |
102 » » readerCreate, ok := <-readerMux.IncomingChannels() | 102 » » reader, ok := <-readerMux.incomingChannels |
103 if !ok { | 103 if !ok { |
104 t.Fatalf("no incoming channel") | 104 t.Fatalf("no incoming channel") |
105 } | 105 } |
106 » » reader, err := readerCreate.Accept() | 106 » » err := reader.Accept() |
107 if err != nil { | 107 if err != nil { |
108 t.Fatalf("Accept: %v", err) | 108 t.Fatalf("Accept: %v", err) |
109 } | 109 } |
110 | 110 |
111 b := make([]byte, 1024) | 111 b := make([]byte, 1024) |
112 n, err := reader.Read(b) | 112 n, err := reader.Read(b) |
113 if err != nil || n != len(b) { | 113 if err != nil || n != len(b) { |
114 t.Errorf("Read: %v, %d bytes", err, n) | 114 t.Errorf("Read: %v, %d bytes", err, n) |
115 } | 115 } |
116 }() | 116 }() |
(...skipping 22 matching lines...) Expand all Loading... |
139 time.Sleep(1 * time.Millisecond) | 139 time.Sleep(1 * time.Millisecond) |
140 | 140 |
141 readerMux.Disconnect(0, "") | 141 readerMux.Disconnect(0, "") |
142 writerMux.Disconnect(0, "") | 142 writerMux.Disconnect(0, "") |
143 } | 143 } |
144 | 144 |
145 func TestMuxReject(t *testing.T) { | 145 func TestMuxReject(t *testing.T) { |
146 client, server := muxPair() | 146 client, server := muxPair() |
147 | 147 |
148 go func() { | 148 go func() { |
149 » » ch, ok := <-server.IncomingChannels() | 149 » » ch, ok := <-server.incomingChannels |
150 if !ok { | 150 if !ok { |
151 t.Fatalf("Accept") | 151 t.Fatalf("Accept") |
152 } | 152 } |
153 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra"
{ | 153 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra"
{ |
154 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(),
ch.ExtraData()) | 154 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(),
ch.ExtraData()) |
155 } | 155 } |
156 ch.Reject(RejectionReason(42), "message") | 156 ch.Reject(RejectionReason(42), "message") |
157 }() | 157 }() |
158 | 158 |
159 ch, err := client.OpenChannel("ch", []byte("extra")) | 159 ch, err := client.OpenChannel("ch", []byte("extra")) |
(...skipping 13 matching lines...) Expand all Loading... |
173 t.Errorf("got %q, want %q", err.Error(), want) | 173 t.Errorf("got %q, want %q", err.Error(), want) |
174 } | 174 } |
175 } | 175 } |
176 | 176 |
177 func TestMuxChannelRequest(t *testing.T) { | 177 func TestMuxChannelRequest(t *testing.T) { |
178 client, server, _ := channelPair(t) | 178 client, server, _ := channelPair(t) |
179 var received int | 179 var received int |
180 var wg sync.WaitGroup | 180 var wg sync.WaitGroup |
181 wg.Add(1) | 181 wg.Add(1) |
182 go func() { | 182 go func() { |
183 » » for r := range server.IncomingRequests() { | 183 » » for r := range server.incomingRequests { |
184 received++ | 184 received++ |
185 if r.WantReply { | 185 if r.WantReply { |
186 server.AckRequest(r.Request == "yes") | 186 server.AckRequest(r.Request == "yes") |
187 } | 187 } |
188 } | 188 } |
189 wg.Done() | 189 wg.Done() |
190 }() | 190 }() |
191 _, err := client.SendRequest("yes", false, nil) | 191 _, err := client.SendRequest("yes", false, nil) |
192 if err != nil { | 192 if err != nil { |
193 t.Fatalf("SendRequest: %v", err) | 193 t.Fatalf("SendRequest: %v", err) |
(...skipping 23 matching lines...) Expand all Loading... |
217 if received != 3 { | 217 if received != 3 { |
218 t.Errorf("got %d requests, want %d", received) | 218 t.Errorf("got %d requests, want %d", received) |
219 } | 219 } |
220 } | 220 } |
221 | 221 |
222 func TestMuxGlobalRequest(t *testing.T) { | 222 func TestMuxGlobalRequest(t *testing.T) { |
223 clientMux, serverMux := muxPair() | 223 clientMux, serverMux := muxPair() |
224 | 224 |
225 var seen bool | 225 var seen bool |
226 go func() { | 226 go func() { |
227 » » for r := range serverMux.IncomingRequests() { | 227 » » for r := range serverMux.incomingRequests { |
228 seen = seen || r.Request == "peek" | 228 seen = seen || r.Request == "peek" |
229 if r.WantReply { | 229 if r.WantReply { |
230 err := serverMux.AckRequest(r.Request == "yes", | 230 err := serverMux.AckRequest(r.Request == "yes", |
231 append([]byte(r.Request), r.Payload...)) | 231 append([]byte(r.Request), r.Payload...)) |
232 if err != nil { | 232 if err != nil { |
233 t.Errorf("AckRequest: %v", err) | 233 t.Errorf("AckRequest: %v", err) |
234 } | 234 } |
235 } | 235 } |
236 } | 236 } |
237 }() | 237 }() |
(...skipping 26 matching lines...) Expand all Loading... |
264 | 264 |
265 func TestMuxGlobalRequestUnblock(t *testing.T) { | 265 func TestMuxGlobalRequestUnblock(t *testing.T) { |
266 clientMux, serverMux := muxPair() | 266 clientMux, serverMux := muxPair() |
267 | 267 |
268 result := make(chan error, 1) | 268 result := make(chan error, 1) |
269 go func() { | 269 go func() { |
270 _, _, err := clientMux.SendRequest("hello", true, nil) | 270 _, _, err := clientMux.SendRequest("hello", true, nil) |
271 result <- err | 271 result <- err |
272 }() | 272 }() |
273 | 273 |
274 » <-serverMux.IncomingRequests() | 274 » <-serverMux.incomingRequests |
275 serverMux.conn.Close() | 275 serverMux.conn.Close() |
276 err := <-result | 276 err := <-result |
277 | 277 |
278 if err != io.EOF { | 278 if err != io.EOF { |
279 t.Errorf("want EOF, got %v", io.EOF) | 279 t.Errorf("want EOF, got %v", io.EOF) |
280 } | 280 } |
281 } | 281 } |
282 | 282 |
283 func TestMuxChannelRequestUnblock(t *testing.T) { | 283 func TestMuxChannelRequestUnblock(t *testing.T) { |
284 a, b, connB := channelPair(t) | 284 a, b, connB := channelPair(t) |
285 | 285 |
286 result := make(chan error, 1) | 286 result := make(chan error, 1) |
287 go func() { | 287 go func() { |
288 _, err := a.SendRequest("hello", true, nil) | 288 _, err := a.SendRequest("hello", true, nil) |
289 result <- err | 289 result <- err |
290 }() | 290 }() |
291 | 291 |
292 » <-b.IncomingRequests() | 292 » <-b.incomingRequests |
293 connB.conn.Close() | 293 connB.conn.Close() |
294 err := <-result | 294 err := <-result |
295 | 295 |
296 if err != io.EOF { | 296 if err != io.EOF { |
297 t.Errorf("want EOF, got %v", err) | 297 t.Errorf("want EOF, got %v", err) |
298 } | 298 } |
299 } | 299 } |
300 | 300 |
301 func TestMuxDisconnect(t *testing.T) { | 301 func TestMuxDisconnect(t *testing.T) { |
302 a, b := muxPair() | 302 a, b := muxPair() |
303 go func() { | 303 go func() { |
304 » » for r := range b.IncomingRequests() { | 304 » » for r := range b.incomingRequests { |
305 if r.WantReply { | 305 if r.WantReply { |
306 b.AckRequest(true, nil) | 306 b.AckRequest(true, nil) |
307 } | 307 } |
308 } | 308 } |
309 }() | 309 }() |
310 | 310 |
311 a.Disconnect(42, "whatever") | 311 a.Disconnect(42, "whatever") |
312 ok, _, err := a.SendRequest("hello", true, nil) | 312 ok, _, err := a.SendRequest("hello", true, nil) |
313 if ok || err == nil { | 313 if ok || err == nil { |
314 t.Errorf("got reply after disconnecting") | 314 t.Errorf("got reply after disconnecting") |
(...skipping 61 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
376 | 376 |
377 packet := make([]byte, 1+4+4+1) | 377 packet := make([]byte, 1+4+4+1) |
378 packet[0] = msgChannelData | 378 packet[0] = msgChannelData |
379 marshalUint32(packet[1:], 29348723 /* invalid channel id */) | 379 marshalUint32(packet[1:], 29348723 /* invalid channel id */) |
380 marshalUint32(packet[5:], 1) | 380 marshalUint32(packet[5:], 1) |
381 packet[9] = 42 | 381 packet[9] = 42 |
382 | 382 |
383 a.conn.writePacket(packet) | 383 a.conn.writePacket(packet) |
384 go a.SendRequest("hello", false, nil) | 384 go a.SendRequest("hello", false, nil) |
385 // 'a' wrote an invalid packet, so 'b' has exited. | 385 // 'a' wrote an invalid packet, so 'b' has exited. |
386 » req, ok := <-b.IncomingRequests() | 386 » req, ok := <-b.incomingRequests |
387 if ok { | 387 if ok { |
388 t.Errorf("got request %#v after receiving invalid packet", req) | 388 t.Errorf("got request %#v after receiving invalid packet", req) |
389 } | 389 } |
390 } | 390 } |
391 | 391 |
392 func TestZeroWindowAdjust(t *testing.T) { | 392 func TestZeroWindowAdjust(t *testing.T) { |
393 a, b, _ := channelPair(t) | 393 a, b, _ := channelPair(t) |
394 | 394 |
395 go func() { | 395 go func() { |
396 io.WriteString(a, "hello") | 396 io.WriteString(a, "hello") |
397 // bogus adjust. | 397 // bogus adjust. |
398 » » a.(*channel).sendMessage( | 398 » » a.sendMessage( |
399 msgChannelWindowAdjust, windowAdjustMsg{}) | 399 msgChannelWindowAdjust, windowAdjustMsg{}) |
400 io.WriteString(a, "world") | 400 io.WriteString(a, "world") |
401 a.Close() | 401 a.Close() |
402 }() | 402 }() |
403 | 403 |
404 want := "helloworld" | 404 want := "helloworld" |
405 c, _ := ioutil.ReadAll(b) | 405 c, _ := ioutil.ReadAll(b) |
406 if string(c) != want { | 406 if string(c) != want { |
407 t.Errorf("got %q want %q", c, want) | 407 t.Errorf("got %q want %q", c, want) |
408 } | 408 } |
409 } | 409 } |
410 | 410 |
411 func TestMuxMaxPacketSize(t *testing.T) { | 411 func TestMuxMaxPacketSize(t *testing.T) { |
412 a, b, _ := channelPair(t) | 412 a, b, _ := channelPair(t) |
413 | 413 |
414 » ch := a.(*channel) | 414 » large := make([]byte, a.maxPacket+1) |
415 » large := make([]byte, ch.maxPacket+1) | 415 » if err := a.writePacket(large); err == nil { |
416 » if err := ch.writePacket(large); err == nil { | |
417 t.Errorf("channel sent out packet larger than maxPacket") | 416 t.Errorf("channel sent out packet larger than maxPacket") |
418 } | 417 } |
419 | 418 |
420 packet := make([]byte, 1+4+4+1+len(large)) | 419 packet := make([]byte, 1+4+4+1+len(large)) |
421 packet[0] = msgChannelData | 420 packet[0] = msgChannelData |
422 » marshalUint32(packet[1:], ch.remoteId) | 421 » marshalUint32(packet[1:], a.remoteId) |
423 marshalUint32(packet[5:], uint32(len(large))) | 422 marshalUint32(packet[5:], uint32(len(large))) |
424 packet[9] = 42 | 423 packet[9] = 42 |
425 | 424 |
426 » if err := ch.mux.conn.writePacket(packet); err != nil { | 425 » if err := a.mux.conn.writePacket(packet); err != nil { |
427 t.Errorf("could not send packet") | 426 t.Errorf("could not send packet") |
428 } | 427 } |
429 | 428 |
430 go a.SendRequest("hello", false, nil) | 429 go a.SendRequest("hello", false, nil) |
431 | 430 |
432 » _, ok := <-b.IncomingRequests() | 431 » _, ok := <-b.incomingRequests |
433 if ok { | 432 if ok { |
434 t.Errorf("connection still alive after receiving large packet.") | 433 t.Errorf("connection still alive after receiving large packet.") |
435 } | 434 } |
436 } | 435 } |
LEFT | RIGHT |