LEFT | RIGHT |
1 // Copyright 2011 The Go Authors. All rights reserved. | 1 // Copyright 2011 The Go Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style | 2 // Use of this source code is governed by a BSD-style |
3 // license that can be found in the LICENSE file. | 3 // license that can be found in the LICENSE file. |
4 | 4 |
5 package ssh | 5 package ssh |
6 | 6 |
7 // Session tests. | 7 // Session tests. |
8 | 8 |
9 import ( | 9 import ( |
10 "bytes" | 10 "bytes" |
11 crypto_rand "crypto/rand" | 11 crypto_rand "crypto/rand" |
12 "io" | 12 "io" |
13 "io/ioutil" | 13 "io/ioutil" |
14 "log" | |
15 "math/rand" | 14 "math/rand" |
16 "net" | 15 "net" |
17 "testing" | 16 "testing" |
18 | 17 |
19 "code.google.com/p/go.crypto/ssh/terminal" | 18 "code.google.com/p/go.crypto/ssh/terminal" |
20 ) | 19 ) |
21 | 20 |
22 var _ = log.Println | 21 type serverType func(*channel, *testing.T) |
23 | |
24 type serverType func(Channel, *testing.T) | |
25 | 22 |
26 // dial constructs a new test server and returns a *ClientConn. | 23 // dial constructs a new test server and returns a *ClientConn. |
27 func dial(handler serverType, t *testing.T) *ClientConn { | 24 func dial(handler serverType, t *testing.T) *ClientConn { |
28 l, err := Listen("tcp", "127.0.0.1:0", serverConfig) | 25 l, err := Listen("tcp", "127.0.0.1:0", serverConfig) |
29 if err != nil { | 26 if err != nil { |
30 t.Fatalf("unable to listen: %v", err) | 27 t.Fatalf("unable to listen: %v", err) |
31 } | 28 } |
32 go func() { | 29 go func() { |
33 defer l.Close() | 30 defer l.Close() |
34 conn, err := l.Accept() | 31 conn, err := l.Accept() |
(...skipping 18 matching lines...) Expand all Loading... |
53 } | 50 } |
54 if err != nil { | 51 if err != nil { |
55 t.Errorf("Unable to accept incoming channel requ
est: %v", err) | 52 t.Errorf("Unable to accept incoming channel requ
est: %v", err) |
56 return | 53 return |
57 } | 54 } |
58 if ch.ChannelType() != "session" { | 55 if ch.ChannelType() != "session" { |
59 ch.Reject(UnknownChannelType, "unknown channel t
ype") | 56 ch.Reject(UnknownChannelType, "unknown channel t
ype") |
60 continue | 57 continue |
61 } | 58 } |
62 | 59 |
63 » » » ch.Accept() | 60 » » » if err = ch.Accept(); err != nil { |
| 61 » » » » t.Errorf("Accept: %v", err) |
| 62 » » » } |
64 go func() { | 63 go func() { |
65 defer close(done) | 64 defer close(done) |
66 » » » » handler(ch, t) | 65 » » » » handler(ch.(*compatChannel).channel, t) |
67 }() | 66 }() |
68 } | 67 } |
69 <-done | 68 <-done |
70 }() | 69 }() |
71 | 70 |
72 config := &ClientConfig{ | 71 config := &ClientConfig{ |
73 User: "testuser", | 72 User: "testuser", |
74 Auth: []ClientAuth{ | 73 Auth: []ClientAuth{ |
75 ClientAuthPassword(clientPassword), | 74 ClientAuthPassword(clientPassword), |
76 }, | 75 }, |
77 } | 76 } |
78 | 77 |
79 c, err := Dial("tcp", l.Addr().String(), config) | 78 c, err := Dial("tcp", l.Addr().String(), config) |
80 if err != nil { | 79 if err != nil { |
81 t.Fatalf("unable to dial remote side: %v", err) | 80 t.Fatalf("unable to dial remote side: %v", err) |
82 } | 81 } |
83 return c | 82 return c |
84 } | 83 } |
85 | 84 |
86 // TEST a simple string is returned to session.Stdout. | 85 // Test a simple string is returned to session.Stdout. |
87 func TestSessionShell(t *testing.T) { | 86 func TestSessionShell(t *testing.T) { |
88 conn := dial(shellHandler, t) | 87 conn := dial(shellHandler, t) |
89 defer conn.Close() | 88 defer conn.Close() |
90 session, err := conn.NewSession() | 89 session, err := conn.NewSession() |
91 if err != nil { | 90 if err != nil { |
92 t.Fatalf("Unable to request new session: %v", err) | 91 t.Fatalf("Unable to request new session: %v", err) |
93 } | 92 } |
94 defer session.Close() | 93 defer session.Close() |
95 stdout := new(bytes.Buffer) | 94 stdout := new(bytes.Buffer) |
96 session.Stdout = stdout | 95 session.Stdout = stdout |
(...skipping 230 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
327 t.Fatalf("expected command to fail but it didn't") | 326 t.Fatalf("expected command to fail but it didn't") |
328 } | 327 } |
329 _, ok := err.(*ExitError) | 328 _, ok := err.(*ExitError) |
330 if ok { | 329 if ok { |
331 // you can't actually test for errors.errorString | 330 // you can't actually test for errors.errorString |
332 // because it's not exported. | 331 // because it's not exported. |
333 t.Fatalf("expected *errorString but got %T", err) | 332 t.Fatalf("expected *errorString but got %T", err) |
334 } | 333 } |
335 } | 334 } |
336 | 335 |
337 // Verify that the client never sends a packet larger than maxpacket. | |
338 // | |
339 // Broken: writer sends more than the window, which is a protocol error. | |
340 func brokenTestClientStdinRespectsMaxPacketSize(t *testing.T) { | |
341 conn := dial(discardHandler, t) | |
342 defer conn.Close() | |
343 session, err := conn.NewSession() | |
344 if err != nil { | |
345 t.Fatalf("failed to request new session: %v", err) | |
346 } | |
347 defer session.Close() | |
348 stdin, err := session.StdinPipe() | |
349 if err != nil { | |
350 t.Fatalf("failed to obtain stdinpipe: %v", err) | |
351 } | |
352 const size = 100 * 1000 | |
353 for i := 0; i < 10; i++ { | |
354 n, err := stdin.Write(make([]byte, size)) | |
355 if n != size || err != nil { | |
356 t.Fatalf("failed to write: %d, %v", n, err) | |
357 } | |
358 } | |
359 } | |
360 | |
361 // Verify that the client never accepts a packet larger than maxpacket. | |
362 func TestServerStdoutRespectsMaxPacketSize(t *testing.T) { | |
363 conn := dial(largeSendHandler, t) | |
364 defer conn.Close() | |
365 session, err := conn.NewSession() | |
366 if err != nil { | |
367 t.Fatalf("Unable to request new session: %v", err) | |
368 } | |
369 defer session.Close() | |
370 out, err := session.StdoutPipe() | |
371 if err != nil { | |
372 t.Fatalf("Unable to connect to Stdout: %v", err) | |
373 } | |
374 if err := session.Shell(); err != nil { | |
375 t.Fatalf("Unable to execute command: %v", err) | |
376 } | |
377 if _, err := ioutil.ReadAll(out); err != nil { | |
378 t.Fatalf("failed to read: %v", err) | |
379 } | |
380 } | |
381 | |
382 // TODO(hanwen): this test should be at the transport level. | 336 // TODO(hanwen): this test should be at the transport level. |
383 func TestClientCannotSendHugePacket(t *testing.T) { | 337 func TestClientCannotSendHugePacket(t *testing.T) { |
384 // client and server use the same transport write code so this | 338 // client and server use the same transport write code so this |
385 // test suffices for both. | 339 // test suffices for both. |
386 conn := dial(shellHandler, t) | 340 conn := dial(shellHandler, t) |
387 defer conn.Close() | 341 defer conn.Close() |
388 if err := conn.transport.writePacket(make([]byte, maxPacket*2)); err ==
nil { | 342 if err := conn.transport.writePacket(make([]byte, maxPacket*2)); err ==
nil { |
389 t.Fatalf("huge packet write should fail") | 343 t.Fatalf("huge packet write should fail") |
390 } | 344 } |
391 } | 345 } |
(...skipping 31 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
423 } | 377 } |
424 result <- echoedBuf.Bytes() | 378 result <- echoedBuf.Bytes() |
425 }() | 379 }() |
426 | 380 |
427 serverStdin, err := session.StdinPipe() | 381 serverStdin, err := session.StdinPipe() |
428 if err != nil { | 382 if err != nil { |
429 t.Fatalf("StdinPipe failed: %v", err) | 383 t.Fatalf("StdinPipe failed: %v", err) |
430 } | 384 } |
431 written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestB
ytes) | 385 written, err := copyNRandomly("stdin", serverStdin, origBuf, windowTestB
ytes) |
432 if err != nil { | 386 if err != nil { |
433 » » t.Fatalf("falied to copy origBuf to serverStdin: %v", err) | 387 » » t.Fatalf("failed to copy origBuf to serverStdin: %v", err) |
434 } | 388 } |
435 if written != windowTestBytes { | 389 if written != windowTestBytes { |
436 t.Fatalf("Wrote only %d of %d bytes to server", written, windowT
estBytes) | 390 t.Fatalf("Wrote only %d of %d bytes to server", written, windowT
estBytes) |
437 } | 391 } |
438 | 392 |
439 echoedBytes := <-result | 393 echoedBytes := <-result |
440 | 394 |
441 if !bytes.Equal(origBytes, echoedBytes) { | 395 if !bytes.Equal(origBytes, echoedBytes) { |
442 t.Fatalf("Echoed buffer differed from original, orig %d, echoed
%d", len(origBytes), len(echoedBytes)) | 396 t.Fatalf("Echoed buffer differed from original, orig %d, echoed
%d", len(origBytes), len(echoedBytes)) |
443 } | 397 } |
(...skipping 21 matching lines...) Expand all Loading... |
465 Status uint32 | 419 Status uint32 |
466 } | 420 } |
467 | 421 |
468 type exitSignalMsg struct { | 422 type exitSignalMsg struct { |
469 Signal string | 423 Signal string |
470 CoreDumped bool | 424 CoreDumped bool |
471 Errmsg string | 425 Errmsg string |
472 Lang string | 426 Lang string |
473 } | 427 } |
474 | 428 |
475 func newServerShell(ch Channel, prompt string) *ServerTerminal { | 429 func newServerShell(ch *channel, prompt string) *ServerTerminal { |
476 term := terminal.NewTerminal(ch, prompt) | 430 term := terminal.NewTerminal(ch, prompt) |
477 s := &ServerTerminal{ | 431 s := &ServerTerminal{ |
478 Term: term, | 432 Term: term, |
479 » » Channel: ch, | 433 » » Channel: newCompatChannel(ch), |
480 » } | 434 » } |
481 » go s.HandleRequests() | |
482 return s | 435 return s |
483 } | 436 } |
484 | 437 |
485 func exitStatusZeroHandler(ch Channel, t *testing.T) { | 438 func exitStatusZeroHandler(ch *channel, t *testing.T) { |
486 defer ch.Close() | 439 defer ch.Close() |
487 // this string is returned to stdout | 440 // this string is returned to stdout |
488 shell := newServerShell(ch, "> ") | 441 shell := newServerShell(ch, "> ") |
489 readLine(shell, t) | 442 readLine(shell, t) |
490 sendStatus(0, ch, t) | 443 sendStatus(0, ch, t) |
491 } | 444 } |
492 | 445 |
493 func exitStatusNonZeroHandler(ch Channel, t *testing.T) { | 446 func exitStatusNonZeroHandler(ch *channel, t *testing.T) { |
494 defer ch.Close() | 447 defer ch.Close() |
495 shell := newServerShell(ch, "> ") | 448 shell := newServerShell(ch, "> ") |
496 readLine(shell, t) | 449 readLine(shell, t) |
497 sendStatus(15, ch, t) | 450 sendStatus(15, ch, t) |
498 } | 451 } |
499 | 452 |
500 func exitSignalAndStatusHandler(ch Channel, t *testing.T) { | 453 func exitSignalAndStatusHandler(ch *channel, t *testing.T) { |
501 defer ch.Close() | 454 defer ch.Close() |
502 shell := newServerShell(ch, "> ") | 455 shell := newServerShell(ch, "> ") |
503 readLine(shell, t) | 456 readLine(shell, t) |
504 sendStatus(15, ch, t) | 457 sendStatus(15, ch, t) |
505 sendSignal("TERM", ch, t) | 458 sendSignal("TERM", ch, t) |
506 } | 459 } |
507 | 460 |
508 func exitSignalHandler(ch Channel, t *testing.T) { | 461 func exitSignalHandler(ch *channel, t *testing.T) { |
509 defer ch.Close() | 462 defer ch.Close() |
510 shell := newServerShell(ch, "> ") | 463 shell := newServerShell(ch, "> ") |
511 readLine(shell, t) | 464 readLine(shell, t) |
512 sendSignal("TERM", ch, t) | 465 sendSignal("TERM", ch, t) |
513 } | 466 } |
514 | 467 |
515 func exitSignalUnknownHandler(ch Channel, t *testing.T) { | 468 func exitSignalUnknownHandler(ch *channel, t *testing.T) { |
516 defer ch.Close() | 469 defer ch.Close() |
517 shell := newServerShell(ch, "> ") | 470 shell := newServerShell(ch, "> ") |
518 readLine(shell, t) | 471 readLine(shell, t) |
519 sendSignal("SYS", ch, t) | 472 sendSignal("SYS", ch, t) |
520 } | 473 } |
521 | 474 |
522 func exitWithoutSignalOrStatus(ch Channel, t *testing.T) { | 475 func exitWithoutSignalOrStatus(ch *channel, t *testing.T) { |
523 defer ch.Close() | 476 defer ch.Close() |
524 shell := newServerShell(ch, "> ") | 477 shell := newServerShell(ch, "> ") |
525 readLine(shell, t) | 478 readLine(shell, t) |
526 } | 479 } |
527 | 480 |
528 func shellHandler(ch Channel, t *testing.T) { | 481 func shellHandler(ch *channel, t *testing.T) { |
529 defer ch.Close() | 482 defer ch.Close() |
530 // this string is returned to stdout | 483 // this string is returned to stdout |
531 shell := newServerShell(ch, "golang") | 484 shell := newServerShell(ch, "golang") |
532 readLine(shell, t) | 485 readLine(shell, t) |
533 sendStatus(0, ch, t) | 486 sendStatus(0, ch, t) |
534 } | 487 } |
535 | 488 |
536 // Ignores the command, writes fixed strings to stderr and stdout. | 489 // Ignores the command, writes fixed strings to stderr and stdout. |
537 // Strings are "this-is-stdout." and "this-is-stderr.". | 490 // Strings are "this-is-stdout." and "this-is-stderr.". |
538 func fixedOutputHandler(ch Channel, t *testing.T) { | 491 func fixedOutputHandler(ch *channel, t *testing.T) { |
539 » defer ch.Close() | 492 » defer ch.Close() |
540 | 493 » _, err := ch.Read(make([]byte, 0)) |
541 » _ = <-ch.ReceivedRequests() | 494 » if _, ok := err.(ChannelRequest); !ok { |
| 495 » » t.Fatalf("error: expected channel request, got: %#v", err) |
| 496 » » return |
| 497 » } |
| 498 |
542 // ignore request, always send some text | 499 // ignore request, always send some text |
543 ch.AckRequest(true) | 500 ch.AckRequest(true) |
544 | 501 |
545 » _, err := io.WriteString(ch, "this-is-stdout.") | 502 » _, err = io.WriteString(ch, "this-is-stdout.") |
546 if err != nil { | 503 if err != nil { |
547 t.Fatalf("error writing on server: %v", err) | 504 t.Fatalf("error writing on server: %v", err) |
548 } | 505 } |
549 » _, err = io.WriteString(ch.Stderr(), "this-is-stderr.") | 506 » _, err = io.WriteString(ch.Extended(1), "this-is-stderr.") |
550 if err != nil { | 507 if err != nil { |
551 t.Fatalf("error writing on server: %v", err) | 508 t.Fatalf("error writing on server: %v", err) |
552 } | 509 } |
553 sendStatus(0, ch, t) | 510 sendStatus(0, ch, t) |
554 } | 511 } |
555 | 512 |
556 func readLine(shell *ServerTerminal, t *testing.T) { | 513 func readLine(shell *ServerTerminal, t *testing.T) { |
557 if _, err := shell.ReadLine(); err != nil && err != io.EOF { | 514 if _, err := shell.ReadLine(); err != nil && err != io.EOF { |
558 t.Errorf("unable to read line: %v", err) | 515 t.Errorf("unable to read line: %v", err) |
559 } | 516 } |
560 } | 517 } |
561 | 518 |
562 func sendStatus(status uint32, ch Channel, t *testing.T) { | 519 func sendStatus(status uint32, ch *channel, t *testing.T) { |
563 msg := exitStatusMsg{ | 520 msg := exitStatusMsg{ |
564 Status: status, | 521 Status: status, |
565 } | 522 } |
566 » if _, err := ch.SendRequest("exit-status", false, marshal(0, msg)[1:]);
err != nil { | 523 » if _, err := ch.SendRequest("exit-status", false, marshal(0, msg)); err
!= nil { |
567 » » panic("x") | |
568 t.Errorf("unable to send status: %v", err) | 524 t.Errorf("unable to send status: %v", err) |
569 } | 525 } |
570 } | 526 } |
571 | 527 |
572 func sendSignal(signal string, ch Channel, t *testing.T) { | 528 func sendSignal(signal string, ch *channel, t *testing.T) { |
573 sig := exitSignalMsg{ | 529 sig := exitSignalMsg{ |
574 Signal: signal, | 530 Signal: signal, |
575 CoreDumped: false, | 531 CoreDumped: false, |
576 Errmsg: "Process terminated", | 532 Errmsg: "Process terminated", |
577 Lang: "en-GB-oed", | 533 Lang: "en-GB-oed", |
578 } | 534 } |
579 » if _, err := ch.SendRequest("exit-signal", false, marshal(0, sig)[1:]);
err != nil { | 535 » if _, err := ch.SendRequest("exit-signal", false, marshal(0, sig)); err
!= nil { |
580 t.Errorf("unable to send signal: %v", err) | 536 t.Errorf("unable to send signal: %v", err) |
581 } | 537 } |
582 } | 538 } |
583 | 539 |
584 func discardHandler(ch Channel, t *testing.T) { | 540 func discardHandler(ch *channel, t *testing.T) { |
585 defer ch.Close() | 541 defer ch.Close() |
586 // grow the window to avoid being fooled by | 542 // grow the window to avoid being fooled by |
587 // the initial 1 << 14 window. | 543 // the initial 1 << 14 window. |
588 » ch.(*channel).adjustWindow(1024 * 1024) | 544 » ch.adjustWindow(1024 * 1024) |
589 | 545 |
590 io.Copy(ioutil.Discard, ch) | 546 io.Copy(ioutil.Discard, ch) |
591 } | 547 } |
592 | 548 |
593 func largeSendHandler(ch Channel, t *testing.T) { | 549 func echoHandler(ch *channel, t *testing.T) { |
594 » defer ch.Close() | |
595 » // grow the window to avoid being fooled by | |
596 » // the initial 1 << 14 window. | |
597 » pCh := ch.(*channel) | |
598 » pCh.adjustWindow(1024 * 1024) | |
599 » shell := newServerShell(ch, "> ") | |
600 » readLine(shell, t) | |
601 » // try to send more than the 32k window | |
602 » // will allow | |
603 » if err := pCh.writePacket(make([]byte, 128*1024)); err == nil { | |
604 » » t.Errorf("wrote packet larger than 32k") | |
605 » } | |
606 } | |
607 | |
608 func echoHandler(ch Channel, t *testing.T) { | |
609 defer ch.Close() | 550 defer ch.Close() |
610 if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err
!= nil { | 551 if n, err := copyNRandomly("echohandler", ch, ch, windowTestBytes); err
!= nil { |
611 t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTes
tBytes, err) | 552 t.Errorf("short write, wrote %d, expected %d: %v ", n, windowTes
tBytes, err) |
612 } | 553 } |
613 } | 554 } |
614 | 555 |
615 // copyNRandomly copies n bytes from src to dst. It uses a variable, and random, | 556 // copyNRandomly copies n bytes from src to dst. It uses a variable, and random, |
616 // buffer size to exercise more code paths. | 557 // buffer size to exercise more code paths. |
617 func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, erro
r) { | 558 func copyNRandomly(title string, dst io.Writer, src io.Reader, n int) (int, erro
r) { |
618 var ( | 559 var ( |
(...skipping 16 matching lines...) Expand all Loading... |
635 if nr != nw { | 576 if nr != nw { |
636 return written, io.ErrShortWrite | 577 return written, io.ErrShortWrite |
637 } | 578 } |
638 if er != nil && er != io.EOF { | 579 if er != nil && er != io.EOF { |
639 return written, er | 580 return written, er |
640 } | 581 } |
641 } | 582 } |
642 return written, nil | 583 return written, nil |
643 } | 584 } |
644 | 585 |
645 func channelKeepaliveSender(ch Channel, t *testing.T) { | 586 func channelKeepaliveSender(ch *channel, t *testing.T) { |
646 defer ch.Close() | 587 defer ch.Close() |
647 shell := newServerShell(ch, "> ") | 588 shell := newServerShell(ch, "> ") |
648 readLine(shell, t) | 589 readLine(shell, t) |
649 if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err !=
nil { | 590 if _, err := ch.SendRequest("keepalive@openssh.com", true, nil); err !=
nil { |
650 t.Errorf("unable to send channel keepalive request: %v", err) | 591 t.Errorf("unable to send channel keepalive request: %v", err) |
651 } | 592 } |
652 sendStatus(0, ch, t) | 593 sendStatus(0, ch, t) |
653 } | 594 } |
| 595 |
| 596 func TestClientWriteEOF(t *testing.T) { |
| 597 conn := dial(simpleEchoHandler, t) |
| 598 defer conn.Close() |
| 599 |
| 600 session, err := conn.NewSession() |
| 601 if err != nil { |
| 602 t.Fatal(err) |
| 603 } |
| 604 defer session.Close() |
| 605 stdin, err := session.StdinPipe() |
| 606 if err != nil { |
| 607 t.Fatalf("StdinPipe failed: %v", err) |
| 608 } |
| 609 stdout, err := session.StdoutPipe() |
| 610 if err != nil { |
| 611 t.Fatalf("StdoutPipe failed: %v", err) |
| 612 } |
| 613 |
| 614 data := []byte(`0000`) |
| 615 _, err = stdin.Write(data) |
| 616 if err != nil { |
| 617 t.Fatalf("Write failed: %v", err) |
| 618 } |
| 619 stdin.Close() |
| 620 |
| 621 res, err := ioutil.ReadAll(stdout) |
| 622 if err != nil { |
| 623 t.Fatalf("Read failed: %v", err) |
| 624 } |
| 625 |
| 626 if !bytes.Equal(data, res) { |
| 627 t.Fatalf("Read differed from write, wrote: %v, read: %v", data,
res) |
| 628 } |
| 629 } |
| 630 |
| 631 func simpleEchoHandler(ch *channel, t *testing.T) { |
| 632 defer ch.Close() |
| 633 data, err := ioutil.ReadAll(ch) |
| 634 if err != nil { |
| 635 t.Errorf("handler read error: %v", err) |
| 636 } |
| 637 _, err = ch.Write(data) |
| 638 if err != nil { |
| 639 t.Errorf("handler write error: %v", err) |
| 640 } |
| 641 } |
LEFT | RIGHT |