LEFT | RIGHT |
| 1 // Copyright 2013 The Go Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style |
| 3 // license that can be found in the LICENSE file. |
| 4 |
1 package ssh | 5 package ssh |
2 | 6 |
3 import ( | 7 import ( |
4 "io" | 8 "io" |
5 » "log" | 9 » "io/ioutil" |
6 "sync" | 10 "sync" |
7 "testing" | 11 "testing" |
8 "time" | 12 "time" |
9 ) | 13 ) |
10 | 14 |
11 var _ = log.Println | |
12 | |
13 func muxPair() (*mux, *mux) { | 15 func muxPair() (*mux, *mux) { |
14 a, b := memPipe() | 16 a, b := memPipe() |
15 | 17 |
16 s := newMux(a) | 18 s := newMux(a) |
17 c := newMux(b) | 19 c := newMux(b) |
18 » c.chanList.offset = 'c' | 20 |
19 » s.chanList.offset = 's' | 21 » go s.Loop() |
| 22 » go c.Loop() |
| 23 |
20 return s, c | 24 return s, c |
21 } | 25 } |
22 | 26 |
23 func channelPair(t *testing.T) (Channel, Channel) { | 27 // Returns both ends of a channel, and the mux for the the 2nd |
| 28 // channel. |
| 29 func channelPair(t *testing.T) (*channel, *channel, *mux) { |
24 c, s := muxPair() | 30 c, s := muxPair() |
25 | 31 |
26 » res := make(chan Channel, 1) | 32 » res := make(chan *channel, 1) |
27 » go func() { | 33 » go func() { |
28 » » ch, err := s.Accept() | 34 » » ch, ok := <-s.incomingChannels |
| 35 » » if !ok { |
| 36 » » » t.Fatalf("No incoming channel") |
| 37 » » } |
| 38 » » if ch.ChannelType() != "chan" { |
| 39 » » » t.Fatalf("got type %q want chan", ch.ChannelType()) |
| 40 » » } |
| 41 » » err := ch.Accept() |
29 if err != nil { | 42 if err != nil { |
30 t.Fatalf("Accept %v", err) | 43 t.Fatalf("Accept %v", err) |
31 } | 44 } |
32 if ch.ChannelType() != "chan" { | |
33 t.Fatalf("got type %q want chan", ch.ChannelType()) | |
34 } | |
35 ch.Accept() | |
36 res <- ch | 45 res <- ch |
37 }() | 46 }() |
38 | 47 |
39 ch, err := c.OpenChannel("chan", nil) | 48 ch, err := c.OpenChannel("chan", nil) |
40 if err != nil { | 49 if err != nil { |
41 t.Fatalf("OpenChannel: %v", err) | 50 t.Fatalf("OpenChannel: %v", err) |
42 } | 51 } |
43 | 52 |
44 » return <-res, ch | 53 » return <-res, ch, c |
45 } | 54 } |
46 | 55 |
47 func TestMuxReadWrite(t *testing.T) { | 56 func TestMuxReadWrite(t *testing.T) { |
48 » s, c := channelPair(t) | 57 » s, c, _ := channelPair(t) |
49 | 58 |
50 magic := "hello world" | 59 magic := "hello world" |
51 magicExt := "hello stderr" | 60 magicExt := "hello stderr" |
52 » var wg sync.WaitGroup | 61 » go func() { |
53 » wg.Add(1) | |
54 » go func() { | |
55 » » defer wg.Done() | |
56 | |
57 _, err := s.Write([]byte(magic)) | 62 _, err := s.Write([]byte(magic)) |
58 if err != nil { | 63 if err != nil { |
59 t.Fatalf("Write: %v", err) | 64 t.Fatalf("Write: %v", err) |
60 } | 65 } |
61 » » _, err = s.Stderr().Write([]byte(magicExt)) | 66 » » _, err = s.Extended(1).Write([]byte(magicExt)) |
62 if err != nil { | 67 if err != nil { |
63 t.Fatalf("Write: %v", err) | 68 t.Fatalf("Write: %v", err) |
64 } | 69 } |
65 err = s.Close() | 70 err = s.Close() |
66 if err != nil { | 71 if err != nil { |
67 t.Fatalf("Close: %v", err) | 72 t.Fatalf("Close: %v", err) |
68 } | 73 } |
69 }() | 74 }() |
70 | 75 |
71 var buf [1024]byte | 76 var buf [1024]byte |
72 n, err := c.Read(buf[:]) | 77 n, err := c.Read(buf[:]) |
73 if err != nil { | 78 if err != nil { |
74 t.Fatalf("server Read: %v", err) | 79 t.Fatalf("server Read: %v", err) |
75 } | 80 } |
76 got := string(buf[:n]) | 81 got := string(buf[:n]) |
77 if got != magic { | 82 if got != magic { |
78 t.Fatalf("server: got %q want %q", got, magic) | 83 t.Fatalf("server: got %q want %q", got, magic) |
79 } | 84 } |
80 | 85 |
81 » n, err = c.Stderr().Read(buf[:]) | 86 » n, err = c.Extended(1).Read(buf[:]) |
82 if err != nil { | 87 if err != nil { |
83 t.Fatalf("server Read: %v", err) | 88 t.Fatalf("server Read: %v", err) |
84 } | 89 } |
85 | 90 |
86 got = string(buf[:n]) | 91 got = string(buf[:n]) |
87 if got != magicExt { | 92 if got != magicExt { |
88 t.Fatalf("server: got %q want %q", got, magic) | 93 t.Fatalf("server: got %q want %q", got, magic) |
89 } | 94 } |
90 } | 95 } |
91 | 96 |
92 func TestMuxFlowControl(t *testing.T) { | 97 func TestMuxFlowControl(t *testing.T) { |
93 writerMux, readerMux := muxPair() | 98 writerMux, readerMux := muxPair() |
94 | 99 |
95 » var wg sync.WaitGroup | 100 » // this goroutine reads just a bit. |
96 » wg.Add(2) | 101 » go func() { |
97 | 102 » » reader, ok := <-readerMux.incomingChannels |
98 » // More than window size | 103 » » if !ok { |
99 » go func() { | 104 » » » t.Fatalf("no incoming channel") |
100 » » reader, err := readerMux.Accept() | 105 » » } |
| 106 » » err := reader.Accept() |
101 if err != nil { | 107 if err != nil { |
102 t.Fatalf("Accept: %v", err) | |
103 } | |
104 if err = reader.Accept(); err != nil { | |
105 t.Fatalf("Accept: %v", err) | 108 t.Fatalf("Accept: %v", err) |
106 } | 109 } |
107 | 110 |
108 b := make([]byte, 1024) | 111 b := make([]byte, 1024) |
109 n, err := reader.Read(b) | 112 n, err := reader.Read(b) |
110 if err != nil || n != len(b) { | 113 if err != nil || n != len(b) { |
111 t.Errorf("Read: %v, %d bytes", err, n) | 114 t.Errorf("Read: %v, %d bytes", err, n) |
112 } | 115 } |
113 wg.Done() | |
114 }() | 116 }() |
115 | 117 |
116 writer, err := writerMux.OpenChannel("pipe", nil) | 118 writer, err := writerMux.OpenChannel("pipe", nil) |
117 if err != nil { | 119 if err != nil { |
118 t.Fatalf("OpenChannel: %v", err) | 120 t.Fatalf("OpenChannel: %v", err) |
119 } | 121 } |
120 | 122 |
| 123 // This goroutine writes is blocked from writing by the slow |
| 124 // reader |
121 go func() { | 125 go func() { |
122 largeData := make([]byte, 3*(1<<15)) | 126 largeData := make([]byte, 3*(1<<15)) |
123 n, err := writer.Write(largeData) | 127 n, err := writer.Write(largeData) |
124 if err != io.EOF { | 128 if err != io.EOF { |
125 t.Errorf("want EOF, got %v", err) | 129 t.Errorf("want EOF, got %v", err) |
126 } | 130 } |
127 want := 1024 + (1 << 15) | 131 want := 1024 + (1 << 15) |
128 if n != want { | 132 if n != want { |
129 t.Errorf("wrote %d, want %d", n, want) | 133 t.Errorf("wrote %d, want %d", n, want) |
130 } | 134 } |
131 wg.Done() | |
132 }() | 135 }() |
133 | 136 |
134 // Wait for a bit for things to subside. The write should be | 137 // Wait for a bit for things to subside. The write should be |
135 // blocked. | 138 // blocked. |
136 time.Sleep(1 * time.Millisecond) | 139 time.Sleep(1 * time.Millisecond) |
137 | 140 |
138 » readerMux.conn.Close() | 141 » readerMux.Disconnect(0, "") |
139 » writerMux.conn.Close() | 142 » writerMux.Disconnect(0, "") |
140 | |
141 » wg.Done() | |
142 } | 143 } |
143 | 144 |
144 func TestMuxReject(t *testing.T) { | 145 func TestMuxReject(t *testing.T) { |
145 client, server := muxPair() | 146 client, server := muxPair() |
146 | 147 |
147 go func() { | 148 go func() { |
148 » » ch, err := server.Accept() | 149 » » ch, ok := <-server.incomingChannels |
149 » » if err != nil { | 150 » » if !ok { |
150 » » » t.Fatalf("Accept: %v", err) | 151 » » » t.Fatalf("Accept") |
151 } | 152 } |
152 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra"
{ | 153 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra"
{ |
153 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(),
ch.ExtraData()) | 154 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(),
ch.ExtraData()) |
154 } | 155 } |
155 ch.Reject(RejectionReason(42), "message") | 156 ch.Reject(RejectionReason(42), "message") |
156 }() | 157 }() |
157 | 158 |
158 ch, err := client.OpenChannel("ch", []byte("extra")) | 159 ch, err := client.OpenChannel("ch", []byte("extra")) |
159 if ch != nil { | 160 if ch != nil { |
160 t.Fatal("openChannel not rejected") | 161 t.Fatal("openChannel not rejected") |
161 } | 162 } |
162 | 163 |
163 » ocf, ok := err.(*OpenChannelFailed) | 164 » ocf, ok := err.(*OpenChannelError) |
164 if !ok { | 165 if !ok { |
165 » » t.Errorf("got %#v want *OpenChannelFailed", err) | 166 » » t.Errorf("got %#v want *OpenChannelError", err) |
166 » } | 167 » } else if ocf.Reason != 42 || ocf.Message != "message" { |
167 | 168 » » t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "messag
e") |
168 » if ocf.Reason != 42 || ocf.Message != "message" { | |
169 » » t.Errorf("got %#v, want {Reason: 42, Mepassage: %q}", ocf, "mess
age") | |
170 } | 169 } |
171 | 170 |
172 want := "ssh: rejected: unknown reason 42 (message)" | 171 want := "ssh: rejected: unknown reason 42 (message)" |
173 if err.Error() != want { | 172 if err.Error() != want { |
174 t.Errorf("got %q, want %q", err.Error(), want) | 173 t.Errorf("got %q, want %q", err.Error(), want) |
175 } | 174 } |
176 } | 175 } |
177 | 176 |
178 func TestMuxChannelRequest(t *testing.T) { | 177 func TestMuxChannelRequest(t *testing.T) { |
179 » client, server := channelPair(t) | 178 » client, server, _ := channelPair(t) |
180 var received int | 179 var received int |
181 var wg sync.WaitGroup | 180 var wg sync.WaitGroup |
182 wg.Add(1) | 181 wg.Add(1) |
183 go func() { | 182 go func() { |
184 » » for r := range server.ReceivedRequests() { | 183 » » for r := range server.incomingRequests { |
185 received++ | 184 received++ |
186 if r.WantReply { | 185 if r.WantReply { |
187 server.AckRequest(r.Request == "yes") | 186 server.AckRequest(r.Request == "yes") |
188 } | 187 } |
189 } | 188 } |
190 wg.Done() | 189 wg.Done() |
191 }() | 190 }() |
192 _, err := client.SendRequest("yes", false, nil) | 191 _, err := client.SendRequest("yes", false, nil) |
193 if err != nil { | 192 if err != nil { |
194 t.Fatalf("SendRequest: %v", err) | 193 t.Fatalf("SendRequest: %v", err) |
195 } | 194 } |
196 ok, err := client.SendRequest("yes", true, nil) | 195 ok, err := client.SendRequest("yes", true, nil) |
197 if err != nil { | 196 if err != nil { |
198 t.Fatalf("SendRequest: %v", err) | 197 t.Fatalf("SendRequest: %v", err) |
199 } | 198 } |
200 log.Println("ok", ok) | |
201 | 199 |
202 if !ok { | 200 if !ok { |
203 t.Errorf("SendRequest(yes): %v", ok) | 201 t.Errorf("SendRequest(yes): %v", ok) |
204 | 202 |
205 } | 203 } |
206 | 204 |
207 ok, err = client.SendRequest("no", true, nil) | 205 ok, err = client.SendRequest("no", true, nil) |
208 if err != nil { | 206 if err != nil { |
209 t.Fatalf("SendRequest: %v", err) | 207 t.Fatalf("SendRequest: %v", err) |
210 } | 208 } |
211 if ok { | 209 if ok { |
212 t.Errorf("SendRequest(no): %v", ok) | 210 t.Errorf("SendRequest(no): %v", ok) |
213 | 211 |
214 } | 212 } |
| 213 |
215 client.Close() | 214 client.Close() |
216 wg.Wait() | 215 wg.Wait() |
217 | 216 |
218 if received != 3 { | 217 if received != 3 { |
219 t.Errorf("got %d requests, want %d", received) | 218 t.Errorf("got %d requests, want %d", received) |
220 } | 219 } |
221 } | 220 } |
222 | 221 |
223 func TestMuxGlobalRequest(t *testing.T) { | 222 func TestMuxGlobalRequest(t *testing.T) { |
224 clientMux, serverMux := muxPair() | 223 clientMux, serverMux := muxPair() |
225 | 224 |
226 var seen bool | 225 var seen bool |
227 » var wg sync.WaitGroup | 226 » go func() { |
228 » wg.Add(1) | 227 » » for r := range serverMux.incomingRequests { |
229 » go func() { | 228 » » » seen = seen || r.Request == "peek" |
230 » » for r := range serverMux.GlobalReceived() { | |
231 » » » seen = seen || r.Type == "peek" | |
232 if r.WantReply { | 229 if r.WantReply { |
233 » » » » err := serverMux.AckGlobalRequest(r.Type == "yes
", | 230 » » » » err := serverMux.AckRequest(r.Request == "yes", |
234 » » » » » append([]byte(r.Type), r.Data...)) | 231 » » » » » append([]byte(r.Request), r.Payload...)) |
235 if err != nil { | 232 if err != nil { |
236 t.Errorf("AckRequest: %v", err) | 233 t.Errorf("AckRequest: %v", err) |
237 } | 234 } |
238 } | 235 } |
239 } | 236 } |
240 » » wg.Done() | 237 » }() |
241 » }() | 238 |
242 | 239 » _, _, err := clientMux.SendRequest("peek", false, nil) |
243 » _, _, err := clientMux.SendGlobalRequest("peek", false, nil) | 240 » if err != nil { |
244 » if err != nil { | 241 » » t.Errorf("SendRequest: %v", err) |
245 » » t.Errorf("SendGlobalRequest: %v", err) | 242 » } |
246 » } | 243 |
247 | 244 » ok, data, err := clientMux.SendRequest("yes", true, []byte("a")) |
248 » ok, data, err := clientMux.SendGlobalRequest("yes", true, []byte("a")) | |
249 if !ok || string(data) != "yesa" || err != nil { | 245 if !ok || string(data) != "yesa" || err != nil { |
250 » » t.Errorf("SendGlobalRequest(\"yes\", true, \"a\"): %v %v %v", | 246 » » t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", |
251 ok, data, err) | 247 ok, data, err) |
252 } | 248 } |
253 » if ok, data, err := clientMux.SendGlobalRequest("yes", true, []byte("a")
); !ok || string(data) != "yesa" || err != nil { | 249 » if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok
|| string(data) != "yesa" || err != nil { |
254 » » t.Errorf("SendGlobalRequest(\"yes\", true, \"a\"): %v %v %v", | 250 » » t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v", |
255 ok, data, err) | 251 ok, data, err) |
256 } | 252 } |
257 | 253 |
258 » if ok, data, err := clientMux.SendGlobalRequest("no", true, []byte("a"))
; ok || string(data) != "noa" || err != nil { | 254 » if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok |
| string(data) != "noa" || err != nil { |
259 » » t.Errorf("SendGlobalRequest(\"no\", true, \"a\"): %v %v %v", | 255 » » t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v", |
260 ok, data, err) | 256 ok, data, err) |
261 } | 257 } |
262 » // not really related to global reqs, but try disconnect too. | 258 |
263 » clientMux.Disconnect(42, "whatever") | 259 » clientMux.Disconnect(0, "") |
264 | 260 » if !seen { |
265 » wg.Wait() | 261 » » t.Errorf("never saw 'peek' request") |
266 } | 262 » } |
| 263 } |
| 264 |
| 265 func TestMuxGlobalRequestUnblock(t *testing.T) { |
| 266 » clientMux, serverMux := muxPair() |
| 267 |
| 268 » result := make(chan error, 1) |
| 269 » go func() { |
| 270 » » _, _, err := clientMux.SendRequest("hello", true, nil) |
| 271 » » result <- err |
| 272 » }() |
| 273 |
| 274 » <-serverMux.incomingRequests |
| 275 » serverMux.conn.Close() |
| 276 » err := <-result |
| 277 |
| 278 » if err != io.EOF { |
| 279 » » t.Errorf("want EOF, got %v", io.EOF) |
| 280 » } |
| 281 } |
| 282 |
| 283 func TestMuxChannelRequestUnblock(t *testing.T) { |
| 284 » a, b, connB := channelPair(t) |
| 285 |
| 286 » result := make(chan error, 1) |
| 287 » go func() { |
| 288 » » _, err := a.SendRequest("hello", true, nil) |
| 289 » » result <- err |
| 290 » }() |
| 291 |
| 292 » <-b.incomingRequests |
| 293 » connB.conn.Close() |
| 294 » err := <-result |
| 295 |
| 296 » if err != io.EOF { |
| 297 » » t.Errorf("want EOF, got %v", err) |
| 298 » } |
| 299 } |
| 300 |
| 301 func TestMuxDisconnect(t *testing.T) { |
| 302 » a, b := muxPair() |
| 303 » go func() { |
| 304 » » for r := range b.incomingRequests { |
| 305 » » » if r.WantReply { |
| 306 » » » » b.AckRequest(true, nil) |
| 307 » » » } |
| 308 » » } |
| 309 » }() |
| 310 |
| 311 » a.Disconnect(42, "whatever") |
| 312 » ok, _, err := a.SendRequest("hello", true, nil) |
| 313 » if ok || err == nil { |
| 314 » » t.Errorf("got reply after disconnecting") |
| 315 » } |
| 316 } |
| 317 |
| 318 func TestMuxCloseChannel(t *testing.T) { |
| 319 » r, w, _ := channelPair(t) |
| 320 |
| 321 » timeout := time.After(10 * time.Millisecond) |
| 322 » result := make(chan error, 1) |
| 323 » go func() { |
| 324 » » var b [1024]byte |
| 325 » » _, err := r.Read(b[:]) |
| 326 » » result <- err |
| 327 » }() |
| 328 » if err := w.Close(); err != nil { |
| 329 » » t.Errorf("w.Close: %v", err) |
| 330 » } |
| 331 |
| 332 » if _, err := w.Write([]byte("hello")); err != io.EOF { |
| 333 » » t.Errorf("got err %v, want io.EOF after Close", err) |
| 334 » } |
| 335 |
| 336 » select { |
| 337 » case e := <-result: |
| 338 » » if e != io.EOF { |
| 339 » » » t.Errorf("got %v (%T), want io.EOF", e, e) |
| 340 » » } |
| 341 » case <-timeout: |
| 342 » » t.Errorf("timed out waiting for read to exit") |
| 343 » } |
| 344 } |
| 345 |
| 346 func TestMuxCloseWriteChannel(t *testing.T) { |
| 347 » r, w, _ := channelPair(t) |
| 348 |
| 349 » timeout := time.After(10 * time.Millisecond) |
| 350 » result := make(chan error, 1) |
| 351 » go func() { |
| 352 » » var b [1024]byte |
| 353 » » _, err := r.Read(b[:]) |
| 354 » » result <- err |
| 355 » }() |
| 356 » if err := w.CloseWrite(); err != nil { |
| 357 » » t.Errorf("w.CloseWrite: %v", err) |
| 358 » } |
| 359 |
| 360 » if _, err := w.Write([]byte("hello")); err != io.EOF { |
| 361 » » t.Errorf("got err %v, want io.EOF after CloseWrite", err) |
| 362 » } |
| 363 |
| 364 » select { |
| 365 » case e := <-result: |
| 366 » » if e != io.EOF { |
| 367 » » » t.Errorf("got %v (%T), want io.EOF", e, e) |
| 368 » » } |
| 369 » case <-timeout: |
| 370 » » t.Errorf("timed out waiting for read to exit") |
| 371 » } |
| 372 } |
| 373 |
| 374 func TestMuxInvalidRecord(t *testing.T) { |
| 375 » a, b := muxPair() |
| 376 |
| 377 » packet := make([]byte, 1+4+4+1) |
| 378 » packet[0] = msgChannelData |
| 379 » marshalUint32(packet[1:], 29348723 /* invalid channel id */) |
| 380 » marshalUint32(packet[5:], 1) |
| 381 » packet[9] = 42 |
| 382 |
| 383 » a.conn.writePacket(packet) |
| 384 » go a.SendRequest("hello", false, nil) |
| 385 » // 'a' wrote an invalid packet, so 'b' has exited. |
| 386 » req, ok := <-b.incomingRequests |
| 387 » if ok { |
| 388 » » t.Errorf("got request %#v after receiving invalid packet", req) |
| 389 » } |
| 390 } |
| 391 |
| 392 func TestZeroWindowAdjust(t *testing.T) { |
| 393 » a, b, _ := channelPair(t) |
| 394 |
| 395 » go func() { |
| 396 » » io.WriteString(a, "hello") |
| 397 » » // bogus adjust. |
| 398 » » a.sendMessage( |
| 399 » » » msgChannelWindowAdjust, windowAdjustMsg{}) |
| 400 » » io.WriteString(a, "world") |
| 401 » » a.Close() |
| 402 » }() |
| 403 |
| 404 » want := "helloworld" |
| 405 » c, _ := ioutil.ReadAll(b) |
| 406 » if string(c) != want { |
| 407 » » t.Errorf("got %q want %q", c, want) |
| 408 » } |
| 409 } |
| 410 |
| 411 func TestMuxMaxPacketSize(t *testing.T) { |
| 412 » a, b, _ := channelPair(t) |
| 413 |
| 414 » large := make([]byte, a.maxPacket+1) |
| 415 » if err := a.writePacket(large); err == nil { |
| 416 » » t.Errorf("channel sent out packet larger than maxPacket") |
| 417 » } |
| 418 |
| 419 » packet := make([]byte, 1+4+4+1+len(large)) |
| 420 » packet[0] = msgChannelData |
| 421 » marshalUint32(packet[1:], a.remoteId) |
| 422 » marshalUint32(packet[5:], uint32(len(large))) |
| 423 » packet[9] = 42 |
| 424 |
| 425 » if err := a.mux.conn.writePacket(packet); err != nil { |
| 426 » » t.Errorf("could not send packet") |
| 427 » } |
| 428 |
| 429 » go a.SendRequest("hello", false, nil) |
| 430 |
| 431 » _, ok := <-b.incomingRequests |
| 432 » if ok { |
| 433 » » t.Errorf("connection still alive after receiving large packet.") |
| 434 » } |
| 435 } |
LEFT | RIGHT |