Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(708)

Delta Between Two Patch Sets: ssh/mux.go

Issue 14225043: code review 14225043: go.crypto/ssh: reimplement SSH connection protocol modu... (Closed)
Left Patch Set: diff -r 2cd6b3b93cdb https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Right Patch Set: diff -r cd1eea1eb828 https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
Left: Side by side diff | Download
Right: Side by side diff | Download
« no previous file with change/comment | « ssh/messages.go ('k') | ssh/mux_test.go » ('j') | no next file with change/comment »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
LEFTRIGHT
1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
1 package ssh 5 package ssh
2 6
3 import ( 7 import (
4 "encoding/binary" 8 "encoding/binary"
5 "errors" 9 "errors"
6 "fmt" 10 "fmt"
7 "io" 11 "io"
8 "log" 12 "log"
9 "sync" 13 "sync"
14 "sync/atomic"
10 ) 15 )
16
17 const debug = false
11 18
12 // Thread safe channel list. 19 // Thread safe channel list.
13 type chanList struct { 20 type chanList struct {
14 // protects concurrent access to chans 21 // protects concurrent access to chans
15 sync.Mutex 22 sync.Mutex
16 » // chans are indexed by the local id of the channel, nChannel.localId. 23
17 » // The PeersId value of messages received by ClientConn.mainLoop is 24 » // chans are indexed by the local id of the channel, which the
18 » // used to locate the right local nChannel in this slice. 25 » // other side should send in the PeersId field.
19 chans []*channel 26 chans []*channel
20 27
21 // This is a debugging aid: it offsets all IDs by this 28 // This is a debugging aid: it offsets all IDs by this
22 // amount. This helps distinguish otherwise identical 29 // amount. This helps distinguish otherwise identical
23 // server/client muxes 30 // server/client muxes
24 offset uint32 31 offset uint32
25 } 32 }
26 33
27 // Assigns a channel ID to the given channel. 34 // Assigns a channel ID to the given channel.
28 func (c *chanList) add(ch *channel) uint32 { 35 func (c *chanList) add(ch *channel) uint32 {
(...skipping 39 matching lines...) Expand 10 before | Expand all | Expand 10 after
68 for _, ch := range c.chans { 75 for _, ch := range c.chans {
69 if ch == nil { 76 if ch == nil {
70 continue 77 continue
71 } 78 }
72 r = append(r, ch) 79 r = append(r, ch)
73 } 80 }
74 c.chans = nil 81 c.chans = nil
75 return r 82 return r
76 } 83 }
77 84
78 // mux contains the state for the SSH connection protocol, which 85 // mux represents the state for the SSH connection protocol, which
79 // multiplexes many channels onto a single packet transport. 86 // multiplexes many channels onto a single packet transport.
80 type mux struct { 87 type mux struct {
81 conn packetConn 88 conn packetConn
82 chanList chanList 89 chanList chanList
83 90
84 » openedChans chan *channel 91 » incomingChannels chan *channel
85 92
86 » globalSentMu sync.Mutex 93 » globalSentMu sync.Mutex
87 » globalResponses chan interface{} 94 » globalResponses chan interface{}
88 » globalReceived chan *ChannelRequest 95 » incomingRequests chan *ChannelRequest
89 } 96 }
90 97
98 // Each new chanList instantiation has a different offset.
91 var globalOff uint32 99 var globalOff uint32
92 100
93 // newMux returns a mux that runs over the given connection. 101 // newMux returns a mux that runs over the given connection. Caller
102 // should run Loop for returned mux.
94 func newMux(p packetConn) *mux { 103 func newMux(p packetConn) *mux {
95 m := &mux{ 104 m := &mux{
96 » » conn: p, 105 » » conn: p,
97 » » openedChans: make(chan *channel, 16), 106 » » incomingChannels: make(chan *channel, 16),
98 » » globalResponses: make(chan interface{}, 1), 107 » » globalResponses: make(chan interface{}, 1),
99 » » globalReceived: make(chan *ChannelRequest, 16), 108 » » incomingRequests: make(chan *ChannelRequest, 16),
100 » } 109 » }
101 » m.chanList.offset = globalOff 110 » m.chanList.offset = atomic.AddUint32(&globalOff, 1)
102 » globalOff++
103
104 return m 111 return m
105 } 112 }
106 113
107 func (m *mux) sendMessage(code byte, msg interface{}) error { 114 func (m *mux) sendMessage(code byte, msg interface{}) error {
108 p := marshal(code, msg) 115 p := marshal(code, msg)
109 return m.conn.writePacket(p) 116 return m.conn.writePacket(p)
110 } 117 }
111 118
112 // SendRequest sends a global request. If wantReply is set, the 119 // SendRequest sends a global request. If wantReply is set, the
113 // return includes success status and extra data. See also RFC4254 section 4 120 // return includes success status and extra data. See also RFC4254 section 4
(...skipping 15 matching lines...) Expand all
129 if wantReply { 136 if wantReply {
130 msg, ok := <-m.globalResponses 137 msg, ok := <-m.globalResponses
131 if !ok { 138 if !ok {
132 return false, nil, io.EOF 139 return false, nil, io.EOF
133 } 140 }
134 switch msg := msg.(type) { 141 switch msg := msg.(type) {
135 case *globalRequestFailureMsg: 142 case *globalRequestFailureMsg:
136 return false, msg.Data, nil 143 return false, msg.Data, nil
137 case *globalRequestSuccessMsg: 144 case *globalRequestSuccessMsg:
138 return true, msg.Data, nil 145 return true, msg.Data, nil
146 default:
147 return false, nil, fmt.Errorf("ssh: unexpected response %#v", msg)
139 } 148 }
140 } 149 }
141 150
142 return false, nil, nil 151 return false, nil, nil
143 } 152 }
144 153
145 // GlobalReceived returns the channel on which incoming global 154 // AckRequest must be called after processing a global request that
146 // requests are handled. 155 // has WantReply set.
147 func (m *mux) ReceivedRequests() <-chan *ChannelRequest { 156 func (m *mux) AckRequest(ok bool, data []byte) error {
148 » return m.globalReceived
149 }
150
151 // AckGlobalRequest must be called for a global request with WantReply
152 // set.
153 func (m *mux) AckGlobalRequest(ok bool, data []byte) error {
154 if ok { 157 if ok {
155 return m.sendMessage(msgRequestSuccess, 158 return m.sendMessage(msgRequestSuccess,
156 globalRequestSuccessMsg{Data: data}) 159 globalRequestSuccessMsg{Data: data})
157 } 160 }
158 return m.sendMessage(msgRequestFailure, globalRequestFailureMsg{Data: da ta}) 161 return m.sendMessage(msgRequestFailure, globalRequestFailureMsg{Data: da ta})
159 } 162 }
160 163
164 // TODO(hanwen): Disconnect is a transport layer message. We should
165 // probably send and receive Disconnect somewhere in the transport
166 // code.
167
168 // Disconnect sends a disconnect message.
161 func (m *mux) Disconnect(reason uint32, message string) error { 169 func (m *mux) Disconnect(reason uint32, message string) error {
162 return m.sendMessage(msgDisconnect, disconnectMsg{ 170 return m.sendMessage(msgDisconnect, disconnectMsg{
163 Reason: reason, 171 Reason: reason,
164 Message: message, 172 Message: message,
165 }) 173 })
166 } 174 }
167 175
168 // Loop runs the connection machine. It will process packets until an 176 // Loop runs the connection machine. It will process packets until an
169 // error is encountered, returning that error. 177 // error is encountered, returning that error. When the loop exits,
178 // the connection is closed.
170 func (m *mux) Loop() error { 179 func (m *mux) Loop() error {
171 var err error 180 var err error
172 » for { 181 » for err == nil {
173 err = m.onePacket() 182 err = m.onePacket()
174 » » if err != nil { 183 » }
175 » » » if debug { 184 » if debug && err != nil {
176 » » » » log.Println("loop exit", err) 185 » » log.Println("loop exit", err)
177 » » » }
178 » » » break
179 » » }
180 } 186 }
181 187
182 for _, ch := range m.chanList.dropAll() { 188 for _, ch := range m.chanList.dropAll() {
183 ch.mu.Lock() 189 ch.mu.Lock()
184 ch.sentClose = true 190 ch.sentClose = true
185 ch.mu.Unlock() 191 ch.mu.Unlock()
186 ch.pending.eof() 192 ch.pending.eof()
187 ch.extPending.eof() 193 ch.extPending.eof()
194 close(ch.incomingRequests)
188 // ch.msg is otherwise only called from onePacket, so 195 // ch.msg is otherwise only called from onePacket, so
189 // this is safe. 196 // this is safe.
190 close(ch.pendingRequests)
191 close(ch.msg) 197 close(ch.msg)
192 } 198 }
193 199
194 » close(m.openedChans) 200 » close(m.incomingChannels)
195 » close(m.globalReceived) 201 » close(m.incomingRequests)
196 close(m.globalResponses) 202 close(m.globalResponses)
197 » // TODO(hanwen): should we close packetConn? 203
198 » // m.conn.Close() 204 » m.conn.Close()
199 return err 205 return err
200 } 206 }
201 207
202 var debug = false 208 // onePacket reads and processes one packet.
203
204 func (m *mux) onePacket() error { 209 func (m *mux) onePacket() error {
205 packet, err := m.conn.readPacket() 210 packet, err := m.conn.readPacket()
206 if err != nil { 211 if err != nil {
207 return err 212 return err
208 } 213 }
209 214
210 if debug { 215 if debug {
211 p, _ := decode(packet) 216 p, _ := decode(packet)
212 » » log.Printf("decoding(%c): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet)) 217 » » log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
213 } 218 }
214 219
215 switch packet[0] { 220 switch packet[0] {
216 case msgDisconnect: 221 case msgDisconnect:
217 return m.handleDisconnect(packet) 222 return m.handleDisconnect(packet)
218 case msgChannelOpen: 223 case msgChannelOpen:
219 return m.handleChannelOpen(packet) 224 return m.handleChannelOpen(packet)
220 case msgGlobalRequest, msgRequestSuccess, msgRequestFailure: 225 case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
221 return m.handleGlobalPacket(packet) 226 return m.handleGlobalPacket(packet)
222 } 227 }
(...skipping 26 matching lines...) Expand all
249 } 254 }
250 255
251 func (m *mux) handleGlobalPacket(packet []byte) error { 256 func (m *mux) handleGlobalPacket(packet []byte) error {
252 msg, err := decode(packet) 257 msg, err := decode(packet)
253 if err != nil { 258 if err != nil {
254 return err 259 return err
255 } 260 }
256 261
257 switch msg := msg.(type) { 262 switch msg := msg.(type) {
258 case *globalRequestMsg: 263 case *globalRequestMsg:
259 » » m.globalReceived <- &ChannelRequest{ 264 » » m.incomingRequests <- &ChannelRequest{
260 msg.Type, 265 msg.Type,
261 msg.WantReply, 266 msg.WantReply,
262 msg.Data, 267 msg.Data,
263 } 268 }
264 case *globalRequestSuccessMsg, *globalRequestFailureMsg: 269 case *globalRequestSuccessMsg, *globalRequestFailureMsg:
265 m.globalResponses <- msg 270 m.globalResponses <- msg
266 default: 271 default:
267 panic(fmt.Sprintf("not a global message %#v", msg)) 272 panic(fmt.Sprintf("not a global message %#v", msg))
268 } 273 }
269 274
270 return nil 275 return nil
271 } 276 }
272 277
273 const minPacketLength = 0 278 const minPacketLength = 0
274 279
280 // handleChannelOpen schedules a channel to be Accept()ed.
275 func (m *mux) handleChannelOpen(packet []byte) error { 281 func (m *mux) handleChannelOpen(packet []byte) error {
276 var msg channelOpenMsg 282 var msg channelOpenMsg
277 if err := unmarshal(&msg, packet, msgChannelOpen); err != nil { 283 if err := unmarshal(&msg, packet, msgChannelOpen); err != nil {
278 return err 284 return err
279 } 285 }
280 286
281 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 { 287 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
282 failMsg := channelOpenFailureMsg{ 288 failMsg := channelOpenFailureMsg{
283 PeersId: msg.PeersId, 289 PeersId: msg.PeersId,
284 Reason: ConnectionFailed, 290 Reason: ConnectionFailed,
285 Message: "invalid request", 291 Message: "invalid request",
286 Language: "en_US.UTF-8", 292 Language: "en_US.UTF-8",
287 } 293 }
288 return m.sendMessage(msgChannelOpenFailure, failMsg) 294 return m.sendMessage(msgChannelOpenFailure, failMsg)
289 } 295 }
290 296
291 » c := newChannel(m.conn, msg.ChanType, msg.TypeSpecificData) 297 » c := newChannel(msg.ChanType, msg.TypeSpecificData)
292 c.mux = m 298 c.mux = m
293 c.remoteId = msg.PeersId 299 c.remoteId = msg.PeersId
294 c.maxPacket = msg.MaxPacketSize 300 c.maxPacket = msg.MaxPacketSize
295 c.remoteWin.add(msg.PeersWindow) 301 c.remoteWin.add(msg.PeersWindow)
296 c.myWindow = defaultWindowSize 302 c.myWindow = defaultWindowSize
297 c.localId = m.chanList.add(c) 303 c.localId = m.chanList.add(c)
298 » m.openedChans <- c 304 » m.incomingChannels <- c
299 return nil 305 return nil
300 } 306 }
301 307
302 type OpenChannelFailed struct { 308 // OpenChannelError is returned the other side rejects our OpenChannel
309 // request.
310 type OpenChannelError struct {
303 Reason RejectionReason 311 Reason RejectionReason
304 Message string 312 Message string
305 } 313 }
306 314
307 func (e *OpenChannelFailed) Error() string { 315 func (e *OpenChannelError) Error() string {
308 return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message) 316 return fmt.Sprintf("ssh: rejected: %s (%s)", e.Reason, e.Message)
309 } 317 }
310 318
311 // Opens an outgoing channel to the other side. 319 // OpenChannel asks for a new channel. If the other side rejects, it
312 func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, error) { 320 // returns a *OpenChannelError.
313 » ch := newChannel(m.conn, chanType, extra) 321 func (m *mux) OpenChannel(chanType string, extra []byte) (*channel, error) {
322 » ch := newChannel(chanType, extra)
314 ch.mux = m 323 ch.mux = m
324
325 // As per RFC 4253 6.1, 32k is also the minimum.
315 ch.maxPacket = 1 << 15 326 ch.maxPacket = 1 << 15
316 ch.myWindow = defaultWindowSize 327 ch.myWindow = defaultWindowSize
317 ch.localId = m.chanList.add(ch) 328 ch.localId = m.chanList.add(ch)
318 329
319 open := channelOpenMsg{ 330 open := channelOpenMsg{
320 ChanType: chanType, 331 ChanType: chanType,
321 PeersWindow: ch.myWindow, 332 PeersWindow: ch.myWindow,
322 MaxPacketSize: ch.maxPacket, 333 MaxPacketSize: ch.maxPacket,
323 TypeSpecificData: extra, 334 TypeSpecificData: extra,
324 PeersId: ch.localId, 335 PeersId: ch.localId,
325 } 336 }
326 if err := m.sendMessage(msgChannelOpen, open); err != nil { 337 if err := m.sendMessage(msgChannelOpen, open); err != nil {
327 return nil, err 338 return nil, err
328 } 339 }
329 340
330 switch msg := (<-ch.msg).(type) { 341 switch msg := (<-ch.msg).(type) {
331 case *channelOpenConfirmMsg: 342 case *channelOpenConfirmMsg:
332 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1< <31 { 343 if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1< <31 {
333 return nil, errors.New("ssh: invalid MaxPacketSize from peer") 344 return nil, errors.New("ssh: invalid MaxPacketSize from peer")
334 } 345 }
335 // fixup remoteId field 346 // fixup remoteId field
336 ch.remoteId = msg.MyId 347 ch.remoteId = msg.MyId
337 ch.maxPacket = msg.MaxPacketSize 348 ch.maxPacket = msg.MaxPacketSize
338 ch.remoteWin.add(msg.MyWindow) 349 ch.remoteWin.add(msg.MyWindow)
339 ch.decided = true 350 ch.decided = true
340 return ch, nil 351 return ch, nil
341 case *channelOpenFailureMsg: 352 case *channelOpenFailureMsg:
342 m.chanList.remove(open.PeersId) 353 m.chanList.remove(open.PeersId)
343 » » // What type is appropriate? OpenChannelFailed, 354 » » return nil, &OpenChannelError{msg.Reason, msg.Message}
344 » » // *OpenChannelFailed? 355 » default:
345 » » return nil, &OpenChannelFailed{msg.Reason, msg.Message} 356 » » return nil, fmt.Errorf("ssh: unexpected packet %T", msg)
346 » } 357 » }
347 » return nil, errors.New("ssh: unexpected packet") 358 }
348 }
349
350 func (m *mux) Accept() (Channel, error) {
351 » c, ok := <-m.openedChans
352 » if !ok {
353 » » return nil, io.EOF
354 » }
355 » return c, nil
356 }
LEFTRIGHT

Powered by Google App Engine
RSS Feeds Recent Issues | This issue
This is Rietveld f62528b