LEFT | RIGHT |
1 // Copyright 2013 The Go Authors. All rights reserved. | 1 // Copyright 2013 The Go Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style | 2 // Use of this source code is governed by a BSD-style |
3 // license that can be found in the LICENSE file. | 3 // license that can be found in the LICENSE file. |
4 | 4 |
5 package ssh | 5 package ssh |
6 | 6 |
7 import ( | 7 import ( |
8 "encoding/binary" | 8 "encoding/binary" |
9 "errors" | 9 "errors" |
10 "fmt" | 10 "fmt" |
(...skipping 70 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
81 c.chans = nil | 81 c.chans = nil |
82 return r | 82 return r |
83 } | 83 } |
84 | 84 |
85 // mux represents the state for the SSH connection protocol, which | 85 // mux represents the state for the SSH connection protocol, which |
86 // multiplexes many channels onto a single packet transport. | 86 // multiplexes many channels onto a single packet transport. |
87 type mux struct { | 87 type mux struct { |
88 conn packetConn | 88 conn packetConn |
89 chanList chanList | 89 chanList chanList |
90 | 90 |
91 » openedChans chan *channel | 91 » incomingChannels chan *channel |
92 | 92 |
93 » globalSentMu sync.Mutex | 93 » globalSentMu sync.Mutex |
94 » globalResponses chan interface{} | 94 » globalResponses chan interface{} |
95 » globalReceived chan *ChannelRequest | 95 » incomingRequests chan *ChannelRequest |
96 } | 96 } |
97 | 97 |
98 // Each new chanList instantiation has a different offset. | 98 // Each new chanList instantiation has a different offset. |
99 var globalOff uint32 | 99 var globalOff uint32 |
100 | 100 |
101 // newMux returns a mux that runs over the given connection. Caller | 101 // newMux returns a mux that runs over the given connection. Caller |
102 // should run Loop for returned mux. | 102 // should run Loop for returned mux. |
103 func newMux(p packetConn) *mux { | 103 func newMux(p packetConn) *mux { |
104 m := &mux{ | 104 m := &mux{ |
105 » » conn: p, | 105 » » conn: p, |
106 » » openedChans: make(chan *channel, 16), | 106 » » incomingChannels: make(chan *channel, 16), |
107 » » globalResponses: make(chan interface{}, 1), | 107 » » globalResponses: make(chan interface{}, 1), |
108 » » globalReceived: make(chan *ChannelRequest, 16), | 108 » » incomingRequests: make(chan *ChannelRequest, 16), |
109 } | 109 } |
110 m.chanList.offset = atomic.AddUint32(&globalOff, 1) | 110 m.chanList.offset = atomic.AddUint32(&globalOff, 1) |
111 return m | 111 return m |
112 } | 112 } |
113 | 113 |
114 func (m *mux) sendMessage(code byte, msg interface{}) error { | 114 func (m *mux) sendMessage(code byte, msg interface{}) error { |
115 p := marshal(code, msg) | 115 p := marshal(code, msg) |
116 return m.conn.writePacket(p) | 116 return m.conn.writePacket(p) |
117 } | 117 } |
118 | 118 |
(...skipping 25 matching lines...) Expand all Loading... |
144 case *globalRequestSuccessMsg: | 144 case *globalRequestSuccessMsg: |
145 return true, msg.Data, nil | 145 return true, msg.Data, nil |
146 default: | 146 default: |
147 return false, nil, fmt.Errorf("ssh: unexpected response
%#v", msg) | 147 return false, nil, fmt.Errorf("ssh: unexpected response
%#v", msg) |
148 } | 148 } |
149 } | 149 } |
150 | 150 |
151 return false, nil, nil | 151 return false, nil, nil |
152 } | 152 } |
153 | 153 |
154 // GlobalReceived returns the channel on which incoming global | |
155 // requests are handled. If this channel is not serviced, the entire | |
156 // mux will hang. | |
157 func (m *mux) IncomingRequests() <-chan *ChannelRequest { | |
158 return m.globalReceived | |
159 } | |
160 | |
161 // AckRequest must be called after processing a global request that | 154 // AckRequest must be called after processing a global request that |
162 // has WantReply set. | 155 // has WantReply set. |
163 func (m *mux) AckRequest(ok bool, data []byte) error { | 156 func (m *mux) AckRequest(ok bool, data []byte) error { |
164 if ok { | 157 if ok { |
165 return m.sendMessage(msgRequestSuccess, | 158 return m.sendMessage(msgRequestSuccess, |
166 globalRequestSuccessMsg{Data: data}) | 159 globalRequestSuccessMsg{Data: data}) |
167 } | 160 } |
168 return m.sendMessage(msgRequestFailure, globalRequestFailureMsg{Data: da
ta}) | 161 return m.sendMessage(msgRequestFailure, globalRequestFailureMsg{Data: da
ta}) |
169 } | 162 } |
170 | 163 |
171 // TODO(hanwen): Disconnect is a transport layer message. We should | 164 // TODO(hanwen): Disconnect is a transport layer message. We should |
172 // probably send and receive Disconnect somewhere in the transport | 165 // probably send and receive Disconnect somewhere in the transport |
173 // code. | 166 // code. |
174 | 167 |
175 // Disconnect sends a disconnect message. | 168 // Disconnect sends a disconnect message. |
176 func (m *mux) Disconnect(reason uint32, message string) error { | 169 func (m *mux) Disconnect(reason uint32, message string) error { |
177 return m.sendMessage(msgDisconnect, disconnectMsg{ | 170 return m.sendMessage(msgDisconnect, disconnectMsg{ |
178 Reason: reason, | 171 Reason: reason, |
179 Message: message, | 172 Message: message, |
180 }) | 173 }) |
181 } | 174 } |
182 | 175 |
183 // Loop runs the connection machine. It will process packets until an | 176 // Loop runs the connection machine. It will process packets until an |
184 // error is encountered, returning that error. When the loop exits, | 177 // error is encountered, returning that error. When the loop exits, |
185 // the connection is closed. | 178 // the connection is closed. |
186 func (m *mux) Loop() error { | 179 func (m *mux) Loop() error { |
187 var err error | 180 var err error |
188 for err == nil { | 181 for err == nil { |
189 err = m.onePacket() | 182 err = m.onePacket() |
190 » » if debug { | 183 » } |
191 » » » log.Println("loop exit", err) | 184 » if debug && err != nil { |
192 » » } | 185 » » log.Println("loop exit", err) |
193 } | 186 } |
194 | 187 |
195 for _, ch := range m.chanList.dropAll() { | 188 for _, ch := range m.chanList.dropAll() { |
196 ch.mu.Lock() | 189 ch.mu.Lock() |
197 ch.sentClose = true | 190 ch.sentClose = true |
198 ch.mu.Unlock() | 191 ch.mu.Unlock() |
199 ch.pending.eof() | 192 ch.pending.eof() |
200 ch.extPending.eof() | 193 ch.extPending.eof() |
| 194 close(ch.incomingRequests) |
201 // ch.msg is otherwise only called from onePacket, so | 195 // ch.msg is otherwise only called from onePacket, so |
202 // this is safe. | 196 // this is safe. |
203 close(ch.pendingRequests) | |
204 close(ch.msg) | 197 close(ch.msg) |
205 } | 198 } |
206 | 199 |
207 » close(m.openedChans) | 200 » close(m.incomingChannels) |
208 » close(m.globalReceived) | 201 » close(m.incomingRequests) |
209 close(m.globalResponses) | 202 close(m.globalResponses) |
210 | 203 |
211 m.conn.Close() | 204 m.conn.Close() |
212 return err | 205 return err |
213 } | 206 } |
214 | 207 |
215 // onePacket reads and processes one packet. | 208 // onePacket reads and processes one packet. |
216 func (m *mux) onePacket() error { | 209 func (m *mux) onePacket() error { |
217 packet, err := m.conn.readPacket() | 210 packet, err := m.conn.readPacket() |
218 if err != nil { | 211 if err != nil { |
(...skipping 42 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
261 } | 254 } |
262 | 255 |
263 func (m *mux) handleGlobalPacket(packet []byte) error { | 256 func (m *mux) handleGlobalPacket(packet []byte) error { |
264 msg, err := decode(packet) | 257 msg, err := decode(packet) |
265 if err != nil { | 258 if err != nil { |
266 return err | 259 return err |
267 } | 260 } |
268 | 261 |
269 switch msg := msg.(type) { | 262 switch msg := msg.(type) { |
270 case *globalRequestMsg: | 263 case *globalRequestMsg: |
271 » » m.globalReceived <- &ChannelRequest{ | 264 » » m.incomingRequests <- &ChannelRequest{ |
272 msg.Type, | 265 msg.Type, |
273 msg.WantReply, | 266 msg.WantReply, |
274 msg.Data, | 267 msg.Data, |
275 } | 268 } |
276 case *globalRequestSuccessMsg, *globalRequestFailureMsg: | 269 case *globalRequestSuccessMsg, *globalRequestFailureMsg: |
277 m.globalResponses <- msg | 270 m.globalResponses <- msg |
278 default: | 271 default: |
279 panic(fmt.Sprintf("not a global message %#v", msg)) | 272 panic(fmt.Sprintf("not a global message %#v", msg)) |
280 } | 273 } |
281 | 274 |
(...skipping 19 matching lines...) Expand all Loading... |
301 return m.sendMessage(msgChannelOpenFailure, failMsg) | 294 return m.sendMessage(msgChannelOpenFailure, failMsg) |
302 } | 295 } |
303 | 296 |
304 c := newChannel(msg.ChanType, msg.TypeSpecificData) | 297 c := newChannel(msg.ChanType, msg.TypeSpecificData) |
305 c.mux = m | 298 c.mux = m |
306 c.remoteId = msg.PeersId | 299 c.remoteId = msg.PeersId |
307 c.maxPacket = msg.MaxPacketSize | 300 c.maxPacket = msg.MaxPacketSize |
308 c.remoteWin.add(msg.PeersWindow) | 301 c.remoteWin.add(msg.PeersWindow) |
309 c.myWindow = defaultWindowSize | 302 c.myWindow = defaultWindowSize |
310 c.localId = m.chanList.add(c) | 303 c.localId = m.chanList.add(c) |
311 » m.openedChans <- c | 304 » m.incomingChannels <- c |
312 return nil | 305 return nil |
313 } | 306 } |
314 | 307 |
315 // OpenChannelError is returned the other side rejects our OpenChannel | 308 // OpenChannelError is returned the other side rejects our OpenChannel |
316 // request. | 309 // request. |
317 type OpenChannelError struct { | 310 type OpenChannelError struct { |
318 Reason RejectionReason | 311 Reason RejectionReason |
319 Message string | 312 Message string |
320 } | 313 } |
321 | 314 |
322 func (e *OpenChannelError) Error() string { | 315 func (e *OpenChannelError) Error() string { |
323 return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) | 316 return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) |
324 } | 317 } |
325 | 318 |
326 // OpenChannel asks for a new channel. If the other side rejects, it | 319 // OpenChannel asks for a new channel. If the other side rejects, it |
327 // returns a *OpenChannelError. | 320 // returns a *OpenChannelError. |
328 func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, error) { | 321 func (m *mux) OpenChannel(chanType string, extra []byte) (*channel, error) { |
329 ch := newChannel(chanType, extra) | 322 ch := newChannel(chanType, extra) |
330 ch.mux = m | 323 ch.mux = m |
331 | 324 |
332 // As per RFC 4253 6.1, 32k is also the minimum. | 325 // As per RFC 4253 6.1, 32k is also the minimum. |
333 ch.maxPacket = 1 << 15 | 326 ch.maxPacket = 1 << 15 |
334 ch.myWindow = defaultWindowSize | 327 ch.myWindow = defaultWindowSize |
335 ch.localId = m.chanList.add(ch) | 328 ch.localId = m.chanList.add(ch) |
336 | 329 |
337 open := channelOpenMsg{ | 330 open := channelOpenMsg{ |
338 ChanType: chanType, | 331 ChanType: chanType, |
(...skipping 17 matching lines...) Expand all Loading... |
356 ch.remoteWin.add(msg.MyWindow) | 349 ch.remoteWin.add(msg.MyWindow) |
357 ch.decided = true | 350 ch.decided = true |
358 return ch, nil | 351 return ch, nil |
359 case *channelOpenFailureMsg: | 352 case *channelOpenFailureMsg: |
360 m.chanList.remove(open.PeersId) | 353 m.chanList.remove(open.PeersId) |
361 return nil, &OpenChannelError{msg.Reason, msg.Message} | 354 return nil, &OpenChannelError{msg.Reason, msg.Message} |
362 default: | 355 default: |
363 return nil, fmt.Errorf("ssh: unexpected packet %T", msg) | 356 return nil, fmt.Errorf("ssh: unexpected packet %T", msg) |
364 } | 357 } |
365 } | 358 } |
366 | |
367 // Accept returns the next channel that the remote side opened. | |
368 func (m *mux) Accept() (Channel, error) { | |
369 c, ok := <-m.openedChans | |
370 if !ok { | |
371 return nil, io.EOF | |
372 } | |
373 return c, nil | |
374 } | |
LEFT | RIGHT |