Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(914)

Delta Between Two Patch Sets: ssh/mux_test.go

Issue 14225043: code review 14225043: go.crypto/ssh: reimplement SSH connection protocol modu... (Closed)
Left Patch Set: diff -r 2cd6b3b93cdb https://code.google.com/p/go.crypto Created 10 years, 6 months ago
Right Patch Set: diff -r cd1eea1eb828 https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
Left: Side by side diff | Download
Right: Side by side diff | Download
« no previous file with change/comment | « ssh/mux.go ('k') | ssh/server.go » ('j') | no next file with change/comment »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
LEFTRIGHT
1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
1 package ssh 5 package ssh
2 6
3 import ( 7 import (
4 "io" 8 "io"
5 » "log" 9 » "io/ioutil"
6 "sync" 10 "sync"
7 "testing" 11 "testing"
8 "time" 12 "time"
9 ) 13 )
10 14
11 var _ = log.Println
12
13 func muxPair() (*mux, *mux) { 15 func muxPair() (*mux, *mux) {
14 a, b := memPipe() 16 a, b := memPipe()
15 17
16 s := newMux(a) 18 s := newMux(a)
17 c := newMux(b) 19 c := newMux(b)
18 » c.chanList.offset = 'c' 20
19 » s.chanList.offset = 's' 21 » go s.Loop()
22 » go c.Loop()
23
20 return s, c 24 return s, c
21 } 25 }
22 26
23 func channelPair(t *testing.T) (Channel, Channel) { 27 // Returns both ends of a channel, and the mux for the the 2nd
28 // channel.
29 func channelPair(t *testing.T) (*channel, *channel, *mux) {
24 c, s := muxPair() 30 c, s := muxPair()
25 31
26 » res := make(chan Channel, 1) 32 » res := make(chan *channel, 1)
27 » go func() { 33 » go func() {
28 » » ch, err := s.Accept() 34 » » ch, ok := <-s.incomingChannels
35 » » if !ok {
36 » » » t.Fatalf("No incoming channel")
37 » » }
38 » » if ch.ChannelType() != "chan" {
39 » » » t.Fatalf("got type %q want chan", ch.ChannelType())
40 » » }
41 » » err := ch.Accept()
29 if err != nil { 42 if err != nil {
30 t.Fatalf("Accept %v", err) 43 t.Fatalf("Accept %v", err)
31 } 44 }
32 if ch.ChannelType() != "chan" {
33 t.Fatalf("got type %q want chan", ch.ChannelType())
34 }
35 ch.Accept()
36 res <- ch 45 res <- ch
37 }() 46 }()
38 47
39 ch, err := c.OpenChannel("chan", nil) 48 ch, err := c.OpenChannel("chan", nil)
40 if err != nil { 49 if err != nil {
41 t.Fatalf("OpenChannel: %v", err) 50 t.Fatalf("OpenChannel: %v", err)
42 } 51 }
43 52
44 » return <-res, ch 53 » return <-res, ch, c
45 } 54 }
46 55
47 func TestMuxReadWrite(t *testing.T) { 56 func TestMuxReadWrite(t *testing.T) {
48 » s, c := channelPair(t) 57 » s, c, _ := channelPair(t)
49 58
50 magic := "hello world" 59 magic := "hello world"
51 magicExt := "hello stderr" 60 magicExt := "hello stderr"
52 » var wg sync.WaitGroup 61 » go func() {
53 » wg.Add(1)
54 » go func() {
55 » » defer wg.Done()
56
57 _, err := s.Write([]byte(magic)) 62 _, err := s.Write([]byte(magic))
58 if err != nil { 63 if err != nil {
59 t.Fatalf("Write: %v", err) 64 t.Fatalf("Write: %v", err)
60 } 65 }
61 » » _, err = s.Stderr().Write([]byte(magicExt)) 66 » » _, err = s.Extended(1).Write([]byte(magicExt))
62 if err != nil { 67 if err != nil {
63 t.Fatalf("Write: %v", err) 68 t.Fatalf("Write: %v", err)
64 } 69 }
65 err = s.Close() 70 err = s.Close()
66 if err != nil { 71 if err != nil {
67 t.Fatalf("Close: %v", err) 72 t.Fatalf("Close: %v", err)
68 } 73 }
69 }() 74 }()
70 75
71 var buf [1024]byte 76 var buf [1024]byte
72 n, err := c.Read(buf[:]) 77 n, err := c.Read(buf[:])
73 if err != nil { 78 if err != nil {
74 t.Fatalf("server Read: %v", err) 79 t.Fatalf("server Read: %v", err)
75 } 80 }
76 got := string(buf[:n]) 81 got := string(buf[:n])
77 if got != magic { 82 if got != magic {
78 t.Fatalf("server: got %q want %q", got, magic) 83 t.Fatalf("server: got %q want %q", got, magic)
79 } 84 }
80 85
81 » n, err = c.Stderr().Read(buf[:]) 86 » n, err = c.Extended(1).Read(buf[:])
82 if err != nil { 87 if err != nil {
83 t.Fatalf("server Read: %v", err) 88 t.Fatalf("server Read: %v", err)
84 } 89 }
85 90
86 got = string(buf[:n]) 91 got = string(buf[:n])
87 if got != magicExt { 92 if got != magicExt {
88 t.Fatalf("server: got %q want %q", got, magic) 93 t.Fatalf("server: got %q want %q", got, magic)
89 } 94 }
90 } 95 }
91 96
92 func TestMuxFlowControl(t *testing.T) { 97 func TestMuxFlowControl(t *testing.T) {
93 writerMux, readerMux := muxPair() 98 writerMux, readerMux := muxPair()
94 99
95 » var wg sync.WaitGroup 100 » // this goroutine reads just a bit.
96 » wg.Add(2) 101 » go func() {
97 102 » » reader, ok := <-readerMux.incomingChannels
98 » // More than window size 103 » » if !ok {
99 » go func() { 104 » » » t.Fatalf("no incoming channel")
100 » » reader, err := readerMux.Accept() 105 » » }
106 » » err := reader.Accept()
101 if err != nil { 107 if err != nil {
102 t.Fatalf("Accept: %v", err)
103 }
104 if err = reader.Accept(); err != nil {
105 t.Fatalf("Accept: %v", err) 108 t.Fatalf("Accept: %v", err)
106 } 109 }
107 110
108 b := make([]byte, 1024) 111 b := make([]byte, 1024)
109 n, err := reader.Read(b) 112 n, err := reader.Read(b)
110 if err != nil || n != len(b) { 113 if err != nil || n != len(b) {
111 t.Errorf("Read: %v, %d bytes", err, n) 114 t.Errorf("Read: %v, %d bytes", err, n)
112 } 115 }
113 wg.Done()
114 }() 116 }()
115 117
116 writer, err := writerMux.OpenChannel("pipe", nil) 118 writer, err := writerMux.OpenChannel("pipe", nil)
117 if err != nil { 119 if err != nil {
118 t.Fatalf("OpenChannel: %v", err) 120 t.Fatalf("OpenChannel: %v", err)
119 } 121 }
120 122
123 // This goroutine writes is blocked from writing by the slow
124 // reader
121 go func() { 125 go func() {
122 largeData := make([]byte, 3*(1<<15)) 126 largeData := make([]byte, 3*(1<<15))
123 n, err := writer.Write(largeData) 127 n, err := writer.Write(largeData)
124 if err != io.EOF { 128 if err != io.EOF {
125 t.Errorf("want EOF, got %v", err) 129 t.Errorf("want EOF, got %v", err)
126 } 130 }
127 want := 1024 + (1 << 15) 131 want := 1024 + (1 << 15)
128 if n != want { 132 if n != want {
129 t.Errorf("wrote %d, want %d", n, want) 133 t.Errorf("wrote %d, want %d", n, want)
130 } 134 }
131 wg.Done()
132 }() 135 }()
133 136
134 // Wait for a bit for things to subside. The write should be 137 // Wait for a bit for things to subside. The write should be
135 // blocked. 138 // blocked.
136 time.Sleep(1 * time.Millisecond) 139 time.Sleep(1 * time.Millisecond)
137 140
138 » readerMux.conn.Close() 141 » readerMux.Disconnect(0, "")
139 » writerMux.conn.Close() 142 » writerMux.Disconnect(0, "")
140
141 » wg.Done()
142 } 143 }
143 144
144 func TestMuxReject(t *testing.T) { 145 func TestMuxReject(t *testing.T) {
145 client, server := muxPair() 146 client, server := muxPair()
146 147
147 go func() { 148 go func() {
148 » » ch, err := server.Accept() 149 » » ch, ok := <-server.incomingChannels
149 » » if err != nil { 150 » » if !ok {
150 » » » t.Fatalf("Accept: %v", err) 151 » » » t.Fatalf("Accept")
151 } 152 }
152 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" { 153 if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
153 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData()) 154 t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
154 } 155 }
155 ch.Reject(RejectionReason(42), "message") 156 ch.Reject(RejectionReason(42), "message")
156 }() 157 }()
157 158
158 ch, err := client.OpenChannel("ch", []byte("extra")) 159 ch, err := client.OpenChannel("ch", []byte("extra"))
159 if ch != nil { 160 if ch != nil {
160 t.Fatal("openChannel not rejected") 161 t.Fatal("openChannel not rejected")
161 } 162 }
162 163
163 » ocf, ok := err.(*OpenChannelFailed) 164 » ocf, ok := err.(*OpenChannelError)
164 if !ok { 165 if !ok {
165 » » t.Errorf("got %#v want *OpenChannelFailed", err) 166 » » t.Errorf("got %#v want *OpenChannelError", err)
166 » } 167 » } else if ocf.Reason != 42 || ocf.Message != "message" {
167 168 » » t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "messag e")
168 » if ocf.Reason != 42 || ocf.Message != "message" {
169 » » t.Errorf("got %#v, want {Reason: 42, Mepassage: %q}", ocf, "mess age")
170 } 169 }
171 170
172 want := "ssh: rejected: unknown reason 42 (message)" 171 want := "ssh: rejected: unknown reason 42 (message)"
173 if err.Error() != want { 172 if err.Error() != want {
174 t.Errorf("got %q, want %q", err.Error(), want) 173 t.Errorf("got %q, want %q", err.Error(), want)
175 } 174 }
176 } 175 }
177 176
178 func TestMuxChannelRequest(t *testing.T) { 177 func TestMuxChannelRequest(t *testing.T) {
179 » client, server := channelPair(t) 178 » client, server, _ := channelPair(t)
180 var received int 179 var received int
181 var wg sync.WaitGroup 180 var wg sync.WaitGroup
182 wg.Add(1) 181 wg.Add(1)
183 go func() { 182 go func() {
184 » » for r := range server.ReceivedRequests() { 183 » » for r := range server.incomingRequests {
185 received++ 184 received++
186 if r.WantReply { 185 if r.WantReply {
187 server.AckRequest(r.Request == "yes") 186 server.AckRequest(r.Request == "yes")
188 } 187 }
189 } 188 }
190 wg.Done() 189 wg.Done()
191 }() 190 }()
192 _, err := client.SendRequest("yes", false, nil) 191 _, err := client.SendRequest("yes", false, nil)
193 if err != nil { 192 if err != nil {
194 t.Fatalf("SendRequest: %v", err) 193 t.Fatalf("SendRequest: %v", err)
195 } 194 }
196 ok, err := client.SendRequest("yes", true, nil) 195 ok, err := client.SendRequest("yes", true, nil)
197 if err != nil { 196 if err != nil {
198 t.Fatalf("SendRequest: %v", err) 197 t.Fatalf("SendRequest: %v", err)
199 } 198 }
200 log.Println("ok", ok)
201 199
202 if !ok { 200 if !ok {
203 t.Errorf("SendRequest(yes): %v", ok) 201 t.Errorf("SendRequest(yes): %v", ok)
204 202
205 } 203 }
206 204
207 ok, err = client.SendRequest("no", true, nil) 205 ok, err = client.SendRequest("no", true, nil)
208 if err != nil { 206 if err != nil {
209 t.Fatalf("SendRequest: %v", err) 207 t.Fatalf("SendRequest: %v", err)
210 } 208 }
211 if ok { 209 if ok {
212 t.Errorf("SendRequest(no): %v", ok) 210 t.Errorf("SendRequest(no): %v", ok)
213 211
214 } 212 }
213
215 client.Close() 214 client.Close()
216 wg.Wait() 215 wg.Wait()
217 216
218 if received != 3 { 217 if received != 3 {
219 t.Errorf("got %d requests, want %d", received) 218 t.Errorf("got %d requests, want %d", received)
220 } 219 }
221 } 220 }
222 221
223 func TestMuxGlobalRequest(t *testing.T) { 222 func TestMuxGlobalRequest(t *testing.T) {
224 clientMux, serverMux := muxPair() 223 clientMux, serverMux := muxPair()
225 224
226 var seen bool 225 var seen bool
227 » var wg sync.WaitGroup 226 » go func() {
228 » wg.Add(1) 227 » » for r := range serverMux.incomingRequests {
229 » go func() { 228 » » » seen = seen || r.Request == "peek"
230 » » for r := range serverMux.GlobalReceived() {
231 » » » seen = seen || r.Type == "peek"
232 if r.WantReply { 229 if r.WantReply {
233 » » » » err := serverMux.AckGlobalRequest(r.Type == "yes ", 230 » » » » err := serverMux.AckRequest(r.Request == "yes",
234 » » » » » append([]byte(r.Type), r.Data...)) 231 » » » » » append([]byte(r.Request), r.Payload...))
235 if err != nil { 232 if err != nil {
236 t.Errorf("AckRequest: %v", err) 233 t.Errorf("AckRequest: %v", err)
237 } 234 }
238 } 235 }
239 } 236 }
240 » » wg.Done() 237 » }()
241 » }() 238
242 239 » _, _, err := clientMux.SendRequest("peek", false, nil)
243 » _, _, err := clientMux.SendGlobalRequest("peek", false, nil) 240 » if err != nil {
244 » if err != nil { 241 » » t.Errorf("SendRequest: %v", err)
245 » » t.Errorf("SendGlobalRequest: %v", err) 242 » }
246 » } 243
247 244 » ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
248 » ok, data, err := clientMux.SendGlobalRequest("yes", true, []byte("a"))
249 if !ok || string(data) != "yesa" || err != nil { 245 if !ok || string(data) != "yesa" || err != nil {
250 » » t.Errorf("SendGlobalRequest(\"yes\", true, \"a\"): %v %v %v", 246 » » t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
251 ok, data, err) 247 ok, data, err)
252 } 248 }
253 » if ok, data, err := clientMux.SendGlobalRequest("yes", true, []byte("a") ); !ok || string(data) != "yesa" || err != nil { 249 » if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
254 » » t.Errorf("SendGlobalRequest(\"yes\", true, \"a\"): %v %v %v", 250 » » t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
255 ok, data, err) 251 ok, data, err)
256 } 252 }
257 253
258 » if ok, data, err := clientMux.SendGlobalRequest("no", true, []byte("a")) ; ok || string(data) != "noa" || err != nil { 254 » if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok | | string(data) != "noa" || err != nil {
259 » » t.Errorf("SendGlobalRequest(\"no\", true, \"a\"): %v %v %v", 255 » » t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
260 ok, data, err) 256 ok, data, err)
261 } 257 }
262 » // not really related to global reqs, but try disconnect too. 258
263 » clientMux.Disconnect(42, "whatever") 259 » clientMux.Disconnect(0, "")
264 260 » if !seen {
265 » wg.Wait() 261 » » t.Errorf("never saw 'peek' request")
266 } 262 » }
263 }
264
265 func TestMuxGlobalRequestUnblock(t *testing.T) {
266 » clientMux, serverMux := muxPair()
267
268 » result := make(chan error, 1)
269 » go func() {
270 » » _, _, err := clientMux.SendRequest("hello", true, nil)
271 » » result <- err
272 » }()
273
274 » <-serverMux.incomingRequests
275 » serverMux.conn.Close()
276 » err := <-result
277
278 » if err != io.EOF {
279 » » t.Errorf("want EOF, got %v", io.EOF)
280 » }
281 }
282
283 func TestMuxChannelRequestUnblock(t *testing.T) {
284 » a, b, connB := channelPair(t)
285
286 » result := make(chan error, 1)
287 » go func() {
288 » » _, err := a.SendRequest("hello", true, nil)
289 » » result <- err
290 » }()
291
292 » <-b.incomingRequests
293 » connB.conn.Close()
294 » err := <-result
295
296 » if err != io.EOF {
297 » » t.Errorf("want EOF, got %v", err)
298 » }
299 }
300
301 func TestMuxDisconnect(t *testing.T) {
302 » a, b := muxPair()
303 » go func() {
304 » » for r := range b.incomingRequests {
305 » » » if r.WantReply {
306 » » » » b.AckRequest(true, nil)
307 » » » }
308 » » }
309 » }()
310
311 » a.Disconnect(42, "whatever")
312 » ok, _, err := a.SendRequest("hello", true, nil)
313 » if ok || err == nil {
314 » » t.Errorf("got reply after disconnecting")
315 » }
316 }
317
318 func TestMuxCloseChannel(t *testing.T) {
319 » r, w, _ := channelPair(t)
320
321 » timeout := time.After(10 * time.Millisecond)
322 » result := make(chan error, 1)
323 » go func() {
324 » » var b [1024]byte
325 » » _, err := r.Read(b[:])
326 » » result <- err
327 » }()
328 » if err := w.Close(); err != nil {
329 » » t.Errorf("w.Close: %v", err)
330 » }
331
332 » if _, err := w.Write([]byte("hello")); err != io.EOF {
333 » » t.Errorf("got err %v, want io.EOF after Close", err)
334 » }
335
336 » select {
337 » case e := <-result:
338 » » if e != io.EOF {
339 » » » t.Errorf("got %v (%T), want io.EOF", e, e)
340 » » }
341 » case <-timeout:
342 » » t.Errorf("timed out waiting for read to exit")
343 » }
344 }
345
346 func TestMuxCloseWriteChannel(t *testing.T) {
347 » r, w, _ := channelPair(t)
348
349 » timeout := time.After(10 * time.Millisecond)
350 » result := make(chan error, 1)
351 » go func() {
352 » » var b [1024]byte
353 » » _, err := r.Read(b[:])
354 » » result <- err
355 » }()
356 » if err := w.CloseWrite(); err != nil {
357 » » t.Errorf("w.CloseWrite: %v", err)
358 » }
359
360 » if _, err := w.Write([]byte("hello")); err != io.EOF {
361 » » t.Errorf("got err %v, want io.EOF after CloseWrite", err)
362 » }
363
364 » select {
365 » case e := <-result:
366 » » if e != io.EOF {
367 » » » t.Errorf("got %v (%T), want io.EOF", e, e)
368 » » }
369 » case <-timeout:
370 » » t.Errorf("timed out waiting for read to exit")
371 » }
372 }
373
374 func TestMuxInvalidRecord(t *testing.T) {
375 » a, b := muxPair()
376
377 » packet := make([]byte, 1+4+4+1)
378 » packet[0] = msgChannelData
379 » marshalUint32(packet[1:], 29348723 /* invalid channel id */)
380 » marshalUint32(packet[5:], 1)
381 » packet[9] = 42
382
383 » a.conn.writePacket(packet)
384 » go a.SendRequest("hello", false, nil)
385 » // 'a' wrote an invalid packet, so 'b' has exited.
386 » req, ok := <-b.incomingRequests
387 » if ok {
388 » » t.Errorf("got request %#v after receiving invalid packet", req)
389 » }
390 }
391
392 func TestZeroWindowAdjust(t *testing.T) {
393 » a, b, _ := channelPair(t)
394
395 » go func() {
396 » » io.WriteString(a, "hello")
397 » » // bogus adjust.
398 » » a.sendMessage(
399 » » » msgChannelWindowAdjust, windowAdjustMsg{})
400 » » io.WriteString(a, "world")
401 » » a.Close()
402 » }()
403
404 » want := "helloworld"
405 » c, _ := ioutil.ReadAll(b)
406 » if string(c) != want {
407 » » t.Errorf("got %q want %q", c, want)
408 » }
409 }
410
411 func TestMuxMaxPacketSize(t *testing.T) {
412 » a, b, _ := channelPair(t)
413
414 » large := make([]byte, a.maxPacket+1)
415 » if err := a.writePacket(large); err == nil {
416 » » t.Errorf("channel sent out packet larger than maxPacket")
417 » }
418
419 » packet := make([]byte, 1+4+4+1+len(large))
420 » packet[0] = msgChannelData
421 » marshalUint32(packet[1:], a.remoteId)
422 » marshalUint32(packet[5:], uint32(len(large)))
423 » packet[9] = 42
424
425 » if err := a.mux.conn.writePacket(packet); err != nil {
426 » » t.Errorf("could not send packet")
427 » }
428
429 » go a.SendRequest("hello", false, nil)
430
431 » _, ok := <-b.incomingRequests
432 » if ok {
433 » » t.Errorf("connection still alive after receiving large packet.")
434 » }
435 }
LEFTRIGHT

Powered by Google App Engine
RSS Feeds Recent Issues | This issue
This is Rietveld f62528b