Index: ssh/session.go |
=================================================================== |
--- a/ssh/session.go |
+++ b/ssh/session.go |
@@ -129,128 +129,126 @@ |
Stdout io.Writer |
Stderr io.Writer |
- *clientChan // the channel backing this session |
- |
- started bool // true once Start, Run or Shell is invoked. |
+ ch Channel // the channel backing this session |
+ started bool // true once Start, Run or Shell is invoked. |
copyFuncs []func() error |
errors chan error // one send per copyFunc |
// true if pipe method is active |
stdinpipe, stdoutpipe, stderrpipe bool |
+ |
+ // stdinPipeWriter is non-nil if StdinPipe has not been called |
+ // and Stdin was specified by the user; it is the write end of |
+ // a pipe connecting Session.Stdin to the stdin channel. |
+ stdinPipeWriter io.WriteCloser |
+ |
+ exitStatus chan error |
+} |
+ |
+// SendRequest sends an out-of-band channel request on the SSH channel |
+// underlying the session. |
+func (s *Session) SendRequest(name string, wantReply bool, payload []byte) (bool, error) { |
+ return s.ch.SendRequest(name, wantReply, payload) |
+} |
+ |
+func (s *Session) Close() error { |
+ return s.ch.Close() |
} |
// RFC 4254 Section 6.4. |
type setenvRequest struct { |
- PeersId uint32 |
- Request string |
- WantReply bool |
- Name string |
- Value string |
-} |
- |
-// RFC 4254 Section 6.5. |
-type subsystemRequestMsg struct { |
- PeersId uint32 |
- Request string |
- WantReply bool |
- Subsystem string |
+ Name string |
+ Value string |
} |
// Setenv sets an environment variable that will be applied to any |
// command executed by Shell or Run. |
func (s *Session) Setenv(name, value string) error { |
- req := setenvRequest{ |
- PeersId: s.remoteId, |
- Request: "env", |
- WantReply: true, |
- Name: name, |
- Value: value, |
+ msg := setenvRequest{ |
+ Name: name, |
+ Value: value, |
} |
- if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { |
- return err |
+ ok, err := s.ch.SendRequest("env", true, Marshal(&msg)) |
+ if err == nil && !ok { |
+ err = errors.New("ssh: setenv failed") |
} |
- return s.waitForResponse() |
+ return err |
} |
// RFC 4254 Section 6.2. |
type ptyRequestMsg struct { |
- PeersId uint32 |
- Request string |
- WantReply bool |
- Term string |
- Columns uint32 |
- Rows uint32 |
- Width uint32 |
- Height uint32 |
- Modelist string |
+ Term string |
+ Columns uint32 |
+ Rows uint32 |
+ Width uint32 |
+ Height uint32 |
+ Modelist string |
} |
// RequestPty requests the association of a pty with the session on the remote host. |
func (s *Session) RequestPty(term string, h, w int, termmodes TerminalModes) error { |
var tm []byte |
for k, v := range termmodes { |
- tm = append(tm, k) |
- tm = appendU32(tm, v) |
+ kv := struct { |
+ Key byte |
+ Val uint32 |
+ }{k, v} |
+ |
+ tm = append(tm, Marshal(&kv)...) |
} |
tm = append(tm, tty_OP_END) |
req := ptyRequestMsg{ |
- PeersId: s.remoteId, |
- Request: "pty-req", |
- WantReply: true, |
- Term: term, |
- Columns: uint32(w), |
- Rows: uint32(h), |
- Width: uint32(w * 8), |
- Height: uint32(h * 8), |
- Modelist: string(tm), |
+ Term: term, |
+ Columns: uint32(w), |
+ Rows: uint32(h), |
+ Width: uint32(w * 8), |
+ Height: uint32(h * 8), |
+ Modelist: string(tm), |
} |
- if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { |
- return err |
+ ok, err := s.ch.SendRequest("pty-req", true, Marshal(&req)) |
+ if err == nil && !ok { |
+ err = errors.New("ssh: pty-req failed") |
} |
- return s.waitForResponse() |
+ return err |
+} |
+ |
+// RFC 4254 Section 6.5. |
+type subsystemRequestMsg struct { |
+ Subsystem string |
} |
// RequestSubsystem requests the association of a subsystem with the session on the remote host. |
// A subsystem is a predefined command that runs in the background when the ssh session is initiated |
func (s *Session) RequestSubsystem(subsystem string) error { |
- req := subsystemRequestMsg{ |
- PeersId: s.remoteId, |
- Request: "subsystem", |
- WantReply: true, |
+ msg := subsystemRequestMsg{ |
Subsystem: subsystem, |
} |
- if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { |
- return err |
+ ok, err := s.ch.SendRequest("subsystem", true, Marshal(&msg)) |
+ if err == nil && !ok { |
+ err = errors.New("ssh: subsystem request failed") |
} |
- return s.waitForResponse() |
+ return err |
} |
// RFC 4254 Section 6.9. |
type signalMsg struct { |
- PeersId uint32 |
- Request string |
- WantReply bool |
- Signal string |
+ Signal string |
} |
// Signal sends the given signal to the remote process. |
// sig is one of the SIG* constants. |
func (s *Session) Signal(sig Signal) error { |
- req := signalMsg{ |
- PeersId: s.remoteId, |
- Request: "signal", |
- WantReply: false, |
- Signal: string(sig), |
+ msg := signalMsg{ |
+ Signal: string(sig), |
} |
- return s.writePacket(marshal(msgChannelRequest, req)) |
+ |
+ _, err := s.ch.SendRequest("signal", false, Marshal(&msg)) |
+ return err |
} |
// RFC 4254 Section 6.5. |
type execMsg struct { |
- PeersId uint32 |
- Request string |
- WantReply bool |
- Command string |
+ Command string |
} |
// Start runs cmd on the remote host. Typically, the remote |
@@ -261,17 +259,16 @@ |
return errors.New("ssh: session already started") |
} |
req := execMsg{ |
- PeersId: s.remoteId, |
- Request: "exec", |
- WantReply: true, |
- Command: cmd, |
+ Command: cmd, |
} |
- if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { |
+ |
+ ok, err := s.ch.SendRequest("exec", true, Marshal(&req)) |
+ if err == nil && !ok { |
+ err = fmt.Errorf("ssh: command %v failed", cmd) |
+ } |
+ if err != nil { |
return err |
} |
- if err := s.waitForResponse(); err != nil { |
- return fmt.Errorf("ssh: could not execute command %s: %v", cmd, err) |
- } |
return s.start() |
} |
@@ -339,31 +336,17 @@ |
if s.started { |
return errors.New("ssh: session already started") |
} |
- req := channelRequestMsg{ |
- PeersId: s.remoteId, |
- Request: "shell", |
- WantReply: true, |
+ |
+ ok, err := s.ch.SendRequest("shell", true, nil) |
+ if err == nil && !ok { |
+ return fmt.Errorf("ssh: cound not start shell") |
} |
- if err := s.writePacket(marshal(msgChannelRequest, req)); err != nil { |
+ if err != nil { |
return err |
} |
- if err := s.waitForResponse(); err != nil { |
- return fmt.Errorf("ssh: could not execute shell: %v", err) |
- } |
return s.start() |
} |
-func (s *Session) waitForResponse() error { |
- msg := <-s.msg |
- switch msg.(type) { |
- case *channelRequestSuccessMsg: |
- return nil |
- case *channelRequestFailureMsg: |
- return errors.New("ssh: request failed") |
- } |
- return fmt.Errorf("ssh: unknown packet %T received: %v", msg, msg) |
-} |
- |
func (s *Session) start() error { |
s.started = true |
@@ -394,8 +377,11 @@ |
if !s.started { |
return errors.New("ssh: session not started") |
} |
- waitErr := s.wait() |
+ waitErr := <-s.exitStatus |
+ if s.stdinPipeWriter != nil { |
+ s.stdinPipeWriter.Close() |
+ } |
var copyError error |
for _ = range s.copyFuncs { |
if err := <-s.errors; err != nil && copyError == nil { |
@@ -408,52 +394,35 @@ |
return copyError |
} |
-func (s *Session) wait() error { |
+func (s *Session) wait(reqs <-chan *Request) error { |
wm := Waitmsg{status: -1} |
+ // Wait for msg channel to be closed before returning. |
+ for msg := range reqs { |
+ switch msg.Type { |
+ case "exit-status": |
+ d := msg.Payload |
+ wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3]) |
+ case "exit-signal": |
+ var sigval struct { |
+ Signal string |
+ CoreDumped bool |
+ Error string |
+ Lang string |
+ } |
+ if err := Unmarshal(msg.Payload, &sigval); err != nil { |
+ return err |
+ } |
- // Wait for msg channel to be closed before returning. |
- for msg := range s.msg { |
- switch msg := msg.(type) { |
- case *channelRequestMsg: |
- switch msg.Request { |
- case "exit-status": |
- d := msg.RequestSpecificData |
- wm.status = int(d[0])<<24 | int(d[1])<<16 | int(d[2])<<8 | int(d[3]) |
- case "exit-signal": |
- signal, rest, ok := parseString(msg.RequestSpecificData) |
- if !ok { |
- return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData) |
- } |
- wm.signal = safeString(string(signal)) |
- |
- // skip coreDumped bool |
- if len(rest) == 0 { |
- return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData) |
- } |
- rest = rest[1:] |
- |
- errmsg, rest, ok := parseString(rest) |
- if !ok { |
- return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData) |
- } |
- wm.msg = safeString(string(errmsg)) |
- |
- lang, _, ok := parseString(rest) |
- if !ok { |
- return fmt.Errorf("wait: could not parse request data: %v", msg.RequestSpecificData) |
- } |
- wm.lang = safeString(string(lang)) |
- default: |
- // This handles keepalives and matches |
- // OpenSSH's behaviour. |
- if msg.WantReply { |
- s.writePacket(marshal(msgChannelFailure, channelRequestFailureMsg{ |
- PeersId: s.remoteId, |
- })) |
- } |
+ // Must sanitize strings? |
+ wm.signal = sigval.Signal |
+ wm.msg = sigval.Error |
+ wm.lang = sigval.Lang |
+ default: |
+ // This handles keepalives and matches |
+ // OpenSSH's behaviour. |
+ if msg.WantReply { |
+ msg.Reply(false, nil) |
} |
- default: |
- return fmt.Errorf("wait: unexpected packet %T received: %v", msg, msg) |
} |
} |
if wm.status == 0 { |
@@ -476,12 +445,20 @@ |
if s.stdinpipe { |
return |
} |
+ var stdin io.Reader |
if s.Stdin == nil { |
- s.Stdin = new(bytes.Buffer) |
+ stdin = new(bytes.Buffer) |
+ } else { |
+ r, w := io.Pipe() |
+ go func() { |
+ _, err := io.Copy(w, s.Stdin) |
+ w.CloseWithError(err) |
+ }() |
+ stdin, s.stdinPipeWriter = r, w |
} |
s.copyFuncs = append(s.copyFuncs, func() error { |
- _, err := io.Copy(s.clientChan.stdin, s.Stdin) |
- if err1 := s.clientChan.stdin.Close(); err == nil && err1 != io.EOF { |
+ _, err := io.Copy(s.ch, stdin) |
+ if err1 := s.ch.CloseWrite(); err == nil && err1 != io.EOF { |
err = err1 |
} |
return err |
@@ -496,7 +473,7 @@ |
s.Stdout = ioutil.Discard |
} |
s.copyFuncs = append(s.copyFuncs, func() error { |
- _, err := io.Copy(s.Stdout, s.clientChan.stdout) |
+ _, err := io.Copy(s.Stdout, s.ch) |
return err |
}) |
} |
@@ -509,11 +486,21 @@ |
s.Stderr = ioutil.Discard |
} |
s.copyFuncs = append(s.copyFuncs, func() error { |
- _, err := io.Copy(s.Stderr, s.clientChan.stderr) |
+ _, err := io.Copy(s.Stderr, s.ch.Stderr()) |
return err |
}) |
} |
+// sessionStdin reroutes Close to CloseWrite. |
+type sessionStdin struct { |
+ io.Writer |
+ ch Channel |
+} |
+ |
+func (s *sessionStdin) Close() error { |
+ return s.ch.CloseWrite() |
+} |
+ |
// StdinPipe returns a pipe that will be connected to the |
// remote command's standard input when the command starts. |
func (s *Session) StdinPipe() (io.WriteCloser, error) { |
@@ -524,7 +511,7 @@ |
return nil, errors.New("ssh: StdinPipe after process started") |
} |
s.stdinpipe = true |
- return s.clientChan.stdin, nil |
+ return &sessionStdin{s.ch, s.ch}, nil |
} |
// StdoutPipe returns a pipe that will be connected to the |
@@ -541,7 +528,7 @@ |
return nil, errors.New("ssh: StdoutPipe after process started") |
} |
s.stdoutpipe = true |
- return s.clientChan.stdout, nil |
+ return s.ch, nil |
} |
// StderrPipe returns a pipe that will be connected to the |
@@ -558,28 +545,20 @@ |
return nil, errors.New("ssh: StderrPipe after process started") |
} |
s.stderrpipe = true |
- return s.clientChan.stderr, nil |
+ return s.ch.Stderr(), nil |
} |
-// NewSession returns a new interactive session on the remote host. |
-func (c *ClientConn) NewSession() (*Session, error) { |
- ch := c.newChan(c.transport) |
- if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenMsg{ |
- ChanType: "session", |
- PeersId: ch.localId, |
- PeersWindow: channelWindowSize, |
- MaxPacketSize: channelMaxPacketSize, |
- })); err != nil { |
- c.chanList.remove(ch.localId) |
- return nil, err |
+// newSession returns a new interactive session on the remote host. |
+func newSession(ch Channel, reqs <-chan *Request) (*Session, error) { |
+ s := &Session{ |
+ ch: ch, |
} |
- if err := ch.waitForChannelOpenResponse(); err != nil { |
- c.chanList.remove(ch.localId) |
- return nil, fmt.Errorf("ssh: unable to open session: %v", err) |
- } |
- return &Session{ |
- clientChan: ch, |
- }, nil |
+ s.exitStatus = make(chan error, 1) |
+ go func() { |
+ s.exitStatus <- s.wait(reqs) |
+ }() |
+ |
+ return s, nil |
} |
// An ExitError reports unsuccessful completion of a remote command. |