LEFT | RIGHT |
| 1 // Copyright 2013 The Go Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style |
| 3 // license that can be found in the LICENSE file. |
| 4 |
1 package ssh | 5 package ssh |
2 | 6 |
3 import ( | 7 import ( |
4 "io" | 8 "io" |
5 "io/ioutil" | 9 "io/ioutil" |
6 "log" | |
7 "sync" | 10 "sync" |
8 "testing" | 11 "testing" |
9 "time" | 12 "time" |
10 ) | 13 ) |
11 | 14 |
12 var _ = log.Println | |
13 | |
14 func muxPair() (*mux, *mux) { | 15 func muxPair() (*mux, *mux) { |
15 a, b := memPipe() | 16 a, b := memPipe() |
16 | 17 |
17 s := newMux(a) | 18 s := newMux(a) |
18 c := newMux(b) | 19 c := newMux(b) |
19 | 20 |
20 go s.Loop() | 21 go s.Loop() |
21 go c.Loop() | 22 go c.Loop() |
22 | 23 |
23 c.chanList.offset = 'c' | |
24 s.chanList.offset = 's' | |
25 return s, c | 24 return s, c |
26 } | 25 } |
27 | 26 |
28 func channelPair(t *testing.T) (Channel, Channel) { | 27 // Returns both ends of a channel, and the mux for the the 2nd |
| 28 // channel. |
| 29 func channelPair(t *testing.T) (*channel, *channel, *mux) { |
29 c, s := muxPair() | 30 c, s := muxPair() |
30 | 31 |
31 » res := make(chan Channel, 1) | 32 » res := make(chan *channel, 1) |
32 » go func() { | 33 » go func() { |
33 » » ch, err := s.Accept() | 34 » » ch, ok := <-s.incomingChannels |
| 35 » » if !ok { |
| 36 » » » t.Fatalf("No incoming channel") |
| 37 » » } |
| 38 » » if ch.ChannelType() != "chan" { |
| 39 » » » t.Fatalf("got type %q want chan", ch.ChannelType()) |
| 40 » » } |
| 41 » » err := ch.Accept() |
34 if err != nil { | 42 if err != nil { |
35 t.Fatalf("Accept %v", err) | 43 t.Fatalf("Accept %v", err) |
36 } | 44 } |
37 if ch.ChannelType() != "chan" { | |
38 t.Fatalf("got type %q want chan", ch.ChannelType()) | |
39 } | |
40 ch.Accept() | |
41 res <- ch | 45 res <- ch |
42 }() | 46 }() |
43 | 47 |
44 ch, err := c.OpenChannel("chan", nil) | 48 ch, err := c.OpenChannel("chan", nil) |
45 if err != nil { | 49 if err != nil { |
46 t.Fatalf("OpenChannel: %v", err) | 50 t.Fatalf("OpenChannel: %v", err) |
47 } | 51 } |
48 | 52 |
49 » return <-res, ch | 53 » return <-res, ch, c |
50 } | 54 } |
51 | 55 |
52 func TestMuxReadWrite(t *testing.T) { | 56 func TestMuxReadWrite(t *testing.T) { |
53 » s, c := channelPair(t) | 57 » s, c, _ := channelPair(t) |
54 | 58 |
55 magic := "hello world" | 59 magic := "hello world" |
56 magicExt := "hello stderr" | 60 magicExt := "hello stderr" |
57 » var wg sync.WaitGroup | 61 » go func() { |
58 » wg.Add(1) | |
59 » go func() { | |
60 » » defer wg.Done() | |
61 | |
62 _, err := s.Write([]byte(magic)) | 62 _, err := s.Write([]byte(magic)) |
63 if err != nil { | 63 if err != nil { |
64 t.Fatalf("Write: %v", err) | 64 t.Fatalf("Write: %v", err) |
65 } | 65 } |
66 » » _, err = s.Stderr().Write([]byte(magicExt)) | 66 » » _, err = s.Extended(1).Write([]byte(magicExt)) |
67 if err != nil { | 67 if err != nil { |
68 t.Fatalf("Write: %v", err) | 68 t.Fatalf("Write: %v", err) |
69 } | 69 } |
70 err = s.Close() | 70 err = s.Close() |
71 if err != nil { | 71 if err != nil { |
72 t.Fatalf("Close: %v", err) | 72 t.Fatalf("Close: %v", err) |
73 } | 73 } |
74 }() | 74 }() |
75 | 75 |
76 var buf [1024]byte | 76 var buf [1024]byte |
77 n, err := c.Read(buf[:]) | 77 n, err := c.Read(buf[:]) |
78 if err != nil { | 78 if err != nil { |
79 t.Fatalf("server Read: %v", err) | 79 t.Fatalf("server Read: %v", err) |
80 } | 80 } |
81 got := string(buf[:n]) | 81 got := string(buf[:n]) |
82 if got != magic { | 82 if got != magic { |
83 t.Fatalf("server: got %q want %q", got, magic) | 83 t.Fatalf("server: got %q want %q", got, magic) |
84 } | 84 } |
85 | 85 |
86 » n, err = c.Stderr().Read(buf[:]) | 86 » n, err = c.Extended(1).Read(buf[:]) |
87 if err != nil { | 87 if err != nil { |
88 t.Fatalf("server Read: %v", err) | 88 t.Fatalf("server Read: %v", err) |
89 } | 89 } |
90 | 90 |
91 got = string(buf[:n]) | 91 got = string(buf[:n]) |
92 if got != magicExt { | 92 if got != magicExt { |
93 t.Fatalf("server: got %q want %q", got, magic) | 93 t.Fatalf("server: got %q want %q", got, magic) |
94 } | 94 } |
95 } | 95 } |
96 | 96 |
97 func TestMuxFlowControl(t *testing.T) { | 97 func TestMuxFlowControl(t *testing.T) { |
98 writerMux, readerMux := muxPair() | 98 writerMux, readerMux := muxPair() |
99 | 99 |
100 var wg sync.WaitGroup | |
101 wg.Add(2) | |
102 | |
103 // this goroutine reads just a bit. | 100 // this goroutine reads just a bit. |
104 go func() { | 101 go func() { |
105 » » reader, err := readerMux.Accept() | 102 » » reader, ok := <-readerMux.incomingChannels |
| 103 » » if !ok { |
| 104 » » » t.Fatalf("no incoming channel") |
| 105 » » } |
| 106 » » err := reader.Accept() |
106 if err != nil { | 107 if err != nil { |
107 t.Fatalf("Accept: %v", err) | |
108 } | |
109 if err = reader.Accept(); err != nil { | |
110 t.Fatalf("Accept: %v", err) | 108 t.Fatalf("Accept: %v", err) |
111 } | 109 } |
112 | 110 |
113 b := make([]byte, 1024) | 111 b := make([]byte, 1024) |
114 n, err := reader.Read(b) | 112 n, err := reader.Read(b) |
115 if err != nil || n != len(b) { | 113 if err != nil || n != len(b) { |
116 t.Errorf("Read: %v, %d bytes", err, n) | 114 t.Errorf("Read: %v, %d bytes", err, n) |
117 } | 115 } |
118 wg.Done() | |
119 }() | 116 }() |
120 | 117 |
121 writer, err := writerMux.OpenChannel("pipe", nil) | 118 writer, err := writerMux.OpenChannel("pipe", nil) |
122 if err != nil { | 119 if err != nil { |
123 t.Fatalf("OpenChannel: %v", err) | 120 t.Fatalf("OpenChannel: %v", err) |
124 } | 121 } |
125 | 122 |
126 // This goroutine writes is blocked from writing by the slow | 123 // This goroutine writes is blocked from writing by the slow |
127 // reader | 124 // reader |
128 go func() { | 125 go func() { |
129 largeData := make([]byte, 3*(1<<15)) | 126 largeData := make([]byte, 3*(1<<15)) |
130 n, err := writer.Write(largeData) | 127 n, err := writer.Write(largeData) |
131 if err != io.EOF { | 128 if err != io.EOF { |
132 t.Errorf("want EOF, got %v", err) | 129 t.Errorf("want EOF, got %v", err) |
133 } | 130 } |
134 want := 1024 + (1 << 15) | 131 want := 1024 + (1 << 15) |
135 if n != want { | 132 if n != want { |
136 t.Errorf("wrote %d, want %d", n, want) | 133 t.Errorf("wrote %d, want %d", n, want) |
137 } | 134 } |
138 wg.Done() | |
139 }() | 135 }() |
140 | 136 |
141 // Wait for a bit for things to subside. The write should be | 137 // Wait for a bit for things to subside. The write should be |
142 // blocked. | 138 // blocked. |
143 time.Sleep(1 * time.Millisecond) | 139 time.Sleep(1 * time.Millisecond) |
144 | 140 |
145 readerMux.Disconnect(0, "") | 141 readerMux.Disconnect(0, "") |
146 writerMux.Disconnect(0, "") | 142 writerMux.Disconnect(0, "") |
147 | |
148 wg.Done() | |
149 } | 143 } |
150 | 144 |
151 func TestMuxReject(t *testing.T) { | 145 func TestMuxReject(t *testing.T) { |
152 client, server := muxPair() | 146 client, server := muxPair() |
153 | 147 |
154 go func() { | 148 go func() { |
155 » » ch, err := server.Accept() | 149 » » ch, ok := <-server.incomingChannels |
156 » » if err != nil { | 150 » » if !ok { |
157 » » » t.Fatalf("Accept: %v", err) | 151 » » » t.Fatalf("Accept") |
158 } | 152 } |
159 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra"
{ | 153 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra"
{ |
160 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(),
ch.ExtraData()) | 154 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(),
ch.ExtraData()) |
161 } | 155 } |
162 ch.Reject(RejectionReason(42), "message") | 156 ch.Reject(RejectionReason(42), "message") |
163 }() | 157 }() |
164 | 158 |
165 ch, err := client.OpenChannel("ch", []byte("extra")) | 159 ch, err := client.OpenChannel("ch", []byte("extra")) |
166 if ch != nil { | 160 if ch != nil { |
167 t.Fatal("openChannel not rejected") | 161 t.Fatal("openChannel not rejected") |
168 } | 162 } |
169 | 163 |
170 » ocf, ok := err.(*OpenChannelFailed) | 164 » ocf, ok := err.(*OpenChannelError) |
171 if !ok { | 165 if !ok { |
172 » » t.Errorf("got %#v want *OpenChannelFailed", err) | 166 » » t.Errorf("got %#v want *OpenChannelError", err) |
173 » } | 167 » } else if ocf.Reason != 42 || ocf.Message != "message" { |
174 | 168 » » t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "messag
e") |
175 » if ocf.Reason != 42 || ocf.Message != "message" { | |
176 » » t.Errorf("got %#v, want {Reason: 42, Mepassage: %q}", ocf, "mess
age") | |
177 } | 169 } |
178 | 170 |
179 want := "ssh: rejected: unknown reason 42 (message)" | 171 want := "ssh: rejected: unknown reason 42 (message)" |
180 if err.Error() != want { | 172 if err.Error() != want { |
181 t.Errorf("got %q, want %q", err.Error(), want) | 173 t.Errorf("got %q, want %q", err.Error(), want) |
182 } | 174 } |
183 } | 175 } |
184 | 176 |
185 func TestMuxChannelRequest(t *testing.T) { | 177 func TestMuxChannelRequest(t *testing.T) { |
186 » client, server := channelPair(t) | 178 » client, server, _ := channelPair(t) |
187 var received int | 179 var received int |
188 var wg sync.WaitGroup | 180 var wg sync.WaitGroup |
189 wg.Add(1) | 181 wg.Add(1) |
190 go func() { | 182 go func() { |
191 » » for r := range server.ReceivedRequests() { | 183 » » for r := range server.incomingRequests { |
192 received++ | 184 received++ |
193 if r.WantReply { | 185 if r.WantReply { |
194 server.AckRequest(r.Request == "yes") | 186 server.AckRequest(r.Request == "yes") |
195 } | 187 } |
196 } | 188 } |
197 wg.Done() | 189 wg.Done() |
198 }() | 190 }() |
199 _, err := client.SendRequest("yes", false, nil) | 191 _, err := client.SendRequest("yes", false, nil) |
200 if err != nil { | 192 if err != nil { |
201 t.Fatalf("SendRequest: %v", err) | 193 t.Fatalf("SendRequest: %v", err) |
(...skipping 22 matching lines...) Expand all Loading... |
224 | 216 |
225 if received != 3 { | 217 if received != 3 { |
226 t.Errorf("got %d requests, want %d", received) | 218 t.Errorf("got %d requests, want %d", received) |
227 } | 219 } |
228 } | 220 } |
229 | 221 |
230 func TestMuxGlobalRequest(t *testing.T) { | 222 func TestMuxGlobalRequest(t *testing.T) { |
231 clientMux, serverMux := muxPair() | 223 clientMux, serverMux := muxPair() |
232 | 224 |
233 var seen bool | 225 var seen bool |
234 » var wg sync.WaitGroup | 226 » go func() { |
235 » wg.Add(1) | 227 » » for r := range serverMux.incomingRequests { |
236 » go func() { | |
237 » » for r := range serverMux.ReceivedRequests() { | |
238 seen = seen || r.Request == "peek" | 228 seen = seen || r.Request == "peek" |
239 if r.WantReply { | 229 if r.WantReply { |
240 » » » » err := serverMux.AckGlobalRequest(r.Request == "
yes", | 230 » » » » err := serverMux.AckRequest(r.Request == "yes", |
241 append([]byte(r.Request), r.Payload...)) | 231 append([]byte(r.Request), r.Payload...)) |
242 if err != nil { | 232 if err != nil { |
243 t.Errorf("AckRequest: %v", err) | 233 t.Errorf("AckRequest: %v", err) |
244 } | 234 } |
245 } | 235 } |
246 } | 236 } |
247 wg.Done() | |
248 }() | 237 }() |
249 | 238 |
250 _, _, err := clientMux.SendRequest("peek", false, nil) | 239 _, _, err := clientMux.SendRequest("peek", false, nil) |
251 if err != nil { | 240 if err != nil { |
252 t.Errorf("SendRequest: %v", err) | 241 t.Errorf("SendRequest: %v", err) |
253 } | 242 } |
254 | 243 |
255 ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) | 244 ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) |
256 if !ok || string(data) != "yesa" || err != nil { | 245 if !ok || string(data) != "yesa" || err != nil { |
257 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", | 246 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", |
258 ok, data, err) | 247 ok, data, err) |
259 } | 248 } |
260 if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok
|| string(data) != "yesa" || err != nil { | 249 if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok
|| string(data) != "yesa" || err != nil { |
261 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", | 250 t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", |
262 ok, data, err) | 251 ok, data, err) |
263 } | 252 } |
264 | 253 |
265 if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok |
| string(data) != "noa" || err != nil { | 254 if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok |
| string(data) != "noa" || err != nil { |
266 t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", | 255 t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", |
267 ok, data, err) | 256 ok, data, err) |
268 } | 257 } |
269 » // not really related to global reqs, but try disconnect too. | 258 |
270 » clientMux.Disconnect(42, "whatever") | 259 » clientMux.Disconnect(0, "") |
271 | 260 » if !seen { |
272 » wg.Wait() | 261 » » t.Errorf("never saw 'peek' request") |
| 262 » } |
| 263 } |
| 264 |
| 265 func TestMuxGlobalRequestUnblock(t *testing.T) { |
| 266 » clientMux, serverMux := muxPair() |
| 267 |
| 268 » result := make(chan error, 1) |
| 269 » go func() { |
| 270 » » _, _, err := clientMux.SendRequest("hello", true, nil) |
| 271 » » result <- err |
| 272 » }() |
| 273 |
| 274 » <-serverMux.incomingRequests |
| 275 » serverMux.conn.Close() |
| 276 » err := <-result |
| 277 |
| 278 » if err != io.EOF { |
| 279 » » t.Errorf("want EOF, got %v", io.EOF) |
| 280 » } |
| 281 } |
| 282 |
| 283 func TestMuxChannelRequestUnblock(t *testing.T) { |
| 284 » a, b, connB := channelPair(t) |
| 285 |
| 286 » result := make(chan error, 1) |
| 287 » go func() { |
| 288 » » _, err := a.SendRequest("hello", true, nil) |
| 289 » » result <- err |
| 290 » }() |
| 291 |
| 292 » <-b.incomingRequests |
| 293 » connB.conn.Close() |
| 294 » err := <-result |
| 295 |
| 296 » if err != io.EOF { |
| 297 » » t.Errorf("want EOF, got %v", err) |
| 298 » } |
| 299 } |
| 300 |
| 301 func TestMuxDisconnect(t *testing.T) { |
| 302 » a, b := muxPair() |
| 303 » go func() { |
| 304 » » for r := range b.incomingRequests { |
| 305 » » » if r.WantReply { |
| 306 » » » » b.AckRequest(true, nil) |
| 307 » » » } |
| 308 » » } |
| 309 » }() |
| 310 |
| 311 » a.Disconnect(42, "whatever") |
| 312 » ok, _, err := a.SendRequest("hello", true, nil) |
| 313 » if ok || err == nil { |
| 314 » » t.Errorf("got reply after disconnecting") |
| 315 » } |
273 } | 316 } |
274 | 317 |
275 func TestMuxCloseChannel(t *testing.T) { | 318 func TestMuxCloseChannel(t *testing.T) { |
276 » r, w := channelPair(t) | 319 » r, w, _ := channelPair(t) |
277 | 320 |
278 timeout := time.After(10 * time.Millisecond) | 321 timeout := time.After(10 * time.Millisecond) |
279 result := make(chan error, 1) | 322 result := make(chan error, 1) |
280 go func() { | 323 go func() { |
281 var b [1024]byte | 324 var b [1024]byte |
282 _, err := r.Read(b[:]) | 325 _, err := r.Read(b[:]) |
283 result <- err | 326 result <- err |
284 }() | 327 }() |
285 if err := w.Close(); err != nil { | 328 if err := w.Close(); err != nil { |
286 t.Errorf("w.Close: %v", err) | 329 t.Errorf("w.Close: %v", err) |
287 } | 330 } |
288 | 331 |
289 if _, err := w.Write([]byte("hello")); err != io.EOF { | 332 if _, err := w.Write([]byte("hello")); err != io.EOF { |
290 t.Errorf("got err %v, want io.EOF after Close", err) | 333 t.Errorf("got err %v, want io.EOF after Close", err) |
291 } | 334 } |
292 | 335 |
293 select { | 336 select { |
294 case e := <-result: | 337 case e := <-result: |
295 if e != io.EOF { | 338 if e != io.EOF { |
296 t.Errorf("got %v (%T), want io.EOF", e, e) | 339 t.Errorf("got %v (%T), want io.EOF", e, e) |
297 } | 340 } |
298 case <-timeout: | 341 case <-timeout: |
299 t.Errorf("timed out waiting for read to exit") | 342 t.Errorf("timed out waiting for read to exit") |
300 } | 343 } |
301 } | 344 } |
302 | 345 |
303 func TestMuxCloseWriteChannel(t *testing.T) { | 346 func TestMuxCloseWriteChannel(t *testing.T) { |
304 » r, w := channelPair(t) | 347 » r, w, _ := channelPair(t) |
305 | 348 |
306 timeout := time.After(10 * time.Millisecond) | 349 timeout := time.After(10 * time.Millisecond) |
307 result := make(chan error, 1) | 350 result := make(chan error, 1) |
308 go func() { | 351 go func() { |
309 var b [1024]byte | 352 var b [1024]byte |
310 _, err := r.Read(b[:]) | 353 _, err := r.Read(b[:]) |
311 result <- err | 354 result <- err |
312 }() | 355 }() |
313 if err := w.CloseWrite(); err != nil { | 356 if err := w.CloseWrite(); err != nil { |
314 t.Errorf("w.CloseWrite: %v", err) | 357 t.Errorf("w.CloseWrite: %v", err) |
(...skipping 18 matching lines...) Expand all Loading... |
333 | 376 |
334 packet := make([]byte, 1+4+4+1) | 377 packet := make([]byte, 1+4+4+1) |
335 packet[0] = msgChannelData | 378 packet[0] = msgChannelData |
336 marshalUint32(packet[1:], 29348723 /* invalid channel id */) | 379 marshalUint32(packet[1:], 29348723 /* invalid channel id */) |
337 marshalUint32(packet[5:], 1) | 380 marshalUint32(packet[5:], 1) |
338 packet[9] = 42 | 381 packet[9] = 42 |
339 | 382 |
340 a.conn.writePacket(packet) | 383 a.conn.writePacket(packet) |
341 go a.SendRequest("hello", false, nil) | 384 go a.SendRequest("hello", false, nil) |
342 // 'a' wrote an invalid packet, so 'b' has exited. | 385 // 'a' wrote an invalid packet, so 'b' has exited. |
343 » req, ok := <-b.ReceivedRequests() | 386 » req, ok := <-b.incomingRequests |
344 if ok { | 387 if ok { |
345 t.Errorf("got request %#v after receiving invalid packet", req) | 388 t.Errorf("got request %#v after receiving invalid packet", req) |
346 } | 389 } |
347 } | 390 } |
348 | 391 |
349 func TestZeroWindowAdjust(t *testing.T) { | 392 func TestZeroWindowAdjust(t *testing.T) { |
350 » a, b := channelPair(t) | 393 » a, b, _ := channelPair(t) |
351 | 394 |
352 go func() { | 395 go func() { |
353 io.WriteString(a, "hello") | 396 io.WriteString(a, "hello") |
354 // bogus adjust. | 397 // bogus adjust. |
355 » » a.(*channel).sendMessage( | 398 » » a.sendMessage( |
356 msgChannelWindowAdjust, windowAdjustMsg{}) | 399 msgChannelWindowAdjust, windowAdjustMsg{}) |
357 io.WriteString(a, "world") | 400 io.WriteString(a, "world") |
358 a.Close() | 401 a.Close() |
359 }() | 402 }() |
360 | 403 |
361 want := "helloworld" | 404 want := "helloworld" |
362 c, _ := ioutil.ReadAll(b) | 405 c, _ := ioutil.ReadAll(b) |
363 if string(c) != want { | 406 if string(c) != want { |
364 t.Errorf("got %q want %q", c, want) | 407 t.Errorf("got %q want %q", c, want) |
365 } | 408 } |
366 } | 409 } |
| 410 |
| 411 func TestMuxMaxPacketSize(t *testing.T) { |
| 412 a, b, _ := channelPair(t) |
| 413 |
| 414 large := make([]byte, a.maxPacket+1) |
| 415 if err := a.writePacket(large); err == nil { |
| 416 t.Errorf("channel sent out packet larger than maxPacket") |
| 417 } |
| 418 |
| 419 packet := make([]byte, 1+4+4+1+len(large)) |
| 420 packet[0] = msgChannelData |
| 421 marshalUint32(packet[1:], a.remoteId) |
| 422 marshalUint32(packet[5:], uint32(len(large))) |
| 423 packet[9] = 42 |
| 424 |
| 425 if err := a.mux.conn.writePacket(packet); err != nil { |
| 426 t.Errorf("could not send packet") |
| 427 } |
| 428 |
| 429 go a.SendRequest("hello", false, nil) |
| 430 |
| 431 _, ok := <-b.incomingRequests |
| 432 if ok { |
| 433 t.Errorf("connection still alive after receiving large packet.") |
| 434 } |
| 435 } |
LEFT | RIGHT |