LEFT | RIGHT |
1 // Copyright 2013 The Go Authors. All rights reserved. | 1 // Copyright 2013 The Go Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style | 2 // Use of this source code is governed by a BSD-style |
3 // license that can be found in the LICENSE file. | 3 // license that can be found in the LICENSE file. |
4 | 4 |
5 package ssh | 5 package ssh |
6 | 6 |
7 import ( | 7 import ( |
8 "io" | 8 "io" |
9 "io/ioutil" | 9 "io/ioutil" |
10 "sync" | 10 "sync" |
11 "testing" | 11 "testing" |
12 "time" | 12 "time" |
13 ) | 13 ) |
14 | 14 |
15 func muxPair() (*mux, *mux) { | 15 func muxPair() (*mux, *mux) { |
16 a, b := memPipe() | 16 a, b := memPipe() |
17 | 17 |
18 s := newMux(a) | 18 s := newMux(a) |
19 c := newMux(b) | 19 c := newMux(b) |
20 | 20 |
21 go s.Loop() | 21 go s.Loop() |
22 go c.Loop() | 22 go c.Loop() |
23 | 23 |
24 return s, c | 24 return s, c |
25 } | 25 } |
26 | 26 |
27 // Returns both ends of a channel, and the mux for the the 2nd | 27 // Returns both ends of a channel, and the mux for the the 2nd |
28 // channel. | 28 // channel. |
29 func channelPair(t *testing.T) (Channel, Channel, *mux) { | 29 func channelPair(t *testing.T) (*channel, *channel, *mux) { |
30 c, s := muxPair() | 30 c, s := muxPair() |
31 | 31 |
32 » res := make(chan Channel, 1) | 32 » res := make(chan *channel, 1) |
33 » go func() { | 33 » go func() { |
34 » » ch, err := s.Accept() | 34 » » ch, ok := <-s.incomingChannels |
| 35 » » if !ok { |
| 36 » » » t.Fatalf("No incoming channel") |
| 37 » » } |
| 38 » » if ch.ChannelType() != "chan" { |
| 39 » » » t.Fatalf("got type %q want chan", ch.ChannelType()) |
| 40 » » } |
| 41 » » err := ch.Accept() |
35 if err != nil { | 42 if err != nil { |
36 t.Fatalf("Accept %v", err) | 43 t.Fatalf("Accept %v", err) |
37 } | 44 } |
38 if ch.ChannelType() != "chan" { | |
39 t.Fatalf("got type %q want chan", ch.ChannelType()) | |
40 } | |
41 ch.Accept() | |
42 res <- ch | 45 res <- ch |
43 }() | 46 }() |
44 | 47 |
45 ch, err := c.OpenChannel("chan", nil) | 48 ch, err := c.OpenChannel("chan", nil) |
46 if err != nil { | 49 if err != nil { |
47 t.Fatalf("OpenChannel: %v", err) | 50 t.Fatalf("OpenChannel: %v", err) |
48 } | 51 } |
49 | 52 |
50 return <-res, ch, c | 53 return <-res, ch, c |
51 } | 54 } |
52 | 55 |
53 func TestMuxReadWrite(t *testing.T) { | 56 func TestMuxReadWrite(t *testing.T) { |
54 s, c, _ := channelPair(t) | 57 s, c, _ := channelPair(t) |
55 | 58 |
56 magic := "hello world" | 59 magic := "hello world" |
57 magicExt := "hello stderr" | 60 magicExt := "hello stderr" |
58 go func() { | 61 go func() { |
59 _, err := s.Write([]byte(magic)) | 62 _, err := s.Write([]byte(magic)) |
60 if err != nil { | 63 if err != nil { |
61 t.Fatalf("Write: %v", err) | 64 t.Fatalf("Write: %v", err) |
62 } | 65 } |
63 » » _, err = s.Stderr().Write([]byte(magicExt)) | 66 » » _, err = s.Extended(1).Write([]byte(magicExt)) |
64 if err != nil { | 67 if err != nil { |
65 t.Fatalf("Write: %v", err) | 68 t.Fatalf("Write: %v", err) |
66 } | 69 } |
67 err = s.Close() | 70 err = s.Close() |
68 if err != nil { | 71 if err != nil { |
69 t.Fatalf("Close: %v", err) | 72 t.Fatalf("Close: %v", err) |
70 } | 73 } |
71 }() | 74 }() |
72 | 75 |
73 var buf [1024]byte | 76 var buf [1024]byte |
74 n, err := c.Read(buf[:]) | 77 n, err := c.Read(buf[:]) |
75 if err != nil { | 78 if err != nil { |
76 t.Fatalf("server Read: %v", err) | 79 t.Fatalf("server Read: %v", err) |
77 } | 80 } |
78 got := string(buf[:n]) | 81 got := string(buf[:n]) |
79 if got != magic { | 82 if got != magic { |
80 t.Fatalf("server: got %q want %q", got, magic) | 83 t.Fatalf("server: got %q want %q", got, magic) |
81 } | 84 } |
82 | 85 |
83 » n, err = c.Stderr().Read(buf[:]) | 86 » n, err = c.Extended(1).Read(buf[:]) |
84 if err != nil { | 87 if err != nil { |
85 t.Fatalf("server Read: %v", err) | 88 t.Fatalf("server Read: %v", err) |
86 } | 89 } |
87 | 90 |
88 got = string(buf[:n]) | 91 got = string(buf[:n]) |
89 if got != magicExt { | 92 if got != magicExt { |
90 t.Fatalf("server: got %q want %q", got, magic) | 93 t.Fatalf("server: got %q want %q", got, magic) |
91 } | 94 } |
92 } | 95 } |
93 | 96 |
94 func TestMuxFlowControl(t *testing.T) { | 97 func TestMuxFlowControl(t *testing.T) { |
95 writerMux, readerMux := muxPair() | 98 writerMux, readerMux := muxPair() |
96 | 99 |
97 // this goroutine reads just a bit. | 100 // this goroutine reads just a bit. |
98 go func() { | 101 go func() { |
99 » » reader, err := readerMux.Accept() | 102 » » reader, ok := <-readerMux.incomingChannels |
| 103 » » if !ok { |
| 104 » » » t.Fatalf("no incoming channel") |
| 105 » » } |
| 106 » » err := reader.Accept() |
100 if err != nil { | 107 if err != nil { |
101 t.Fatalf("Accept: %v", err) | |
102 } | |
103 if err = reader.Accept(); err != nil { | |
104 t.Fatalf("Accept: %v", err) | 108 t.Fatalf("Accept: %v", err) |
105 } | 109 } |
106 | 110 |
107 b := make([]byte, 1024) | 111 b := make([]byte, 1024) |
108 n, err := reader.Read(b) | 112 n, err := reader.Read(b) |
109 if err != nil || n != len(b) { | 113 if err != nil || n != len(b) { |
110 t.Errorf("Read: %v, %d bytes", err, n) | 114 t.Errorf("Read: %v, %d bytes", err, n) |
111 } | 115 } |
112 }() | 116 }() |
113 | 117 |
(...skipping 21 matching lines...) Expand all Loading... |
135 time.Sleep(1 * time.Millisecond) | 139 time.Sleep(1 * time.Millisecond) |
136 | 140 |
137 readerMux.Disconnect(0, "") | 141 readerMux.Disconnect(0, "") |
138 writerMux.Disconnect(0, "") | 142 writerMux.Disconnect(0, "") |
139 } | 143 } |
140 | 144 |
141 func TestMuxReject(t *testing.T) { | 145 func TestMuxReject(t *testing.T) { |
142 client, server := muxPair() | 146 client, server := muxPair() |
143 | 147 |
144 go func() { | 148 go func() { |
145 » » ch, err := server.Accept() | 149 » » ch, ok := <-server.incomingChannels |
146 » » if err != nil { | 150 » » if !ok { |
147 » » » t.Fatalf("Accept: %v", err) | 151 » » » t.Fatalf("Accept") |
148 } | 152 } |
149 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra"
{ | 153 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra"
{ |
150 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(),
ch.ExtraData()) | 154 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(),
ch.ExtraData()) |
151 } | 155 } |
152 ch.Reject(RejectionReason(42), "message") | 156 ch.Reject(RejectionReason(42), "message") |
153 }() | 157 }() |
154 | 158 |
155 ch, err := client.OpenChannel("ch", []byte("extra")) | 159 ch, err := client.OpenChannel("ch", []byte("extra")) |
156 if ch != nil { | 160 if ch != nil { |
157 t.Fatal("openChannel not rejected") | 161 t.Fatal("openChannel not rejected") |
(...skipping 11 matching lines...) Expand all Loading... |
169 t.Errorf("got %q, want %q", err.Error(), want) | 173 t.Errorf("got %q, want %q", err.Error(), want) |
170 } | 174 } |
171 } | 175 } |
172 | 176 |
173 func TestMuxChannelRequest(t *testing.T) { | 177 func TestMuxChannelRequest(t *testing.T) { |
174 client, server, _ := channelPair(t) | 178 client, server, _ := channelPair(t) |
175 var received int | 179 var received int |
176 var wg sync.WaitGroup | 180 var wg sync.WaitGroup |
177 wg.Add(1) | 181 wg.Add(1) |
178 go func() { | 182 go func() { |
179 » » for r := range server.IncomingRequests() { | 183 » » for r := range server.incomingRequests { |
180 received++ | 184 received++ |
181 if r.WantReply { | 185 if r.WantReply { |
182 server.AckRequest(r.Request == "yes") | 186 server.AckRequest(r.Request == "yes") |
183 } | 187 } |
184 } | 188 } |
185 wg.Done() | 189 wg.Done() |
186 }() | 190 }() |
187 _, err := client.SendRequest("yes", false, nil) | 191 _, err := client.SendRequest("yes", false, nil) |
188 if err != nil { | 192 if err != nil { |
189 t.Fatalf("SendRequest: %v", err) | 193 t.Fatalf("SendRequest: %v", err) |
(...skipping 23 matching lines...) Expand all Loading... |
213 if received != 3 { | 217 if received != 3 { |
214 t.Errorf("got %d requests, want %d", received) | 218 t.Errorf("got %d requests, want %d", received) |
215 } | 219 } |
216 } | 220 } |
217 | 221 |
218 func TestMuxGlobalRequest(t *testing.T) { | 222 func TestMuxGlobalRequest(t *testing.T) { |
219 clientMux, serverMux := muxPair() | 223 clientMux, serverMux := muxPair() |
220 | 224 |
221 var seen bool | 225 var seen bool |
222 go func() { | 226 go func() { |
223 » » for r := range serverMux.IncomingRequests() { | 227 » » for r := range serverMux.incomingRequests { |
224 seen = seen || r.Request == "peek" | 228 seen = seen || r.Request == "peek" |
225 if r.WantReply { | 229 if r.WantReply { |
226 err := serverMux.AckRequest(r.Request == "yes", | 230 err := serverMux.AckRequest(r.Request == "yes", |
227 append([]byte(r.Request), r.Payload...)) | 231 append([]byte(r.Request), r.Payload...)) |
228 if err != nil { | 232 if err != nil { |
229 t.Errorf("AckRequest: %v", err) | 233 t.Errorf("AckRequest: %v", err) |
230 } | 234 } |
231 } | 235 } |
232 } | 236 } |
233 }() | 237 }() |
(...skipping 26 matching lines...) Expand all Loading... |
260 | 264 |
261 func TestMuxGlobalRequestUnblock(t *testing.T) { | 265 func TestMuxGlobalRequestUnblock(t *testing.T) { |
262 clientMux, serverMux := muxPair() | 266 clientMux, serverMux := muxPair() |
263 | 267 |
264 result := make(chan error, 1) | 268 result := make(chan error, 1) |
265 go func() { | 269 go func() { |
266 _, _, err := clientMux.SendRequest("hello", true, nil) | 270 _, _, err := clientMux.SendRequest("hello", true, nil) |
267 result <- err | 271 result <- err |
268 }() | 272 }() |
269 | 273 |
270 » <-serverMux.IncomingRequests() | 274 » <-serverMux.incomingRequests |
271 serverMux.conn.Close() | 275 serverMux.conn.Close() |
272 err := <-result | 276 err := <-result |
273 | 277 |
274 if err != io.EOF { | 278 if err != io.EOF { |
275 t.Errorf("want EOF, got %v", io.EOF) | 279 t.Errorf("want EOF, got %v", io.EOF) |
276 } | 280 } |
277 } | 281 } |
278 | 282 |
279 func TestMuxChannelRequestUnblock(t *testing.T) { | 283 func TestMuxChannelRequestUnblock(t *testing.T) { |
280 a, b, connB := channelPair(t) | 284 a, b, connB := channelPair(t) |
281 | 285 |
282 result := make(chan error, 1) | 286 result := make(chan error, 1) |
283 go func() { | 287 go func() { |
284 _, err := a.SendRequest("hello", true, nil) | 288 _, err := a.SendRequest("hello", true, nil) |
285 result <- err | 289 result <- err |
286 }() | 290 }() |
287 | 291 |
288 » <-b.IncomingRequests() | 292 » <-b.incomingRequests |
289 connB.conn.Close() | 293 connB.conn.Close() |
290 err := <-result | 294 err := <-result |
291 | 295 |
292 if err != io.EOF { | 296 if err != io.EOF { |
293 t.Errorf("want EOF, got %v", err) | 297 t.Errorf("want EOF, got %v", err) |
294 } | 298 } |
295 } | 299 } |
296 | 300 |
297 func TestMuxDisconnect(t *testing.T) { | 301 func TestMuxDisconnect(t *testing.T) { |
298 a, b := muxPair() | 302 a, b := muxPair() |
299 go func() { | 303 go func() { |
300 » » for r := range b.IncomingRequests() { | 304 » » for r := range b.incomingRequests { |
301 if r.WantReply { | 305 if r.WantReply { |
302 b.AckRequest(true, nil) | 306 b.AckRequest(true, nil) |
303 } | 307 } |
304 } | 308 } |
305 }() | 309 }() |
306 | 310 |
307 a.Disconnect(42, "whatever") | 311 a.Disconnect(42, "whatever") |
308 ok, _, err := a.SendRequest("hello", true, nil) | 312 ok, _, err := a.SendRequest("hello", true, nil) |
309 if ok || err == nil { | 313 if ok || err == nil { |
310 t.Errorf("got reply after disconnecting") | 314 t.Errorf("got reply after disconnecting") |
(...skipping 61 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
372 | 376 |
373 packet := make([]byte, 1+4+4+1) | 377 packet := make([]byte, 1+4+4+1) |
374 packet[0] = msgChannelData | 378 packet[0] = msgChannelData |
375 marshalUint32(packet[1:], 29348723 /* invalid channel id */) | 379 marshalUint32(packet[1:], 29348723 /* invalid channel id */) |
376 marshalUint32(packet[5:], 1) | 380 marshalUint32(packet[5:], 1) |
377 packet[9] = 42 | 381 packet[9] = 42 |
378 | 382 |
379 a.conn.writePacket(packet) | 383 a.conn.writePacket(packet) |
380 go a.SendRequest("hello", false, nil) | 384 go a.SendRequest("hello", false, nil) |
381 // 'a' wrote an invalid packet, so 'b' has exited. | 385 // 'a' wrote an invalid packet, so 'b' has exited. |
382 » req, ok := <-b.IncomingRequests() | 386 » req, ok := <-b.incomingRequests |
383 if ok { | 387 if ok { |
384 t.Errorf("got request %#v after receiving invalid packet", req) | 388 t.Errorf("got request %#v after receiving invalid packet", req) |
385 } | 389 } |
386 } | 390 } |
387 | 391 |
388 func TestZeroWindowAdjust(t *testing.T) { | 392 func TestZeroWindowAdjust(t *testing.T) { |
389 a, b, _ := channelPair(t) | 393 a, b, _ := channelPair(t) |
390 | 394 |
391 go func() { | 395 go func() { |
392 io.WriteString(a, "hello") | 396 io.WriteString(a, "hello") |
393 // bogus adjust. | 397 // bogus adjust. |
394 » » a.(*channel).sendMessage( | 398 » » a.sendMessage( |
395 msgChannelWindowAdjust, windowAdjustMsg{}) | 399 msgChannelWindowAdjust, windowAdjustMsg{}) |
396 io.WriteString(a, "world") | 400 io.WriteString(a, "world") |
397 a.Close() | 401 a.Close() |
398 }() | 402 }() |
399 | 403 |
400 want := "helloworld" | 404 want := "helloworld" |
401 c, _ := ioutil.ReadAll(b) | 405 c, _ := ioutil.ReadAll(b) |
402 if string(c) != want { | 406 if string(c) != want { |
403 t.Errorf("got %q want %q", c, want) | 407 t.Errorf("got %q want %q", c, want) |
404 } | 408 } |
405 } | 409 } |
406 | 410 |
407 func TestMuxMaxPacketSize(t *testing.T) { | 411 func TestMuxMaxPacketSize(t *testing.T) { |
408 a, b, _ := channelPair(t) | 412 a, b, _ := channelPair(t) |
409 | 413 |
410 » ch := a.(*channel) | 414 » large := make([]byte, a.maxPacket+1) |
411 » large := make([]byte, ch.maxPacket+1) | 415 » if err := a.writePacket(large); err == nil { |
412 » if err := ch.writePacket(large); err == nil { | |
413 t.Errorf("channel sent out packet larger than maxPacket") | 416 t.Errorf("channel sent out packet larger than maxPacket") |
414 } | 417 } |
415 | 418 |
416 packet := make([]byte, 1+4+4+1+len(large)) | 419 packet := make([]byte, 1+4+4+1+len(large)) |
417 packet[0] = msgChannelData | 420 packet[0] = msgChannelData |
418 » marshalUint32(packet[1:], ch.remoteId) | 421 » marshalUint32(packet[1:], a.remoteId) |
419 marshalUint32(packet[5:], uint32(len(large))) | 422 marshalUint32(packet[5:], uint32(len(large))) |
420 packet[9] = 42 | 423 packet[9] = 42 |
421 | 424 |
422 » if err := ch.mux.conn.writePacket(packet); err != nil { | 425 » if err := a.mux.conn.writePacket(packet); err != nil { |
423 t.Errorf("could not send packet") | 426 t.Errorf("could not send packet") |
424 } | 427 } |
425 | 428 |
426 go a.SendRequest("hello", false, nil) | 429 go a.SendRequest("hello", false, nil) |
427 | 430 |
428 » _, ok := <-b.IncomingRequests() | 431 » _, ok := <-b.incomingRequests |
429 if ok { | 432 if ok { |
430 t.Errorf("connection still alive after receiving large packet.") | 433 t.Errorf("connection still alive after receiving large packet.") |
431 } | 434 } |
432 } | 435 } |
LEFT | RIGHT |