LEFT | RIGHT |
| 1 // Copyright 2013 The Go Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style |
| 3 // license that can be found in the LICENSE file. |
| 4 |
1 package ssh | 5 package ssh |
2 | 6 |
3 import ( | 7 import ( |
4 "encoding/binary" | 8 "encoding/binary" |
5 "errors" | 9 "errors" |
6 "fmt" | 10 "fmt" |
7 "io" | 11 "io" |
8 "log" | 12 "log" |
9 "sync" | 13 "sync" |
| 14 "sync/atomic" |
10 ) | 15 ) |
| 16 |
| 17 const debug = false |
11 | 18 |
12 // Thread safe channel list. | 19 // Thread safe channel list. |
13 type chanList struct { | 20 type chanList struct { |
14 // protects concurrent access to chans | 21 // protects concurrent access to chans |
15 sync.Mutex | 22 sync.Mutex |
16 » // chans are indexed by the local id of the channel, nChannel.localId. | 23 |
17 » // The PeersId value of messages received by ClientConn.mainLoop is | 24 » // chans are indexed by the local id of the channel, which the |
18 » // used to locate the right local nChannel in this slice. | 25 » // other side should send in the PeersId field. |
19 chans []*channel | 26 chans []*channel |
20 | 27 |
21 // This is a debugging aid: it offsets all IDs by this | 28 // This is a debugging aid: it offsets all IDs by this |
22 // amount. This helps distinguish otherwise identical | 29 // amount. This helps distinguish otherwise identical |
23 // server/client muxes | 30 // server/client muxes |
24 offset uint32 | 31 offset uint32 |
25 } | 32 } |
26 | 33 |
27 // Assigns a channel ID to the given channel. | 34 // Assigns a channel ID to the given channel. |
28 func (c *chanList) add(ch *channel) uint32 { | 35 func (c *chanList) add(ch *channel) uint32 { |
(...skipping 39 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
68 for _, ch := range c.chans { | 75 for _, ch := range c.chans { |
69 if ch == nil { | 76 if ch == nil { |
70 continue | 77 continue |
71 } | 78 } |
72 r = append(r, ch) | 79 r = append(r, ch) |
73 } | 80 } |
74 c.chans = nil | 81 c.chans = nil |
75 return r | 82 return r |
76 } | 83 } |
77 | 84 |
78 // mux contains the state for the SSH connection protocol, which | 85 // mux represents the state for the SSH connection protocol, which |
79 // multiplexes many channels onto a single packet transport. | 86 // multiplexes many channels onto a single packet transport. |
80 type mux struct { | 87 type mux struct { |
81 conn packetConn | 88 conn packetConn |
82 chanList chanList | 89 chanList chanList |
83 | 90 |
84 » openedChans chan *channel | 91 » incomingChannels chan *channel |
85 | 92 |
86 » globalSentMu sync.Mutex | 93 » globalSentMu sync.Mutex |
87 » globalResponses chan interface{} | 94 » globalResponses chan interface{} |
88 » globalReceived chan *ChannelRequest | 95 » incomingRequests chan *ChannelRequest |
89 } | 96 } |
90 | 97 |
| 98 // Each new chanList instantiation has a different offset. |
91 var globalOff uint32 | 99 var globalOff uint32 |
92 | 100 |
93 // newMux returns a mux that runs over the given connection. | 101 // newMux returns a mux that runs over the given connection. Caller |
| 102 // should run Loop for returned mux. |
94 func newMux(p packetConn) *mux { | 103 func newMux(p packetConn) *mux { |
95 m := &mux{ | 104 m := &mux{ |
96 » » conn: p, | 105 » » conn: p, |
97 » » openedChans: make(chan *channel, 16), | 106 » » incomingChannels: make(chan *channel, 16), |
98 » » globalResponses: make(chan interface{}, 1), | 107 » » globalResponses: make(chan interface{}, 1), |
99 » » globalReceived: make(chan *ChannelRequest, 16), | 108 » » incomingRequests: make(chan *ChannelRequest, 16), |
100 » } | 109 » } |
101 » m.chanList.offset = globalOff | 110 » m.chanList.offset = atomic.AddUint32(&globalOff, 1) |
102 » globalOff++ | |
103 | |
104 return m | 111 return m |
105 } | 112 } |
106 | 113 |
107 func (m *mux) sendMessage(code byte, msg interface{}) error { | 114 func (m *mux) sendMessage(code byte, msg interface{}) error { |
108 p := marshal(code, msg) | 115 p := marshal(code, msg) |
109 return m.conn.writePacket(p) | 116 return m.conn.writePacket(p) |
110 } | 117 } |
111 | 118 |
112 // SendRequest sends a global request. If wantReply is set, the | 119 // SendRequest sends a global request. If wantReply is set, the |
113 // return includes success status and extra data. See also RFC4254 section 4 | 120 // return includes success status and extra data. See also RFC4254 section 4 |
(...skipping 15 matching lines...) Expand all Loading... |
129 if wantReply { | 136 if wantReply { |
130 msg, ok := <-m.globalResponses | 137 msg, ok := <-m.globalResponses |
131 if !ok { | 138 if !ok { |
132 return false, nil, io.EOF | 139 return false, nil, io.EOF |
133 } | 140 } |
134 switch msg := msg.(type) { | 141 switch msg := msg.(type) { |
135 case *globalRequestFailureMsg: | 142 case *globalRequestFailureMsg: |
136 return false, msg.Data, nil | 143 return false, msg.Data, nil |
137 case *globalRequestSuccessMsg: | 144 case *globalRequestSuccessMsg: |
138 return true, msg.Data, nil | 145 return true, msg.Data, nil |
| 146 default: |
| 147 return false, nil, fmt.Errorf("ssh: unexpected response
%#v", msg) |
139 } | 148 } |
140 } | 149 } |
141 | 150 |
142 return false, nil, nil | 151 return false, nil, nil |
143 } | 152 } |
144 | 153 |
145 // GlobalReceived returns the channel on which incoming global | 154 // AckRequest must be called after processing a global request that |
146 // requests are handled. | 155 // has WantReply set. |
147 func (m *mux) ReceivedRequests() <-chan *ChannelRequest { | 156 func (m *mux) AckRequest(ok bool, data []byte) error { |
148 » return m.globalReceived | |
149 } | |
150 | |
151 // AckGlobalRequest must be called for a global request with WantReply | |
152 // set. | |
153 func (m *mux) AckGlobalRequest(ok bool, data []byte) error { | |
154 if ok { | 157 if ok { |
155 return m.sendMessage(msgRequestSuccess, | 158 return m.sendMessage(msgRequestSuccess, |
156 globalRequestSuccessMsg{Data: data}) | 159 globalRequestSuccessMsg{Data: data}) |
157 } | 160 } |
158 return m.sendMessage(msgRequestFailure, globalRequestFailureMsg{Data: da
ta}) | 161 return m.sendMessage(msgRequestFailure, globalRequestFailureMsg{Data: da
ta}) |
159 } | 162 } |
160 | 163 |
| 164 // TODO(hanwen): Disconnect is a transport layer message. We should |
| 165 // probably send and receive Disconnect somewhere in the transport |
| 166 // code. |
| 167 |
| 168 // Disconnect sends a disconnect message. |
161 func (m *mux) Disconnect(reason uint32, message string) error { | 169 func (m *mux) Disconnect(reason uint32, message string) error { |
162 return m.sendMessage(msgDisconnect, disconnectMsg{ | 170 return m.sendMessage(msgDisconnect, disconnectMsg{ |
163 Reason: reason, | 171 Reason: reason, |
164 Message: message, | 172 Message: message, |
165 }) | 173 }) |
166 } | 174 } |
167 | 175 |
168 // Loop runs the connection machine. It will process packets until an | 176 // Loop runs the connection machine. It will process packets until an |
169 // error is encountered, returning that error. | 177 // error is encountered, returning that error. When the loop exits, |
| 178 // the connection is closed. |
170 func (m *mux) Loop() error { | 179 func (m *mux) Loop() error { |
171 var err error | 180 var err error |
172 » for { | 181 » for err == nil { |
173 err = m.onePacket() | 182 err = m.onePacket() |
174 » » if err != nil { | 183 » } |
175 » » » if debug { | 184 » if debug && err != nil { |
176 » » » » log.Println("loop exit", err) | 185 » » log.Println("loop exit", err) |
177 » » » } | |
178 » » » break | |
179 » » } | |
180 } | 186 } |
181 | 187 |
182 for _, ch := range m.chanList.dropAll() { | 188 for _, ch := range m.chanList.dropAll() { |
183 ch.mu.Lock() | 189 ch.mu.Lock() |
184 ch.sentClose = true | 190 ch.sentClose = true |
185 ch.mu.Unlock() | 191 ch.mu.Unlock() |
186 ch.pending.eof() | 192 ch.pending.eof() |
187 ch.extPending.eof() | 193 ch.extPending.eof() |
| 194 close(ch.incomingRequests) |
188 // ch.msg is otherwise only called from onePacket, so | 195 // ch.msg is otherwise only called from onePacket, so |
189 // this is safe. | 196 // this is safe. |
190 close(ch.pendingRequests) | |
191 close(ch.msg) | 197 close(ch.msg) |
192 } | 198 } |
193 | 199 |
194 » close(m.openedChans) | 200 » close(m.incomingChannels) |
195 » close(m.globalReceived) | 201 » close(m.incomingRequests) |
196 close(m.globalResponses) | 202 close(m.globalResponses) |
197 » // TODO(hanwen): should we close packetConn? | 203 |
198 » // m.conn.Close() | 204 » m.conn.Close() |
199 return err | 205 return err |
200 } | 206 } |
201 | 207 |
202 var debug = false | 208 // onePacket reads and processes one packet. |
203 | |
204 func (m *mux) onePacket() error { | 209 func (m *mux) onePacket() error { |
205 packet, err := m.conn.readPacket() | 210 packet, err := m.conn.readPacket() |
206 if err != nil { | 211 if err != nil { |
207 return err | 212 return err |
208 } | 213 } |
209 | 214 |
210 if debug { | 215 if debug { |
211 p, _ := decode(packet) | 216 p, _ := decode(packet) |
212 » » log.Printf("decoding(%c): %d %#v - %d bytes", m.chanList.offset,
packet[0], p, len(packet)) | 217 » » log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset,
packet[0], p, len(packet)) |
213 } | 218 } |
214 | 219 |
215 switch packet[0] { | 220 switch packet[0] { |
216 case msgDisconnect: | 221 case msgDisconnect: |
217 return m.handleDisconnect(packet) | 222 return m.handleDisconnect(packet) |
218 case msgChannelOpen: | 223 case msgChannelOpen: |
219 return m.handleChannelOpen(packet) | 224 return m.handleChannelOpen(packet) |
220 case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: | 225 case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: |
221 return m.handleGlobalPacket(packet) | 226 return m.handleGlobalPacket(packet) |
222 } | 227 } |
(...skipping 26 matching lines...) Expand all Loading... |
249 } | 254 } |
250 | 255 |
251 func (m *mux) handleGlobalPacket(packet []byte) error { | 256 func (m *mux) handleGlobalPacket(packet []byte) error { |
252 msg, err := decode(packet) | 257 msg, err := decode(packet) |
253 if err != nil { | 258 if err != nil { |
254 return err | 259 return err |
255 } | 260 } |
256 | 261 |
257 switch msg := msg.(type) { | 262 switch msg := msg.(type) { |
258 case *globalRequestMsg: | 263 case *globalRequestMsg: |
259 » » m.globalReceived <- &ChannelRequest{ | 264 » » m.incomingRequests <- &ChannelRequest{ |
260 msg.Type, | 265 msg.Type, |
261 msg.WantReply, | 266 msg.WantReply, |
262 msg.Data, | 267 msg.Data, |
263 } | 268 } |
264 case *globalRequestSuccessMsg, *globalRequestFailureMsg: | 269 case *globalRequestSuccessMsg, *globalRequestFailureMsg: |
265 m.globalResponses <- msg | 270 m.globalResponses <- msg |
266 default: | 271 default: |
267 panic(fmt.Sprintf("not a global message %#v", msg)) | 272 panic(fmt.Sprintf("not a global message %#v", msg)) |
268 } | 273 } |
269 | 274 |
270 return nil | 275 return nil |
271 } | 276 } |
272 | 277 |
273 const minPacketLength = 0 | 278 const minPacketLength = 0 |
274 | 279 |
| 280 // handleChannelOpen schedules a channel to be Accept()ed. |
275 func (m *mux) handleChannelOpen(packet []byte) error { | 281 func (m *mux) handleChannelOpen(packet []byte) error { |
276 var msg channelOpenMsg | 282 var msg channelOpenMsg |
277 if err := unmarshal(&msg, packet, msgChannelOpen); err != nil { | 283 if err := unmarshal(&msg, packet, msgChannelOpen); err != nil { |
278 return err | 284 return err |
279 } | 285 } |
280 | 286 |
281 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { | 287 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { |
282 failMsg := channelOpenFailureMsg{ | 288 failMsg := channelOpenFailureMsg{ |
283 PeersId: msg.PeersId, | 289 PeersId: msg.PeersId, |
284 Reason: ConnectionFailed, | 290 Reason: ConnectionFailed, |
285 Message: "invalid request", | 291 Message: "invalid request", |
286 Language: "en_US.UTF-8", | 292 Language: "en_US.UTF-8", |
287 } | 293 } |
288 return m.sendMessage(msgChannelOpenFailure, failMsg) | 294 return m.sendMessage(msgChannelOpenFailure, failMsg) |
289 } | 295 } |
290 | 296 |
291 » c := newChannel(m.conn, msg.ChanType, msg.TypeSpecificData) | 297 » c := newChannel(msg.ChanType, msg.TypeSpecificData) |
292 c.mux = m | 298 c.mux = m |
293 c.remoteId = msg.PeersId | 299 c.remoteId = msg.PeersId |
294 c.maxPacket = msg.MaxPacketSize | 300 c.maxPacket = msg.MaxPacketSize |
295 c.remoteWin.add(msg.PeersWindow) | 301 c.remoteWin.add(msg.PeersWindow) |
296 c.myWindow = defaultWindowSize | 302 c.myWindow = defaultWindowSize |
297 c.localId = m.chanList.add(c) | 303 c.localId = m.chanList.add(c) |
298 » m.openedChans <- c | 304 » m.incomingChannels <- c |
299 return nil | 305 return nil |
300 } | 306 } |
301 | 307 |
302 type OpenChannelFailed struct { | 308 // OpenChannelError is returned the other side rejects our OpenChannel |
| 309 // request. |
| 310 type OpenChannelError struct { |
303 Reason RejectionReason | 311 Reason RejectionReason |
304 Message string | 312 Message string |
305 } | 313 } |
306 | 314 |
307 func (e *OpenChannelFailed) Error() string { | 315 func (e *OpenChannelError) Error() string { |
308 return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) | 316 return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) |
309 } | 317 } |
310 | 318 |
311 // Opens an outgoing channel to the other side. | 319 // OpenChannel asks for a new channel. If the other side rejects, it |
312 func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, error) { | 320 // returns a *OpenChannelError. |
313 » ch := newChannel(m.conn, chanType, extra) | 321 func (m *mux) OpenChannel(chanType string, extra []byte) (*channel, error) { |
| 322 » ch := newChannel(chanType, extra) |
314 ch.mux = m | 323 ch.mux = m |
| 324 |
| 325 // As per RFC 4253 6.1, 32k is also the minimum. |
315 ch.maxPacket = 1 << 15 | 326 ch.maxPacket = 1 << 15 |
316 ch.myWindow = defaultWindowSize | 327 ch.myWindow = defaultWindowSize |
317 ch.localId = m.chanList.add(ch) | 328 ch.localId = m.chanList.add(ch) |
318 | 329 |
319 open := channelOpenMsg{ | 330 open := channelOpenMsg{ |
320 ChanType: chanType, | 331 ChanType: chanType, |
321 PeersWindow: ch.myWindow, | 332 PeersWindow: ch.myWindow, |
322 MaxPacketSize: ch.maxPacket, | 333 MaxPacketSize: ch.maxPacket, |
323 TypeSpecificData: extra, | 334 TypeSpecificData: extra, |
324 PeersId: ch.localId, | 335 PeersId: ch.localId, |
325 } | 336 } |
326 if err := m.sendMessage(msgChannelOpen, open); err != nil { | 337 if err := m.sendMessage(msgChannelOpen, open); err != nil { |
327 return nil, err | 338 return nil, err |
328 } | 339 } |
329 | 340 |
330 switch msg := (<-ch.msg).(type) { | 341 switch msg := (<-ch.msg).(type) { |
331 case *channelOpenConfirmMsg: | 342 case *channelOpenConfirmMsg: |
332 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<
<31 { | 343 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<
<31 { |
333 return nil, errors.New("ssh: invalid MaxPacketSize from
peer") | 344 return nil, errors.New("ssh: invalid MaxPacketSize from
peer") |
334 } | 345 } |
335 // fixup remoteId field | 346 // fixup remoteId field |
336 ch.remoteId = msg.MyId | 347 ch.remoteId = msg.MyId |
337 ch.maxPacket = msg.MaxPacketSize | 348 ch.maxPacket = msg.MaxPacketSize |
338 ch.remoteWin.add(msg.MyWindow) | 349 ch.remoteWin.add(msg.MyWindow) |
339 ch.decided = true | 350 ch.decided = true |
340 return ch, nil | 351 return ch, nil |
341 case *channelOpenFailureMsg: | 352 case *channelOpenFailureMsg: |
342 m.chanList.remove(open.PeersId) | 353 m.chanList.remove(open.PeersId) |
343 » » // What type is appropriate? OpenChannelFailed, | 354 » » return nil, &OpenChannelError{msg.Reason, msg.Message} |
344 » » // *OpenChannelFailed? | 355 » default: |
345 » » return nil, &OpenChannelFailed{msg.Reason, msg.Message} | 356 » » return nil, fmt.Errorf("ssh: unexpected packet %T", msg) |
346 » } | 357 » } |
347 » return nil, errors.New("ssh: unexpected packet") | 358 } |
348 } | |
349 | |
350 func (m *mux) Accept() (Channel, error) { | |
351 » c, ok := <-m.openedChans | |
352 » if !ok { | |
353 » » return nil, io.EOF | |
354 » } | |
355 » return c, nil | |
356 } | |
LEFT | RIGHT |