Left: | ||
Right: |
LEFT | RIGHT |
---|---|
1 // Copyright 2013 The Go Authors. All rights reserved. | 1 // Copyright 2013 The Go Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style | 2 // Use of this source code is governed by a BSD-style |
3 // license that can be found in the LICENSE file. | 3 // license that can be found in the LICENSE file. |
4 | 4 |
5 package ssh | 5 package ssh |
6 | 6 |
7 import ( | 7 import ( |
8 "encoding/binary" | 8 "encoding/binary" |
9 "errors" | 9 "errors" |
10 "fmt" | 10 "fmt" |
(...skipping 29 matching lines...) Expand all Loading... | |
40 c.chans[i] = ch | 40 c.chans[i] = ch |
41 return uint32(i) + c.offset | 41 return uint32(i) + c.offset |
42 } | 42 } |
43 } | 43 } |
44 c.chans = append(c.chans, ch) | 44 c.chans = append(c.chans, ch) |
45 return uint32(len(c.chans)-1) + c.offset | 45 return uint32(len(c.chans)-1) + c.offset |
46 } | 46 } |
47 | 47 |
48 // getChan returns the channel for the given ID. | 48 // getChan returns the channel for the given ID. |
49 func (c *chanList) getChan(id uint32) *channel { | 49 func (c *chanList) getChan(id uint32) *channel { |
50 id -= c.offset | 50 id -= c.offset |
dfc
2013/10/14 00:59:11
the race detector won't like this
hanwen-google
2013/10/14 06:30:26
id is thread-local, while offset is constant. I do
dfc
2013/10/14 06:36:40
You are correct, I was mistaken.
| |
51 | 51 |
52 c.Lock() | 52 c.Lock() |
53 defer c.Unlock() | 53 defer c.Unlock() |
54 if id < uint32(len(c.chans)) { | 54 if id < uint32(len(c.chans)) { |
55 return c.chans[id] | 55 return c.chans[id] |
56 } | 56 } |
57 return nil | 57 return nil |
58 } | 58 } |
59 | 59 |
60 func (c *chanList) remove(id uint32) { | 60 func (c *chanList) remove(id uint32) { |
61 id -= c.offset | 61 id -= c.offset |
dfc
2013/10/14 00:59:11
same
hanwen-google
2013/10/14 06:30:26
same.
| |
62 c.Lock() | 62 c.Lock() |
63 if id < uint32(len(c.chans)) { | 63 if id < uint32(len(c.chans)) { |
64 c.chans[id] = nil | 64 c.chans[id] = nil |
65 } | 65 } |
66 c.Unlock() | 66 c.Unlock() |
67 } | 67 } |
68 | 68 |
69 // dropAll drops all remaining channels | 69 // dropAll drops all remaining channels |
70 func (c *chanList) dropAll() []*channel { | 70 func (c *chanList) dropAll() []*channel { |
71 c.Lock() | 71 c.Lock() |
72 defer c.Unlock() | 72 defer c.Unlock() |
73 var r []*channel | 73 var r []*channel |
74 | 74 |
75 for _, ch := range c.chans { | 75 for _, ch := range c.chans { |
76 if ch == nil { | 76 if ch == nil { |
77 continue | 77 continue |
78 } | 78 } |
79 r = append(r, ch) | 79 r = append(r, ch) |
80 } | 80 } |
81 c.chans = nil | 81 c.chans = nil |
82 return r | 82 return r |
83 } | 83 } |
84 | 84 |
85 // mux contains the state for the SSH connection protocol, which | 85 // mux represents the state for the SSH connection protocol, which |
86 // multiplexes many channels onto a single packet transport. | 86 // multiplexes many channels onto a single packet transport. |
dfc
2013/10/14 00:59:11
// mux represents the state of an SSH connection.
hanwen-google
2013/10/14 06:30:26
On 2013/10/14 00:59:11, dfc wrote:
> // mux repres
| |
87 type mux struct { | 87 type mux struct { |
88 conn packetConn | 88 conn packetConn |
89 chanList chanList | 89 chanList chanList |
dfc
2013/10/14 00:59:11
s/chanList chanList/chanList/
hanwen-google
2013/10/14 06:30:26
one of the problems I had with the old code was ac
| |
90 | 90 |
91 » openedChans chan *channel | 91 » incomingChannels chan *channel |
92 | 92 |
93 » globalSentMu sync.Mutex | 93 » globalSentMu sync.Mutex |
94 » globalResponses chan interface{} | 94 » globalResponses chan interface{} |
95 » globalReceived chan *ChannelRequest | 95 » incomingRequests chan *ChannelRequest |
96 } | |
97 | |
98 func (m *mux) writePacket(p []byte) error { | |
dfc
2013/10/14 00:59:11
if you embed packetConn into mux, then you won't n
hanwen-google
2013/10/14 06:30:26
Done.
| |
99 » return m.conn.writePacket(p) | |
100 } | 96 } |
101 | 97 |
102 // Each new chanList instantiation has a different offset. | 98 // Each new chanList instantiation has a different offset. |
103 var globalOff uint32 | 99 var globalOff uint32 |
104 | 100 |
105 // newMux returns a mux that runs over the given connection. Caller | 101 // newMux returns a mux that runs over the given connection. Caller |
106 // should run Loop for returned mux. | 102 // should run Loop for returned mux. |
107 func newMux(p packetConn) *mux { | 103 func newMux(p packetConn) *mux { |
108 m := &mux{ | 104 m := &mux{ |
109 » » conn: p, | 105 » » conn: p, |
110 » » openedChans: make(chan *channel, 16), | 106 » » incomingChannels: make(chan *channel, 16), |
111 » » globalResponses: make(chan interface{}, 1), | 107 » » globalResponses: make(chan interface{}, 1), |
112 » » globalReceived: make(chan *ChannelRequest, 16), | 108 » » incomingRequests: make(chan *ChannelRequest, 16), |
113 } | 109 } |
114 m.chanList.offset = atomic.AddUint32(&globalOff, 1) | 110 m.chanList.offset = atomic.AddUint32(&globalOff, 1) |
115 return m | 111 return m |
116 } | 112 } |
117 | 113 |
118 func (m *mux) sendMessage(code byte, msg interface{}) error { | 114 func (m *mux) sendMessage(code byte, msg interface{}) error { |
119 p := marshal(code, msg) | 115 p := marshal(code, msg) |
120 return m.conn.writePacket(p) | 116 return m.conn.writePacket(p) |
121 } | 117 } |
122 | 118 |
(...skipping 17 matching lines...) Expand all Loading... | |
140 if wantReply { | 136 if wantReply { |
141 msg, ok := <-m.globalResponses | 137 msg, ok := <-m.globalResponses |
142 if !ok { | 138 if !ok { |
143 return false, nil, io.EOF | 139 return false, nil, io.EOF |
144 } | 140 } |
145 switch msg := msg.(type) { | 141 switch msg := msg.(type) { |
146 case *globalRequestFailureMsg: | 142 case *globalRequestFailureMsg: |
147 return false, msg.Data, nil | 143 return false, msg.Data, nil |
148 case *globalRequestSuccessMsg: | 144 case *globalRequestSuccessMsg: |
149 return true, msg.Data, nil | 145 return true, msg.Data, nil |
146 default: | |
147 return false, nil, fmt.Errorf("ssh: unexpected response %#v", msg) | |
150 } | 148 } |
151 } | 149 } |
152 | 150 |
153 return false, nil, nil | 151 return false, nil, nil |
154 } | |
155 | |
156 // GlobalReceived returns the channel on which incoming global | |
157 // requests are handled. If this channel is not serviced, the entire | |
158 // mux may hang. | |
dfc
2013/10/14 00:59:11
s/may/will/
hanwen-google
2013/10/14 06:30:26
Done.
| |
159 func (m *mux) ReceivedRequests() <-chan *ChannelRequest { | |
160 return m.globalReceived | |
161 } | 152 } |
162 | 153 |
163 // AckRequest must be called after processing a global request that | 154 // AckRequest must be called after processing a global request that |
164 // has WantReply set. | 155 // has WantReply set. |
165 func (m *mux) AckRequest(ok bool, data []byte) error { | 156 func (m *mux) AckRequest(ok bool, data []byte) error { |
166 if ok { | 157 if ok { |
167 return m.sendMessage(msgRequestSuccess, | 158 return m.sendMessage(msgRequestSuccess, |
168 globalRequestSuccessMsg{Data: data}) | 159 globalRequestSuccessMsg{Data: data}) |
169 } | 160 } |
170 return m.sendMessage(msgRequestFailure, globalRequestFailureMsg{Data: da ta}) | 161 return m.sendMessage(msgRequestFailure, globalRequestFailureMsg{Data: da ta}) |
171 } | 162 } |
172 | 163 |
173 // TODO(hanwen): Disconnect is a transport layer message. We should | 164 // TODO(hanwen): Disconnect is a transport layer message. We should |
174 // probably send and receive Disconnect somewhere in the transport | 165 // probably send and receive Disconnect somewhere in the transport |
175 // code. | 166 // code. |
176 | 167 |
177 // Disconnect sends a disconnect message. | 168 // Disconnect sends a disconnect message. |
178 func (m *mux) Disconnect(reason uint32, message string) error { | 169 func (m *mux) Disconnect(reason uint32, message string) error { |
179 return m.sendMessage(msgDisconnect, disconnectMsg{ | 170 return m.sendMessage(msgDisconnect, disconnectMsg{ |
180 Reason: reason, | 171 Reason: reason, |
181 Message: message, | 172 Message: message, |
182 }) | 173 }) |
183 } | 174 } |
184 | 175 |
185 // Loop runs the connection machine. It will process packets until an | 176 // Loop runs the connection machine. It will process packets until an |
186 // error is encountered, returning that error. When the loop exits, | 177 // error is encountered, returning that error. When the loop exits, |
187 // the connection is closed. | 178 // the connection is closed. |
188 func (m *mux) Loop() error { | 179 func (m *mux) Loop() error { |
189 var err error | 180 var err error |
dfc
2013/10/14 00:59:11
for err != nil {
err = m.onePacket()
if de
hanwen-google
2013/10/14 06:30:26
Done.
| |
190 » for { | 181 » for err == nil { |
191 err = m.onePacket() | 182 err = m.onePacket() |
192 » » if err != nil { | 183 » } |
193 » » » if debug { | 184 » if debug && err != nil { |
194 » » » » log.Println("loop exit", err) | 185 » » log.Println("loop exit", err) |
195 » » » } | |
196 » » » break | |
197 » » } | |
198 } | 186 } |
199 | 187 |
200 for _, ch := range m.chanList.dropAll() { | 188 for _, ch := range m.chanList.dropAll() { |
201 ch.mu.Lock() | 189 ch.mu.Lock() |
202 ch.sentClose = true | 190 ch.sentClose = true |
203 ch.mu.Unlock() | 191 ch.mu.Unlock() |
204 ch.pending.eof() | 192 ch.pending.eof() |
205 ch.extPending.eof() | 193 ch.extPending.eof() |
194 close(ch.incomingRequests) | |
206 // ch.msg is otherwise only called from onePacket, so | 195 // ch.msg is otherwise only called from onePacket, so |
207 // this is safe. | 196 // this is safe. |
208 close(ch.pendingRequests) | |
209 close(ch.msg) | 197 close(ch.msg) |
210 } | 198 } |
211 | 199 |
212 » close(m.openedChans) | 200 » close(m.incomingChannels) |
213 » close(m.globalReceived) | 201 » close(m.incomingRequests) |
214 close(m.globalResponses) | 202 close(m.globalResponses) |
215 | 203 |
216 m.conn.Close() | 204 m.conn.Close() |
217 return err | 205 return err |
218 } | 206 } |
219 | 207 |
220 // onePacket reads and processes one packet. | 208 // onePacket reads and processes one packet. |
221 func (m *mux) onePacket() error { | 209 func (m *mux) onePacket() error { |
222 packet, err := m.conn.readPacket() | 210 packet, err := m.conn.readPacket() |
223 if err != nil { | 211 if err != nil { |
(...skipping 42 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
266 } | 254 } |
267 | 255 |
268 func (m *mux) handleGlobalPacket(packet []byte) error { | 256 func (m *mux) handleGlobalPacket(packet []byte) error { |
269 msg, err := decode(packet) | 257 msg, err := decode(packet) |
270 if err != nil { | 258 if err != nil { |
271 return err | 259 return err |
272 } | 260 } |
273 | 261 |
274 switch msg := msg.(type) { | 262 switch msg := msg.(type) { |
275 case *globalRequestMsg: | 263 case *globalRequestMsg: |
276 » » m.globalReceived <- &ChannelRequest{ | 264 » » m.incomingRequests <- &ChannelRequest{ |
277 msg.Type, | 265 msg.Type, |
278 msg.WantReply, | 266 msg.WantReply, |
279 msg.Data, | 267 msg.Data, |
280 } | 268 } |
281 case *globalRequestSuccessMsg, *globalRequestFailureMsg: | 269 case *globalRequestSuccessMsg, *globalRequestFailureMsg: |
282 m.globalResponses <- msg | 270 m.globalResponses <- msg |
283 default: | 271 default: |
284 panic(fmt.Sprintf("not a global message %#v", msg)) | 272 panic(fmt.Sprintf("not a global message %#v", msg)) |
285 } | 273 } |
286 | 274 |
(...skipping 19 matching lines...) Expand all Loading... | |
306 return m.sendMessage(msgChannelOpenFailure, failMsg) | 294 return m.sendMessage(msgChannelOpenFailure, failMsg) |
307 } | 295 } |
308 | 296 |
309 c := newChannel(msg.ChanType, msg.TypeSpecificData) | 297 c := newChannel(msg.ChanType, msg.TypeSpecificData) |
310 c.mux = m | 298 c.mux = m |
311 c.remoteId = msg.PeersId | 299 c.remoteId = msg.PeersId |
312 c.maxPacket = msg.MaxPacketSize | 300 c.maxPacket = msg.MaxPacketSize |
313 c.remoteWin.add(msg.PeersWindow) | 301 c.remoteWin.add(msg.PeersWindow) |
314 c.myWindow = defaultWindowSize | 302 c.myWindow = defaultWindowSize |
315 c.localId = m.chanList.add(c) | 303 c.localId = m.chanList.add(c) |
316 » m.openedChans <- c | 304 » m.incomingChannels <- c |
317 return nil | 305 return nil |
318 } | 306 } |
319 | 307 |
320 // OpenChannelError is returned the other side rejects our OpenChannel | 308 // OpenChannelError is returned the other side rejects our OpenChannel |
321 // request. | 309 // request. |
322 type OpenChannelError struct { | 310 type OpenChannelError struct { |
323 Reason RejectionReason | 311 Reason RejectionReason |
324 Message string | 312 Message string |
325 } | 313 } |
326 | 314 |
327 func (e *OpenChannelError) Error() string { | 315 func (e *OpenChannelError) Error() string { |
328 return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) | 316 return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) |
329 } | 317 } |
330 | 318 |
331 // OpenChannel asks for a new channel. If the other side rejects, it | 319 // OpenChannel asks for a new channel. If the other side rejects, it |
332 // returns a *OpenChannelError. | 320 // returns a *OpenChannelError. |
333 func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, error) { | 321 func (m *mux) OpenChannel(chanType string, extra []byte) (*channel, error) { |
334 ch := newChannel(chanType, extra) | 322 ch := newChannel(chanType, extra) |
335 ch.mux = m | 323 ch.mux = m |
336 | 324 |
337 // As per RFC 4253 6.1, 32k is also the minimum. | 325 // As per RFC 4253 6.1, 32k is also the minimum. |
338 ch.maxPacket = 1 << 15 | 326 ch.maxPacket = 1 << 15 |
339 ch.myWindow = defaultWindowSize | 327 ch.myWindow = defaultWindowSize |
340 ch.localId = m.chanList.add(ch) | 328 ch.localId = m.chanList.add(ch) |
341 | 329 |
342 open := channelOpenMsg{ | 330 open := channelOpenMsg{ |
343 ChanType: chanType, | 331 ChanType: chanType, |
(...skipping 13 matching lines...) Expand all Loading... | |
357 } | 345 } |
358 // fixup remoteId field | 346 // fixup remoteId field |
359 ch.remoteId = msg.MyId | 347 ch.remoteId = msg.MyId |
360 ch.maxPacket = msg.MaxPacketSize | 348 ch.maxPacket = msg.MaxPacketSize |
361 ch.remoteWin.add(msg.MyWindow) | 349 ch.remoteWin.add(msg.MyWindow) |
362 ch.decided = true | 350 ch.decided = true |
363 return ch, nil | 351 return ch, nil |
364 case *channelOpenFailureMsg: | 352 case *channelOpenFailureMsg: |
365 m.chanList.remove(open.PeersId) | 353 m.chanList.remove(open.PeersId) |
366 return nil, &OpenChannelError{msg.Reason, msg.Message} | 354 return nil, &OpenChannelError{msg.Reason, msg.Message} |
367 » } | 355 » default: |
368 » return nil, errors.New("ssh: unexpected packet") | 356 » » return nil, fmt.Errorf("ssh: unexpected packet %T", msg) |
369 } | 357 » } |
370 | 358 } |
371 // Accept returns the next channel that the remote side opened. | |
372 func (m *mux) Accept() (Channel, error) { | |
373 » c, ok := <-m.openedChans | |
374 » if !ok { | |
375 » » return nil, io.EOF | |
376 » } | |
377 » return c, nil | |
378 } | |
LEFT | RIGHT |