LEFT | RIGHT |
(no file at all) | |
| 1 // Copyright 2013 The Go Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style |
| 3 // license that can be found in the LICENSE file. |
| 4 |
| 5 package ssh |
| 6 |
| 7 import ( |
| 8 "encoding/binary" |
| 9 "errors" |
| 10 "fmt" |
| 11 "io" |
| 12 "log" |
| 13 "sync" |
| 14 "sync/atomic" |
| 15 ) |
| 16 |
| 17 const debug = false |
| 18 |
| 19 // Thread safe channel list. |
| 20 type chanList struct { |
| 21 // protects concurrent access to chans |
| 22 sync.Mutex |
| 23 |
| 24 // chans are indexed by the local id of the channel, which the |
| 25 // other side should send in the PeersId field. |
| 26 chans []*channel |
| 27 |
| 28 // This is a debugging aid: it offsets all IDs by this |
| 29 // amount. This helps distinguish otherwise identical |
| 30 // server/client muxes |
| 31 offset uint32 |
| 32 } |
| 33 |
| 34 // Assigns a channel ID to the given channel. |
| 35 func (c *chanList) add(ch *channel) uint32 { |
| 36 c.Lock() |
| 37 defer c.Unlock() |
| 38 for i := range c.chans { |
| 39 if c.chans[i] == nil { |
| 40 c.chans[i] = ch |
| 41 return uint32(i) + c.offset |
| 42 } |
| 43 } |
| 44 c.chans = append(c.chans, ch) |
| 45 return uint32(len(c.chans)-1) + c.offset |
| 46 } |
| 47 |
| 48 // getChan returns the channel for the given ID. |
| 49 func (c *chanList) getChan(id uint32) *channel { |
| 50 id -= c.offset |
| 51 |
| 52 c.Lock() |
| 53 defer c.Unlock() |
| 54 if id < uint32(len(c.chans)) { |
| 55 return c.chans[id] |
| 56 } |
| 57 return nil |
| 58 } |
| 59 |
| 60 func (c *chanList) remove(id uint32) { |
| 61 id -= c.offset |
| 62 c.Lock() |
| 63 if id < uint32(len(c.chans)) { |
| 64 c.chans[id] = nil |
| 65 } |
| 66 c.Unlock() |
| 67 } |
| 68 |
| 69 // dropAll drops all remaining channels |
| 70 func (c *chanList) dropAll() []*channel { |
| 71 c.Lock() |
| 72 defer c.Unlock() |
| 73 var r []*channel |
| 74 |
| 75 for _, ch := range c.chans { |
| 76 if ch == nil { |
| 77 continue |
| 78 } |
| 79 r = append(r, ch) |
| 80 } |
| 81 c.chans = nil |
| 82 return r |
| 83 } |
| 84 |
| 85 // mux represents the state for the SSH connection protocol, which |
| 86 // multiplexes many channels onto a single packet transport. |
| 87 type mux struct { |
| 88 conn packetConn |
| 89 chanList chanList |
| 90 |
| 91 incomingChannels chan *channel |
| 92 |
| 93 globalSentMu sync.Mutex |
| 94 globalResponses chan interface{} |
| 95 incomingRequests chan *ChannelRequest |
| 96 } |
| 97 |
| 98 // Each new chanList instantiation has a different offset. |
| 99 var globalOff uint32 |
| 100 |
| 101 // newMux returns a mux that runs over the given connection. Caller |
| 102 // should run Loop for returned mux. |
| 103 func newMux(p packetConn) *mux { |
| 104 m := &mux{ |
| 105 conn: p, |
| 106 incomingChannels: make(chan *channel, 16), |
| 107 globalResponses: make(chan interface{}, 1), |
| 108 incomingRequests: make(chan *ChannelRequest, 16), |
| 109 } |
| 110 m.chanList.offset = atomic.AddUint32(&globalOff, 1) |
| 111 return m |
| 112 } |
| 113 |
| 114 func (m *mux) sendMessage(code byte, msg interface{}) error { |
| 115 p := marshal(code, msg) |
| 116 return m.conn.writePacket(p) |
| 117 } |
| 118 |
| 119 // SendRequest sends a global request. If wantReply is set, the |
| 120 // return includes success status and extra data. See also RFC4254 section 4 |
| 121 func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []
byte, error) { |
| 122 if wantReply { |
| 123 m.globalSentMu.Lock() |
| 124 defer m.globalSentMu.Unlock() |
| 125 } |
| 126 |
| 127 if err := m.sendMessage(msgGlobalRequest, |
| 128 globalRequestMsg{ |
| 129 Type: name, |
| 130 WantReply: wantReply, |
| 131 Data: payload, |
| 132 }); err != nil { |
| 133 return false, nil, err |
| 134 } |
| 135 |
| 136 if wantReply { |
| 137 msg, ok := <-m.globalResponses |
| 138 if !ok { |
| 139 return false, nil, io.EOF |
| 140 } |
| 141 switch msg := msg.(type) { |
| 142 case *globalRequestFailureMsg: |
| 143 return false, msg.Data, nil |
| 144 case *globalRequestSuccessMsg: |
| 145 return true, msg.Data, nil |
| 146 default: |
| 147 return false, nil, fmt.Errorf("ssh: unexpected response
%#v", msg) |
| 148 } |
| 149 } |
| 150 |
| 151 return false, nil, nil |
| 152 } |
| 153 |
| 154 // AckRequest must be called after processing a global request that |
| 155 // has WantReply set. |
| 156 func (m *mux) AckRequest(ok bool, data []byte) error { |
| 157 if ok { |
| 158 return m.sendMessage(msgRequestSuccess, |
| 159 globalRequestSuccessMsg{Data: data}) |
| 160 } |
| 161 return m.sendMessage(msgRequestFailure, globalRequestFailureMsg{Data: da
ta}) |
| 162 } |
| 163 |
| 164 // TODO(hanwen): Disconnect is a transport layer message. We should |
| 165 // probably send and receive Disconnect somewhere in the transport |
| 166 // code. |
| 167 |
| 168 // Disconnect sends a disconnect message. |
| 169 func (m *mux) Disconnect(reason uint32, message string) error { |
| 170 return m.sendMessage(msgDisconnect, disconnectMsg{ |
| 171 Reason: reason, |
| 172 Message: message, |
| 173 }) |
| 174 } |
| 175 |
| 176 // Loop runs the connection machine. It will process packets until an |
| 177 // error is encountered, returning that error. When the loop exits, |
| 178 // the connection is closed. |
| 179 func (m *mux) Loop() error { |
| 180 var err error |
| 181 for err == nil { |
| 182 err = m.onePacket() |
| 183 } |
| 184 if debug && err != nil { |
| 185 log.Println("loop exit", err) |
| 186 } |
| 187 |
| 188 for _, ch := range m.chanList.dropAll() { |
| 189 ch.mu.Lock() |
| 190 ch.sentClose = true |
| 191 ch.mu.Unlock() |
| 192 ch.pending.eof() |
| 193 ch.extPending.eof() |
| 194 close(ch.incomingRequests) |
| 195 // ch.msg is otherwise only called from onePacket, so |
| 196 // this is safe. |
| 197 close(ch.msg) |
| 198 } |
| 199 |
| 200 close(m.incomingChannels) |
| 201 close(m.incomingRequests) |
| 202 close(m.globalResponses) |
| 203 |
| 204 m.conn.Close() |
| 205 return err |
| 206 } |
| 207 |
| 208 // onePacket reads and processes one packet. |
| 209 func (m *mux) onePacket() error { |
| 210 packet, err := m.conn.readPacket() |
| 211 if err != nil { |
| 212 return err |
| 213 } |
| 214 |
| 215 if debug { |
| 216 p, _ := decode(packet) |
| 217 log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset,
packet[0], p, len(packet)) |
| 218 } |
| 219 |
| 220 switch packet[0] { |
| 221 case msgDisconnect: |
| 222 return m.handleDisconnect(packet) |
| 223 case msgChannelOpen: |
| 224 return m.handleChannelOpen(packet) |
| 225 case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: |
| 226 return m.handleGlobalPacket(packet) |
| 227 } |
| 228 |
| 229 // assume a channel packet. |
| 230 if len(packet) < 5 { |
| 231 return ParseError{packet[0]} |
| 232 } |
| 233 ID := binary.BigEndian.Uint32(packet[1:]) |
| 234 ch := m.chanList.getChan(ID) |
| 235 if ch == nil { |
| 236 return fmt.Errorf("invalid channel %d", ID) |
| 237 } |
| 238 |
| 239 return ch.handlePacket(packet) |
| 240 } |
| 241 |
| 242 func (m *mux) handleDisconnect(packet []byte) error { |
| 243 var d disconnectMsg |
| 244 if err := unmarshal(&d, packet, msgDisconnect); err != nil { |
| 245 return err |
| 246 } |
| 247 |
| 248 if debug { |
| 249 // TODO(hanwen): the disconnect message has more |
| 250 // diagnostics. We could try to return those? |
| 251 log.Printf("caught disconnect: %v", d) |
| 252 } |
| 253 return io.EOF |
| 254 } |
| 255 |
| 256 func (m *mux) handleGlobalPacket(packet []byte) error { |
| 257 msg, err := decode(packet) |
| 258 if err != nil { |
| 259 return err |
| 260 } |
| 261 |
| 262 switch msg := msg.(type) { |
| 263 case *globalRequestMsg: |
| 264 m.incomingRequests <- &ChannelRequest{ |
| 265 msg.Type, |
| 266 msg.WantReply, |
| 267 msg.Data, |
| 268 } |
| 269 case *globalRequestSuccessMsg, *globalRequestFailureMsg: |
| 270 m.globalResponses <- msg |
| 271 default: |
| 272 panic(fmt.Sprintf("not a global message %#v", msg)) |
| 273 } |
| 274 |
| 275 return nil |
| 276 } |
| 277 |
| 278 const minPacketLength = 0 |
| 279 |
| 280 // handleChannelOpen schedules a channel to be Accept()ed. |
| 281 func (m *mux) handleChannelOpen(packet []byte) error { |
| 282 var msg channelOpenMsg |
| 283 if err := unmarshal(&msg, packet, msgChannelOpen); err != nil { |
| 284 return err |
| 285 } |
| 286 |
| 287 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { |
| 288 failMsg := channelOpenFailureMsg{ |
| 289 PeersId: msg.PeersId, |
| 290 Reason: ConnectionFailed, |
| 291 Message: "invalid request", |
| 292 Language: "en_US.UTF-8", |
| 293 } |
| 294 return m.sendMessage(msgChannelOpenFailure, failMsg) |
| 295 } |
| 296 |
| 297 c := newChannel(msg.ChanType, msg.TypeSpecificData) |
| 298 c.mux = m |
| 299 c.remoteId = msg.PeersId |
| 300 c.maxPacket = msg.MaxPacketSize |
| 301 c.remoteWin.add(msg.PeersWindow) |
| 302 c.myWindow = defaultWindowSize |
| 303 c.localId = m.chanList.add(c) |
| 304 m.incomingChannels <- c |
| 305 return nil |
| 306 } |
| 307 |
| 308 // OpenChannelError is returned the other side rejects our OpenChannel |
| 309 // request. |
| 310 type OpenChannelError struct { |
| 311 Reason RejectionReason |
| 312 Message string |
| 313 } |
| 314 |
| 315 func (e *OpenChannelError) Error() string { |
| 316 return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) |
| 317 } |
| 318 |
| 319 // OpenChannel asks for a new channel. If the other side rejects, it |
| 320 // returns a *OpenChannelError. |
| 321 func (m *mux) OpenChannel(chanType string, extra []byte) (*channel, error) { |
| 322 ch := newChannel(chanType, extra) |
| 323 ch.mux = m |
| 324 |
| 325 // As per RFC 4253 6.1, 32k is also the minimum. |
| 326 ch.maxPacket = 1 << 15 |
| 327 ch.myWindow = defaultWindowSize |
| 328 ch.localId = m.chanList.add(ch) |
| 329 |
| 330 open := channelOpenMsg{ |
| 331 ChanType: chanType, |
| 332 PeersWindow: ch.myWindow, |
| 333 MaxPacketSize: ch.maxPacket, |
| 334 TypeSpecificData: extra, |
| 335 PeersId: ch.localId, |
| 336 } |
| 337 if err := m.sendMessage(msgChannelOpen, open); err != nil { |
| 338 return nil, err |
| 339 } |
| 340 |
| 341 switch msg := (<-ch.msg).(type) { |
| 342 case *channelOpenConfirmMsg: |
| 343 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<
<31 { |
| 344 return nil, errors.New("ssh: invalid MaxPacketSize from
peer") |
| 345 } |
| 346 // fixup remoteId field |
| 347 ch.remoteId = msg.MyId |
| 348 ch.maxPacket = msg.MaxPacketSize |
| 349 ch.remoteWin.add(msg.MyWindow) |
| 350 ch.decided = true |
| 351 return ch, nil |
| 352 case *channelOpenFailureMsg: |
| 353 m.chanList.remove(open.PeersId) |
| 354 return nil, &OpenChannelError{msg.Reason, msg.Message} |
| 355 default: |
| 356 return nil, fmt.Errorf("ssh: unexpected packet %T", msg) |
| 357 } |
| 358 } |
LEFT | RIGHT |