Index: ssh/channel.go |
=================================================================== |
new file mode 100644 |
--- /dev/null |
+++ b/ssh/channel.go |
@@ -0,0 +1,318 @@ |
+// Copyright 2011 The Go Authors. All rights reserved. |
+// Use of this source code is governed by a BSD-style |
+// license that can be found in the LICENSE file. |
+ |
+package ssh |
+ |
+import ( |
+ "errors" |
+ "io" |
+ "sync" |
+) |
+ |
+// A Channel is an ordered, reliable, duplex stream that is multiplexed over an |
+// SSH connection. |
+type Channel interface { |
+ // Accept accepts the channel creation request. |
+ Accept() error |
+ // Reject rejects the channel creation request. After calling this, no |
+ // other methods on the Channel may be called. If they are then the |
+ // peer is likely to signal a protocol error and drop the connection. |
+ Reject(reason RejectionReason, message string) error |
+ |
+ // Read may return a ChannelRequest as an error. |
+ Read(data []byte) (int, error) |
+ Write(data []byte) (int, error) |
+ Close() error |
+ |
+ // AckRequest either sends an ack or nack to the channel request. |
+ AckRequest(ok bool) error |
+ |
+ // ChannelType returns the type of the channel, as supplied by the |
+ // client. |
+ ChannelType() string |
+ // ExtraData returns the arbitary payload for this channel, as supplied |
+ // by the client. This data is specific to the channel type. |
+ ExtraData() []byte |
+} |
+ |
+// ChannelRequest represents a request sent on a channel, outside of the normal |
+// stream of bytes. It may result from calling Read on a Channel. |
+type ChannelRequest struct { |
+ Request string |
+ WantReply bool |
+ Payload []byte |
+} |
+ |
+func (c ChannelRequest) Error() string { |
+ return "channel request received" |
+} |
+ |
+// RejectionReason is an enumeration used when rejecting channel creation |
+// requests. See RFC 4254, section 5.1. |
+type RejectionReason int |
+ |
+const ( |
+ Prohibited RejectionReason = iota + 1 |
+ ConnectionFailed |
+ UnknownChannelType |
+ ResourceShortage |
+) |
+ |
+type channel struct { |
+ // immutable once created |
+ chanType string |
+ extraData []byte |
+ |
+ theyClosed bool |
+ theySentEOF bool |
+ weClosed bool |
+ dead bool |
+ |
+ serverConn *ServerConn |
+ myId, theirId uint32 |
+ myWindow, theirWindow uint32 |
+ maxPacketSize uint32 |
+ err error |
+ |
+ pendingRequests []ChannelRequest |
+ pendingData []byte |
+ head, length int |
+ |
+ // This lock is inferior to serverConn.lock |
+ lock sync.Mutex |
+ cond *sync.Cond |
+} |
+ |
+func (c *channel) Accept() error { |
+ c.serverConn.lock.Lock() |
+ defer c.serverConn.lock.Unlock() |
+ |
+ if c.serverConn.err != nil { |
+ return c.serverConn.err |
+ } |
+ |
+ confirm := channelOpenConfirmMsg{ |
+ PeersId: c.theirId, |
+ MyId: c.myId, |
+ MyWindow: c.myWindow, |
+ MaxPacketSize: c.maxPacketSize, |
+ } |
+ return c.serverConn.writePacket(marshal(msgChannelOpenConfirm, confirm)) |
+} |
+ |
+func (c *channel) Reject(reason RejectionReason, message string) error { |
+ c.serverConn.lock.Lock() |
+ defer c.serverConn.lock.Unlock() |
+ |
+ if c.serverConn.err != nil { |
+ return c.serverConn.err |
+ } |
+ |
+ reject := channelOpenFailureMsg{ |
+ PeersId: c.theirId, |
+ Reason: uint32(reason), |
+ Message: message, |
+ Language: "en", |
+ } |
+ return c.serverConn.writePacket(marshal(msgChannelOpenFailure, reject)) |
+} |
+ |
+func (c *channel) handlePacket(packet interface{}) { |
+ c.lock.Lock() |
+ defer c.lock.Unlock() |
+ |
+ switch packet := packet.(type) { |
+ case *channelRequestMsg: |
+ req := ChannelRequest{ |
+ Request: packet.Request, |
+ WantReply: packet.WantReply, |
+ Payload: packet.RequestSpecificData, |
+ } |
+ |
+ c.pendingRequests = append(c.pendingRequests, req) |
+ c.cond.Signal() |
+ case *channelCloseMsg: |
+ c.theyClosed = true |
+ c.cond.Signal() |
+ case *channelEOFMsg: |
+ c.theySentEOF = true |
+ c.cond.Signal() |
+ default: |
+ panic("unknown packet type") |
+ } |
+} |
+ |
+func (c *channel) handleData(data []byte) { |
+ c.lock.Lock() |
+ defer c.lock.Unlock() |
+ |
+ // The other side should never send us more than our window. |
+ if len(data)+c.length > len(c.pendingData) { |
+ // TODO(agl): we should tear down the channel with a protocol |
+ // error. |
+ return |
+ } |
+ |
+ c.myWindow -= uint32(len(data)) |
+ for i := 0; i < 2; i++ { |
+ tail := c.head + c.length |
+ if tail > len(c.pendingData) { |
+ tail -= len(c.pendingData) |
+ } |
+ n := copy(c.pendingData[tail:], data) |
+ data = data[n:] |
+ c.length += n |
+ } |
+ |
+ c.cond.Signal() |
+} |
+ |
+func (c *channel) Read(data []byte) (n int, err error) { |
+ c.lock.Lock() |
+ defer c.lock.Unlock() |
+ |
+ if c.err != nil { |
+ return 0, c.err |
+ } |
+ |
+ if c.myWindow <= uint32(len(c.pendingData))/2 { |
+ packet := marshal(msgChannelWindowAdjust, windowAdjustMsg{ |
+ PeersId: c.theirId, |
+ AdditionalBytes: uint32(len(c.pendingData)) - c.myWindow, |
+ }) |
+ if err := c.serverConn.writePacket(packet); err != nil { |
+ return 0, err |
+ } |
+ } |
+ |
+ for { |
+ if c.theySentEOF || c.theyClosed || c.dead { |
+ return 0, io.EOF |
+ } |
+ |
+ if len(c.pendingRequests) > 0 { |
+ req := c.pendingRequests[0] |
+ if len(c.pendingRequests) == 1 { |
+ c.pendingRequests = nil |
+ } else { |
+ oldPendingRequests := c.pendingRequests |
+ c.pendingRequests = make([]ChannelRequest, len(oldPendingRequests)-1) |
+ copy(c.pendingRequests, oldPendingRequests[1:]) |
+ } |
+ |
+ return 0, req |
+ } |
+ |
+ if c.length > 0 { |
+ tail := c.head + c.length |
+ if tail > len(c.pendingData) { |
+ tail -= len(c.pendingData) |
+ } |
+ n = copy(data, c.pendingData[c.head:tail]) |
+ c.head += n |
+ c.length -= n |
+ if c.head == len(c.pendingData) { |
+ c.head = 0 |
+ } |
+ return |
+ } |
+ |
+ c.cond.Wait() |
+ } |
+ |
+ panic("unreachable") |
+} |
+ |
+func (c *channel) Write(data []byte) (n int, err error) { |
+ for len(data) > 0 { |
+ c.lock.Lock() |
+ if c.dead || c.weClosed { |
+ return 0, io.EOF |
+ } |
+ |
+ if c.theirWindow == 0 { |
+ c.cond.Wait() |
+ continue |
+ } |
+ c.lock.Unlock() |
+ |
+ todo := data |
+ if uint32(len(todo)) > c.theirWindow { |
+ todo = todo[:c.theirWindow] |
+ } |
+ |
+ packet := make([]byte, 1+4+4+len(todo)) |
+ packet[0] = msgChannelData |
+ packet[1] = byte(c.theirId >> 24) |
+ packet[2] = byte(c.theirId >> 16) |
+ packet[3] = byte(c.theirId >> 8) |
+ packet[4] = byte(c.theirId) |
+ packet[5] = byte(len(todo) >> 24) |
+ packet[6] = byte(len(todo) >> 16) |
+ packet[7] = byte(len(todo) >> 8) |
+ packet[8] = byte(len(todo)) |
+ copy(packet[9:], todo) |
+ |
+ c.serverConn.lock.Lock() |
+ if err = c.serverConn.writePacket(packet); err != nil { |
+ c.serverConn.lock.Unlock() |
+ return |
+ } |
+ c.serverConn.lock.Unlock() |
+ |
+ n += len(todo) |
+ data = data[len(todo):] |
+ } |
+ |
+ return |
+} |
+ |
+func (c *channel) Close() error { |
+ c.serverConn.lock.Lock() |
+ defer c.serverConn.lock.Unlock() |
+ |
+ if c.serverConn.err != nil { |
+ return c.serverConn.err |
+ } |
+ |
+ if c.weClosed { |
+ return errors.New("ssh: channel already closed") |
+ } |
+ c.weClosed = true |
+ |
+ closeMsg := channelCloseMsg{ |
+ PeersId: c.theirId, |
+ } |
+ return c.serverConn.writePacket(marshal(msgChannelClose, closeMsg)) |
+} |
+ |
+func (c *channel) AckRequest(ok bool) error { |
+ c.serverConn.lock.Lock() |
+ defer c.serverConn.lock.Unlock() |
+ |
+ if c.serverConn.err != nil { |
+ return c.serverConn.err |
+ } |
+ |
+ if ok { |
+ ack := channelRequestSuccessMsg{ |
+ PeersId: c.theirId, |
+ } |
+ return c.serverConn.writePacket(marshal(msgChannelSuccess, ack)) |
+ } else { |
+ ack := channelRequestFailureMsg{ |
+ PeersId: c.theirId, |
+ } |
+ return c.serverConn.writePacket(marshal(msgChannelFailure, ack)) |
+ } |
+ panic("unreachable") |
+} |
+ |
+func (c *channel) ChannelType() string { |
+ return c.chanType |
+} |
+ |
+func (c *channel) ExtraData() []byte { |
+ return c.extraData |
+} |