Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(126)

Delta Between Two Patch Sets: ssh/mux.go

Issue 14225043: code review 14225043: go.crypto/ssh: reimplement SSH connection protocol modu... (Closed)
Left Patch Set: diff -r bb19605bfacc https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Right Patch Set: diff -r cd1eea1eb828 https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
Left: Side by side diff | Download
Right: Side by side diff | Download
« no previous file with change/comment | « ssh/messages.go ('k') | ssh/mux_test.go » ('j') | no next file with change/comment »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
LEFTRIGHT
1 // Copyright 2013 The Go Authors. All rights reserved. 1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style 2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file. 3 // license that can be found in the LICENSE file.
4 4
5 package ssh 5 package ssh
6 6
7 import ( 7 import (
8 "encoding/binary" 8 "encoding/binary"
9 "errors" 9 "errors"
10 "fmt" 10 "fmt"
(...skipping 29 matching lines...) Expand all
40 c.chans[i] = ch 40 c.chans[i] = ch
41 return uint32(i) + c.offset 41 return uint32(i) + c.offset
42 } 42 }
43 } 43 }
44 c.chans = append(c.chans, ch) 44 c.chans = append(c.chans, ch)
45 return uint32(len(c.chans)-1) + c.offset 45 return uint32(len(c.chans)-1) + c.offset
46 } 46 }
47 47
48 // getChan returns the channel for the given ID. 48 // getChan returns the channel for the given ID.
49 func (c *chanList) getChan(id uint32) *channel { 49 func (c *chanList) getChan(id uint32) *channel {
50 id -= c.offset 50 id -= c.offset
dfc 2013/10/14 00:59:11 the race detector won't like this
hanwen-google 2013/10/14 06:30:26 id is thread-local, while offset is constant. I do
dfc 2013/10/14 06:36:40 You are correct, I was mistaken.
51 51
52 c.Lock() 52 c.Lock()
53 defer c.Unlock() 53 defer c.Unlock()
54 if id < uint32(len(c.chans)) { 54 if id < uint32(len(c.chans)) {
55 return c.chans[id] 55 return c.chans[id]
56 } 56 }
57 return nil 57 return nil
58 } 58 }
59 59
60 func (c *chanList) remove(id uint32) { 60 func (c *chanList) remove(id uint32) {
61 id -= c.offset 61 id -= c.offset
dfc 2013/10/14 00:59:11 same
hanwen-google 2013/10/14 06:30:26 same.
62 c.Lock() 62 c.Lock()
63 if id < uint32(len(c.chans)) { 63 if id < uint32(len(c.chans)) {
64 c.chans[id] = nil 64 c.chans[id] = nil
65 } 65 }
66 c.Unlock() 66 c.Unlock()
67 } 67 }
68 68
69 // dropAll drops all remaining channels 69 // dropAll drops all remaining channels
70 func (c *chanList) dropAll() []*channel { 70 func (c *chanList) dropAll() []*channel {
71 c.Lock() 71 c.Lock()
72 defer c.Unlock() 72 defer c.Unlock()
73 var r []*channel 73 var r []*channel
74 74
75 for _, ch := range c.chans { 75 for _, ch := range c.chans {
76 if ch == nil { 76 if ch == nil {
77 continue 77 continue
78 } 78 }
79 r = append(r, ch) 79 r = append(r, ch)
80 } 80 }
81 c.chans = nil 81 c.chans = nil
82 return r 82 return r
83 } 83 }
84 84
85 // mux contains the state for the SSH connection protocol, which 85 // mux represents the state for the SSH connection protocol, which
86 // multiplexes many channels onto a single packet transport. 86 // multiplexes many channels onto a single packet transport.
dfc 2013/10/14 00:59:11 // mux represents the state of an SSH connection.
hanwen-google 2013/10/14 06:30:26 On 2013/10/14 00:59:11, dfc wrote: > // mux repres
87 type mux struct { 87 type mux struct {
88 conn packetConn 88 conn packetConn
89 chanList chanList 89 chanList chanList
dfc 2013/10/14 00:59:11 s/chanList chanList/chanList/
hanwen-google 2013/10/14 06:30:26 one of the problems I had with the old code was ac
90 90
91 » openedChans chan *channel 91 » incomingChannels chan *channel
92 92
93 » globalSentMu sync.Mutex 93 » globalSentMu sync.Mutex
94 » globalResponses chan interface{} 94 » globalResponses chan interface{}
95 » globalReceived chan *ChannelRequest 95 » incomingRequests chan *ChannelRequest
96 }
97
98 func (m *mux) writePacket(p []byte) error {
dfc 2013/10/14 00:59:11 if you embed packetConn into mux, then you won't n
hanwen-google 2013/10/14 06:30:26 Done.
99 » return m.conn.writePacket(p)
100 } 96 }
101 97
102 // Each new chanList instantiation has a different offset. 98 // Each new chanList instantiation has a different offset.
103 var globalOff uint32 99 var globalOff uint32
104 100
105 // newMux returns a mux that runs over the given connection. Caller 101 // newMux returns a mux that runs over the given connection. Caller
106 // should run Loop for returned mux. 102 // should run Loop for returned mux.
107 func newMux(p packetConn) *mux { 103 func newMux(p packetConn) *mux {
108 m := &mux{ 104 m := &mux{
109 » » conn: p, 105 » » conn: p,
110 » » openedChans: make(chan *channel, 16), 106 » » incomingChannels: make(chan *channel, 16),
111 » » globalResponses: make(chan interface{}, 1), 107 » » globalResponses: make(chan interface{}, 1),
112 » » globalReceived: make(chan *ChannelRequest, 16), 108 » » incomingRequests: make(chan *ChannelRequest, 16),
113 } 109 }
114 m.chanList.offset = atomic.AddUint32(&globalOff, 1) 110 m.chanList.offset = atomic.AddUint32(&globalOff, 1)
115 return m 111 return m
116 } 112 }
117 113
118 func (m *mux) sendMessage(code byte, msg interface{}) error { 114 func (m *mux) sendMessage(code byte, msg interface{}) error {
119 p := marshal(code, msg) 115 p := marshal(code, msg)
120 return m.conn.writePacket(p) 116 return m.conn.writePacket(p)
121 } 117 }
122 118
(...skipping 17 matching lines...) Expand all
140 if wantReply { 136 if wantReply {
141 msg, ok := <-m.globalResponses 137 msg, ok := <-m.globalResponses
142 if !ok { 138 if !ok {
143 return false, nil, io.EOF 139 return false, nil, io.EOF
144 } 140 }
145 switch msg := msg.(type) { 141 switch msg := msg.(type) {
146 case *globalRequestFailureMsg: 142 case *globalRequestFailureMsg:
147 return false, msg.Data, nil 143 return false, msg.Data, nil
148 case *globalRequestSuccessMsg: 144 case *globalRequestSuccessMsg:
149 return true, msg.Data, nil 145 return true, msg.Data, nil
146 default:
147 return false, nil, fmt.Errorf("ssh: unexpected response %#v", msg)
150 } 148 }
151 } 149 }
152 150
153 return false, nil, nil 151 return false, nil, nil
154 }
155
156 // GlobalReceived returns the channel on which incoming global
157 // requests are handled. If this channel is not serviced, the entire
158 // mux may hang.
dfc 2013/10/14 00:59:11 s/may/will/
hanwen-google 2013/10/14 06:30:26 Done.
159 func (m *mux) ReceivedRequests() <-chan *ChannelRequest {
160 return m.globalReceived
161 } 152 }
162 153
163 // AckRequest must be called after processing a global request that 154 // AckRequest must be called after processing a global request that
164 // has WantReply set. 155 // has WantReply set.
165 func (m *mux) AckRequest(ok bool, data []byte) error { 156 func (m *mux) AckRequest(ok bool, data []byte) error {
166 if ok { 157 if ok {
167 return m.sendMessage(msgRequestSuccess, 158 return m.sendMessage(msgRequestSuccess,
168 globalRequestSuccessMsg{Data: data}) 159 globalRequestSuccessMsg{Data: data})
169 } 160 }
170 return m.sendMessage(msgRequestFailure, globalRequestFailureMsg{Data: da ta}) 161 return m.sendMessage(msgRequestFailure, globalRequestFailureMsg{Data: da ta})
171 } 162 }
172 163
173 // TODO(hanwen): Disconnect is a transport layer message. We should 164 // TODO(hanwen): Disconnect is a transport layer message. We should
174 // probably send and receive Disconnect somewhere in the transport 165 // probably send and receive Disconnect somewhere in the transport
175 // code. 166 // code.
176 167
177 // Disconnect sends a disconnect message. 168 // Disconnect sends a disconnect message.
178 func (m *mux) Disconnect(reason uint32, message string) error { 169 func (m *mux) Disconnect(reason uint32, message string) error {
179 return m.sendMessage(msgDisconnect, disconnectMsg{ 170 return m.sendMessage(msgDisconnect, disconnectMsg{
180 Reason: reason, 171 Reason: reason,
181 Message: message, 172 Message: message,
182 }) 173 })
183 } 174 }
184 175
185 // Loop runs the connection machine. It will process packets until an 176 // Loop runs the connection machine. It will process packets until an
186 // error is encountered, returning that error. When the loop exits, 177 // error is encountered, returning that error. When the loop exits,
187 // the connection is closed. 178 // the connection is closed.
188 func (m *mux) Loop() error { 179 func (m *mux) Loop() error {
189 var err error 180 var err error
dfc 2013/10/14 00:59:11 for err != nil { err = m.onePacket() if de
hanwen-google 2013/10/14 06:30:26 Done.
190 » for { 181 » for err == nil {
191 err = m.onePacket() 182 err = m.onePacket()
192 » » if err != nil { 183 » }
193 » » » if debug { 184 » if debug && err != nil {
194 » » » » log.Println("loop exit", err) 185 » » log.Println("loop exit", err)
195 » » » }
196 » » » break
197 » » }
198 } 186 }
199 187
200 for _, ch := range m.chanList.dropAll() { 188 for _, ch := range m.chanList.dropAll() {
201 ch.mu.Lock() 189 ch.mu.Lock()
202 ch.sentClose = true 190 ch.sentClose = true
203 ch.mu.Unlock() 191 ch.mu.Unlock()
204 ch.pending.eof() 192 ch.pending.eof()
205 ch.extPending.eof() 193 ch.extPending.eof()
194 close(ch.incomingRequests)
206 // ch.msg is otherwise only called from onePacket, so 195 // ch.msg is otherwise only called from onePacket, so
207 // this is safe. 196 // this is safe.
208 close(ch.pendingRequests)
209 close(ch.msg) 197 close(ch.msg)
210 } 198 }
211 199
212 » close(m.openedChans) 200 » close(m.incomingChannels)
213 » close(m.globalReceived) 201 » close(m.incomingRequests)
214 close(m.globalResponses) 202 close(m.globalResponses)
215 203
216 m.conn.Close() 204 m.conn.Close()
217 return err 205 return err
218 } 206 }
219 207
220 // onePacket reads and processes one packet. 208 // onePacket reads and processes one packet.
221 func (m *mux) onePacket() error { 209 func (m *mux) onePacket() error {
222 packet, err := m.conn.readPacket() 210 packet, err := m.conn.readPacket()
223 if err != nil { 211 if err != nil {
(...skipping 42 matching lines...) Expand 10 before | Expand all | Expand 10 after
266 } 254 }
267 255
268 func (m *mux) handleGlobalPacket(packet []byte) error { 256 func (m *mux) handleGlobalPacket(packet []byte) error {
269 msg, err := decode(packet) 257 msg, err := decode(packet)
270 if err != nil { 258 if err != nil {
271 return err 259 return err
272 } 260 }
273 261
274 switch msg := msg.(type) { 262 switch msg := msg.(type) {
275 case *globalRequestMsg: 263 case *globalRequestMsg:
276 » » m.globalReceived <- &ChannelRequest{ 264 » » m.incomingRequests <- &ChannelRequest{
277 msg.Type, 265 msg.Type,
278 msg.WantReply, 266 msg.WantReply,
279 msg.Data, 267 msg.Data,
280 } 268 }
281 case *globalRequestSuccessMsg, *globalRequestFailureMsg: 269 case *globalRequestSuccessMsg, *globalRequestFailureMsg:
282 m.globalResponses <- msg 270 m.globalResponses <- msg
283 default: 271 default:
284 panic(fmt.Sprintf("not a global message %#v", msg)) 272 panic(fmt.Sprintf("not a global message %#v", msg))
285 } 273 }
286 274
(...skipping 19 matching lines...) Expand all
306 return m.sendMessage(msgChannelOpenFailure, failMsg) 294 return m.sendMessage(msgChannelOpenFailure, failMsg)
307 } 295 }
308 296
309 c := newChannel(msg.ChanType, msg.TypeSpecificData) 297 c := newChannel(msg.ChanType, msg.TypeSpecificData)
310 c.mux = m 298 c.mux = m
311 c.remoteId = msg.PeersId 299 c.remoteId = msg.PeersId
312 c.maxPacket = msg.MaxPacketSize 300 c.maxPacket = msg.MaxPacketSize
313 c.remoteWin.add(msg.PeersWindow) 301 c.remoteWin.add(msg.PeersWindow)
314 c.myWindow = defaultWindowSize 302 c.myWindow = defaultWindowSize
315 c.localId = m.chanList.add(c) 303 c.localId = m.chanList.add(c)
316 » m.openedChans <- c 304 » m.incomingChannels <- c
317 return nil 305 return nil
318 } 306 }
319 307
320 // OpenChannelError is returned the other side rejects our OpenChannel 308 // OpenChannelError is returned the other side rejects our OpenChannel
321 // request. 309 // request.
322 type OpenChannelError struct { 310 type OpenChannelError struct {
323 Reason RejectionReason 311 Reason RejectionReason
324 Message string 312 Message string
325 } 313 }
326 314
327 func (e *OpenChannelError) Error() string { 315 func (e *OpenChannelError) Error() string {
328 return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) 316 return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message)
329 } 317 }
330 318
331 // OpenChannel asks for a new channel. If the other side rejects, it 319 // OpenChannel asks for a new channel. If the other side rejects, it
332 // returns a *OpenChannelError. 320 // returns a *OpenChannelError.
333 func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, error) { 321 func (m *mux) OpenChannel(chanType string, extra []byte) (*channel, error) {
334 ch := newChannel(chanType, extra) 322 ch := newChannel(chanType, extra)
335 ch.mux = m 323 ch.mux = m
336 324
337 // As per RFC 4253 6.1, 32k is also the minimum. 325 // As per RFC 4253 6.1, 32k is also the minimum.
338 ch.maxPacket = 1 << 15 326 ch.maxPacket = 1 << 15
339 ch.myWindow = defaultWindowSize 327 ch.myWindow = defaultWindowSize
340 ch.localId = m.chanList.add(ch) 328 ch.localId = m.chanList.add(ch)
341 329
342 open := channelOpenMsg{ 330 open := channelOpenMsg{
343 ChanType: chanType, 331 ChanType: chanType,
(...skipping 13 matching lines...) Expand all
357 } 345 }
358 // fixup remoteId field 346 // fixup remoteId field
359 ch.remoteId = msg.MyId 347 ch.remoteId = msg.MyId
360 ch.maxPacket = msg.MaxPacketSize 348 ch.maxPacket = msg.MaxPacketSize
361 ch.remoteWin.add(msg.MyWindow) 349 ch.remoteWin.add(msg.MyWindow)
362 ch.decided = true 350 ch.decided = true
363 return ch, nil 351 return ch, nil
364 case *channelOpenFailureMsg: 352 case *channelOpenFailureMsg:
365 m.chanList.remove(open.PeersId) 353 m.chanList.remove(open.PeersId)
366 return nil, &OpenChannelError{msg.Reason, msg.Message} 354 return nil, &OpenChannelError{msg.Reason, msg.Message}
367 » } 355 » default:
368 » return nil, errors.New("ssh: unexpected packet") 356 » » return nil, fmt.Errorf("ssh: unexpected packet %T", msg)
369 } 357 » }
370 358 }
371 // Accept returns the next channel that the remote side opened.
372 func (m *mux) Accept() (Channel, error) {
373 » c, ok := <-m.openedChans
374 » if !ok {
375 » » return nil, io.EOF
376 » }
377 » return c, nil
378 }
LEFTRIGHT

Powered by Google App Engine
RSS Feeds Recent Issues | This issue
This is Rietveld f62528b