Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(1105)

Delta Between Two Patch Sets: ssh/server.go

Issue 14494058: code review 14494058: go.crypto/ssh: support rekeying in both directions. (Closed)
Left Patch Set: diff -r 5ff5636e18c9 https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Right Patch Set: diff -r cd1eea1eb828 https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
Left: Side by side diff | Download
Right: Side by side diff | Download
LEFTRIGHT
1 // Copyright 2011 The Go Authors. All rights reserved. 1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style 2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file. 3 // license that can be found in the LICENSE file.
4 4
5 package ssh 5 package ssh
6 6
7 import ( 7 import (
8 "bytes" 8 "bytes"
9 "crypto/rand" 9 "crypto/rand"
10 "encoding/binary" 10 "encoding/binary"
(...skipping 78 matching lines...) Expand 10 before | Expand all | Expand 10 after
89 type cachedPubKey struct { 89 type cachedPubKey struct {
90 user, algo string 90 user, algo string
91 pubKey []byte 91 pubKey []byte
92 result bool 92 result bool
93 } 93 }
94 94
95 const maxCachedPubKeys = 16 95 const maxCachedPubKeys = 16
96 96
97 // A ServerConn represents an incoming connection. 97 // A ServerConn represents an incoming connection.
98 type ServerConn struct { 98 type ServerConn struct {
99 » *handshakeTransport 99 » transport *handshakeTransport
100 » config *ServerConfig 100 » config *ServerConfig
101 sshConn 101 sshConn
102 102
103 channels map[uint32]*serverChan 103 channels map[uint32]*serverChan
104 nextChanId uint32 104 nextChanId uint32
105 105
106 // lock protects err and channels. 106 // lock protects err and channels.
107 lock sync.Mutex 107 lock sync.Mutex
108 err error 108 err error
109 109
110 // cachedPubKeys contains the cache results of tests for public keys. 110 // cachedPubKeys contains the cache results of tests for public keys.
(...skipping 12 matching lines...) Expand all
123 ClientVersion []byte 123 ClientVersion []byte
124 124
125 // Our version. 125 // Our version.
126 serverVersion []byte 126 serverVersion []byte
127 } 127 }
128 128
129 // Server returns a new SSH server connection 129 // Server returns a new SSH server connection
130 // using c as the underlying transport. 130 // using c as the underlying transport.
131 func Server(c net.Conn, config *ServerConfig) *ServerConn { 131 func Server(c net.Conn, config *ServerConfig) *ServerConn {
132 return &ServerConn{ 132 return &ServerConn{
133 » » sshConn: sshConn{c}, 133 » » sshConn: sshConn{c, c},
134 channels: make(map[uint32]*serverChan), 134 channels: make(map[uint32]*serverChan),
135 config: config, 135 config: config,
136 } 136 }
137 } 137 }
138 138
139 // signAndMarshal signs the data with the appropriate algorithm, 139 // signAndMarshal signs the data with the appropriate algorithm,
140 // and serializes the result in SSH wire format. 140 // and serializes the result in SSH wire format.
141 func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) { 141 func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
142 sig, err := k.Sign(rand, data) 142 sig, err := k.Sign(rand, data)
143 if err != nil { 143 if err != nil {
144 return nil, err 144 return nil, err
145 } 145 }
146 146
147 return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil 147 return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
148 } 148 }
149 149
150 // Handshake performs an SSH transport and client authentication on the given Se rverConn. 150 // Handshake performs an SSH transport and client authentication on the given Se rverConn.
151 func (s *ServerConn) Handshake() error { 151 func (s *ServerConn) Handshake() error {
152 var err error 152 var err error
153 s.serverVersion = []byte(packageVersion) 153 s.serverVersion = []byte(packageVersion)
154 » s.ClientVersion, err = exchangeVersions(s.sshConn.Conn, s.serverVersion) 154 » s.ClientVersion, err = exchangeVersions(s.sshConn.conn, s.serverVersion)
155 if err != nil { 155 if err != nil {
156 return err 156 return err
157 } 157 }
158 158
159 » tr := newTransport(s.sshConn.Conn, s.config.rand(), false /* not client */) 159 » tr := newTransport(s.sshConn.conn, s.config.rand(), false /* not client */)
160 » hs := newServerTransport(tr, s.ClientVersion, s.serverVersion, s.config) 160 » s.transport = newServerTransport(tr, s.ClientVersion, s.serverVersion, s .config)
161 » s.handshakeTransport = hs 161
162 162 » if err := s.transport.requestKeyChange(); err != nil {
163 » if _, _, err := s.sendKexInit(); err != nil { 163 » » return err
164 » » return err 164 » }
165 » } 165
166 166 » if packet, err := s.transport.readPacket(); err != nil {
167 » // ignore newkeys message. 167 » » return err
168 » if _, err := s.readPacket(); err != nil { 168 » } else if packet[0] != msgNewKeys {
169 » » return err 169 » » return UnexpectedMessageError{msgNewKeys, packet[0]}
170 } 170 }
171 171
172 var packet []byte 172 var packet []byte
173 » if packet, err = s.readPacket(); err != nil { 173 » if packet, err = s.transport.readPacket(); err != nil {
174 » » return err 174 » » return err
175 » } 175 » }
176
176 var serviceRequest serviceRequestMsg 177 var serviceRequest serviceRequestMsg
177 if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != n il { 178 if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != n il {
178 return err 179 return err
179 } 180 }
180 if serviceRequest.Service != serviceUserAuth { 181 if serviceRequest.Service != serviceUserAuth {
181 return errors.New("ssh: requested service '" + serviceRequest.Se rvice + "' before authenticating") 182 return errors.New("ssh: requested service '" + serviceRequest.Se rvice + "' before authenticating")
182 } 183 }
183 serviceAccept := serviceAcceptMsg{ 184 serviceAccept := serviceAcceptMsg{
184 Service: serviceUserAuth, 185 Service: serviceUserAuth,
185 } 186 }
186 » if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil { 187 » if err := s.transport.writePacket(marshal(msgServiceAccept, serviceAccep t)); err != nil {
187 » » return err 188 » » return err
188 » } 189 » }
189 190
190 » if err = s.authenticate(s.handshakeTransport.SessionID()); err != nil { 191 » if err := s.authenticate(); err != nil {
191 return err 192 return err
192 } 193 }
193 return err 194 return err
194 } 195 }
195 196
196 func isAcceptableAlgo(algo string) bool { 197 func isAcceptableAlgo(algo string) bool {
197 switch algo { 198 switch algo {
198 case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoEC DSA521, 199 case KeyAlgoRSA, KeyAlgoDSA, KeyAlgoECDSA256, KeyAlgoECDSA384, KeyAlgoEC DSA521,
199 CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECD SA384v01, CertAlgoECDSA521v01: 200 CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, CertAlgoECD SA384v01, CertAlgoECDSA521v01:
200 return true 201 return true
(...skipping 21 matching lines...) Expand all
222 pubKey: make([]byte, len(pubKey)), 223 pubKey: make([]byte, len(pubKey)),
223 result: result, 224 result: result,
224 } 225 }
225 copy(c.pubKey, pubKey) 226 copy(c.pubKey, pubKey)
226 s.cachedPubKeys = append(s.cachedPubKeys, c) 227 s.cachedPubKeys = append(s.cachedPubKeys, c)
227 } 228 }
228 229
229 return result 230 return result
230 } 231 }
231 232
232 func (s *ServerConn) authenticate(H []byte) error { 233 func (s *ServerConn) authenticate() error {
233 var userAuthReq userAuthRequestMsg 234 var userAuthReq userAuthRequestMsg
234 var err error 235 var err error
235 var packet []byte 236 var packet []byte
236 237
237 userAuthLoop: 238 userAuthLoop:
238 for { 239 for {
239 » » if packet, err = s.readPacket(); err != nil { 240 » » if packet, err = s.transport.readPacket(); err != nil {
240 return err 241 return err
241 } 242 }
242 if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); er r != nil { 243 if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); er r != nil {
243 return err 244 return err
244 } 245 }
245 246
246 if userAuthReq.Service != serviceSSH { 247 if userAuthReq.Service != serviceSSH {
247 return errors.New("ssh: client attempted to negotiate fo r unknown service: " + userAuthReq.Service) 248 return errors.New("ssh: client attempted to negotiate fo r unknown service: " + userAuthReq.Service)
248 } 249 }
249 250
(...skipping 53 matching lines...) Expand 10 before | Expand all | Expand 10 after
303 // The client can query if the given public key 304 // The client can query if the given public key
304 // would be okay. 305 // would be okay.
305 if len(payload) > 0 { 306 if len(payload) > 0 {
306 return ParseError{msgUserAuthRequest} 307 return ParseError{msgUserAuthRequest}
307 } 308 }
308 if s.testPubKey(userAuthReq.User, algo, pubKey) { 309 if s.testPubKey(userAuthReq.User, algo, pubKey) {
309 okMsg := userAuthPubKeyOkMsg{ 310 okMsg := userAuthPubKeyOkMsg{
310 Algo: algo, 311 Algo: algo,
311 PubKey: string(pubKey), 312 PubKey: string(pubKey),
312 } 313 }
313 » » » » » if err = s.writePacket(marshal(msgUserAu thPubKeyOk, okMsg)); err != nil { 314 » » » » » if err = s.transport.writePacket(marshal (msgUserAuthPubKeyOk, okMsg)); err != nil {
314 return err 315 return err
315 } 316 }
316 continue userAuthLoop 317 continue userAuthLoop
317 } 318 }
318 } else { 319 } else {
319 sig, payload, ok := parseSignature(payload) 320 sig, payload, ok := parseSignature(payload)
320 if !ok || len(payload) > 0 { 321 if !ok || len(payload) > 0 {
321 return ParseError{msgUserAuthRequest} 322 return ParseError{msgUserAuthRequest}
322 } 323 }
323 // Ensure the public key algo and signature algo 324 // Ensure the public key algo and signature algo
324 // are supported. Compare the private key 325 // are supported. Compare the private key
325 // algorithm name that corresponds to algo with 326 // algorithm name that corresponds to algo with
326 // sig.Format. This is usually the same, but 327 // sig.Format. This is usually the same, but
327 // for certs, the names differ. 328 // for certs, the names differ.
328 if !isAcceptableAlgo(algo) || !isAcceptableAlgo( sig.Format) || pubAlgoToPrivAlgo(algo) != sig.Format { 329 if !isAcceptableAlgo(algo) || !isAcceptableAlgo( sig.Format) || pubAlgoToPrivAlgo(algo) != sig.Format {
329 break 330 break
330 } 331 }
331 » » » » signedData := buildDataSignedForAuth(H, userAuth Req, algoBytes, pubKey) 332 » » » » signedData := buildDataSignedForAuth(s.transport .getSessionID(), userAuthReq, algoBytes, pubKey)
332 key, _, ok := ParsePublicKey(pubKey) 333 key, _, ok := ParsePublicKey(pubKey)
333 if !ok { 334 if !ok {
334 return ParseError{msgUserAuthRequest} 335 return ParseError{msgUserAuthRequest}
335 } 336 }
336 337
337 if !key.Verify(signedData, sig.Blob) { 338 if !key.Verify(signedData, sig.Blob) {
339 // TODO(hanwen): fix this
340 // message. It's not a parse
341 // error
338 return ParseError{msgUserAuthRequest} 342 return ParseError{msgUserAuthRequest}
339 } 343 }
340 // TODO(jmpittman): Implement full validation fo r certificates. 344 // TODO(jmpittman): Implement full validation fo r certificates.
341 s.User = userAuthReq.User 345 s.User = userAuthReq.User
342 if s.testPubKey(userAuthReq.User, algo, pubKey) { 346 if s.testPubKey(userAuthReq.User, algo, pubKey) {
343 break userAuthLoop 347 break userAuthLoop
344 } 348 }
345 } 349 }
346 } 350 }
347 351
348 var failureMsg userAuthFailureMsg 352 var failureMsg userAuthFailureMsg
349 if s.config.PasswordCallback != nil { 353 if s.config.PasswordCallback != nil {
350 failureMsg.Methods = append(failureMsg.Methods, "passwor d") 354 failureMsg.Methods = append(failureMsg.Methods, "passwor d")
351 } 355 }
352 if s.config.PublicKeyCallback != nil { 356 if s.config.PublicKeyCallback != nil {
353 failureMsg.Methods = append(failureMsg.Methods, "publick ey") 357 failureMsg.Methods = append(failureMsg.Methods, "publick ey")
354 } 358 }
355 if s.config.KeyboardInteractiveCallback != nil { 359 if s.config.KeyboardInteractiveCallback != nil {
356 failureMsg.Methods = append(failureMsg.Methods, "keyboar d-interactive") 360 failureMsg.Methods = append(failureMsg.Methods, "keyboar d-interactive")
357 } 361 }
358 362
359 if len(failureMsg.Methods) == 0 { 363 if len(failureMsg.Methods) == 0 {
360 return errors.New("ssh: no authentication methods config ured but NoClientAuth is also false") 364 return errors.New("ssh: no authentication methods config ured but NoClientAuth is also false")
361 } 365 }
362 366
363 » » if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil { 367 » » if err = s.transport.writePacket(marshal(msgUserAuthFailure, fai lureMsg)); err != nil {
364 return err 368 return err
365 } 369 }
366 } 370 }
367 371
368 packet = []byte{msgUserAuthSuccess} 372 packet = []byte{msgUserAuthSuccess}
369 » if err = s.writePacket(packet); err != nil { 373 » if err = s.transport.writePacket(packet); err != nil {
370 return err 374 return err
371 } 375 }
372 376
373 return nil 377 return nil
374 } 378 }
375 379
376 // sshClientKeyboardInteractive implements a ClientKeyboardInteractive by 380 // sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
377 // asking the client on the other side of a ServerConn. 381 // asking the client on the other side of a ServerConn.
378 type sshClientKeyboardInteractive struct { 382 type sshClientKeyboardInteractive struct {
379 *ServerConn 383 *ServerConn
380 } 384 }
381 385
382 func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest ions []string, echos []bool) (answers []string, err error) { 386 func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest ions []string, echos []bool) (answers []string, err error) {
383 if len(questions) != len(echos) { 387 if len(questions) != len(echos) {
384 return nil, errors.New("ssh: echos and questions must have equal length") 388 return nil, errors.New("ssh: echos and questions must have equal length")
385 } 389 }
386 390
387 var prompts []byte 391 var prompts []byte
388 for i := range questions { 392 for i := range questions {
389 prompts = appendString(prompts, questions[i]) 393 prompts = appendString(prompts, questions[i])
390 prompts = appendBool(prompts, echos[i]) 394 prompts = appendBool(prompts, echos[i])
391 } 395 }
392 396
393 » if err := c.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequ estMsg{ 397 » if err := c.transport.writePacket(marshal(msgUserAuthInfoRequest, userAu thInfoRequestMsg{
394 Instruction: instruction, 398 Instruction: instruction,
395 NumPrompts: uint32(len(questions)), 399 NumPrompts: uint32(len(questions)),
396 Prompts: prompts, 400 Prompts: prompts,
397 })); err != nil { 401 })); err != nil {
398 return nil, err 402 return nil, err
399 } 403 }
400 404
401 » packet, err := c.readPacket() 405 » packet, err := c.transport.readPacket()
402 if err != nil { 406 if err != nil {
403 return nil, err 407 return nil, err
404 } 408 }
405 if packet[0] != msgUserAuthInfoResponse { 409 if packet[0] != msgUserAuthInfoResponse {
406 return nil, UnexpectedMessageError{msgUserAuthInfoResponse, pack et[0]} 410 return nil, UnexpectedMessageError{msgUserAuthInfoResponse, pack et[0]}
407 } 411 }
408 packet = packet[1:] 412 packet = packet[1:]
409 413
410 n, packet, ok := parseUint32(packet) 414 n, packet, ok := parseUint32(packet)
411 if !ok || int(n) != len(questions) { 415 if !ok || int(n) != len(questions) {
412 return nil, &ParseError{msgUserAuthInfoResponse} 416 return nil, &ParseError{msgUserAuthInfoResponse}
413 } 417 }
414 418
415 for i := uint32(0); i < n; i++ { 419 for i := uint32(0); i < n; i++ {
416 ans, rest, ok := parseString(packet) 420 ans, rest, ok := parseString(packet)
417 if !ok { 421 if !ok {
418 return nil, &ParseError{msgUserAuthInfoResponse} 422 return nil, &ParseError{msgUserAuthInfoResponse}
419 } 423 }
420 424
421 answers = append(answers, string(ans)) 425 answers = append(answers, string(ans))
422 packet = rest 426 packet = rest
423 } 427 }
424 if len(packet) != 0 { 428 if len(packet) != 0 {
425 return nil, errors.New("ssh: junk at end of message") 429 return nil, errors.New("ssh: junk at end of message")
426 } 430 }
427 431
428 return answers, nil 432 return answers, nil
429 }
430
431 // Close shuts down the underlying network connection.
432 func (s *ServerConn) Close() error {
433 return s.sshConn.Conn.Close()
434 } 433 }
435 434
436 const defaultWindowSize = 32768 435 const defaultWindowSize = 32768
437 436
438 // Accept reads and processes messages on a ServerConn. It must be called 437 // Accept reads and processes messages on a ServerConn. It must be called
439 // in order to demultiplex messages to any resulting Channels. 438 // in order to demultiplex messages to any resulting Channels.
440 func (s *ServerConn) Accept() (Channel, error) { 439 func (s *ServerConn) Accept() (Channel, error) {
441 // TODO(dfc) s.lock is not held here so visibility of s.err is not guara nteed. 440 // TODO(dfc) s.lock is not held here so visibility of s.err is not guara nteed.
442 if s.err != nil { 441 if s.err != nil {
443 return nil, s.err 442 return nil, s.err
444 } 443 }
445 444
446 for { 445 for {
447 » » packet, err := s.readPacket() 446 » » packet, err := s.transport.readPacket()
448 if err != nil { 447 if err != nil {
449
450 s.lock.Lock() 448 s.lock.Lock()
451 s.err = err 449 s.err = err
452 s.lock.Unlock() 450 s.lock.Unlock()
453 451
454 // TODO(dfc) s.lock protects s.channels but isn't being held here. 452 // TODO(dfc) s.lock protects s.channels but isn't being held here.
455 for _, c := range s.channels { 453 for _, c := range s.channels {
456 c.setDead() 454 c.setDead()
457 c.handleData(nil) 455 c.handleData(nil)
458 } 456 }
459 457
(...skipping 26 matching lines...) Expand all
486 if err != nil { 484 if err != nil {
487 return nil, err 485 return nil, err
488 } 486 }
489 switch msg := decoded.(type) { 487 switch msg := decoded.(type) {
490 case *channelOpenMsg: 488 case *channelOpenMsg:
491 if msg.MaxPacketSize < minPacketLength || msg.Ma xPacketSize > 1<<31 { 489 if msg.MaxPacketSize < minPacketLength || msg.Ma xPacketSize > 1<<31 {
492 return nil, errors.New("ssh: invalid Max PacketSize from peer") 490 return nil, errors.New("ssh: invalid Max PacketSize from peer")
493 } 491 }
494 c := &serverChan{ 492 c := &serverChan{
495 channel: channel{ 493 channel: channel{
496 » » » » » » packetConn: s, 494 » » » » » » packetConn: s.transport,
497 remoteId: msg.PeersId, 495 remoteId: msg.PeersId,
498 remoteWin: window{Cond: newCond ()}, 496 remoteWin: window{Cond: newCond ()},
499 maxPacket: msg.MaxPacketSize, 497 maxPacket: msg.MaxPacketSize,
500 }, 498 },
501 chanType: msg.ChanType, 499 chanType: msg.ChanType,
502 extraData: msg.TypeSpecificData, 500 extraData: msg.TypeSpecificData,
503 myWindow: defaultWindowSize, 501 myWindow: defaultWindowSize,
504 serverConn: s, 502 serverConn: s,
505 cond: newCond(), 503 cond: newCond(),
506 pendingData: make([]byte, defaultWindowS ize), 504 pendingData: make([]byte, defaultWindowS ize),
(...skipping 41 matching lines...) Expand 10 before | Expand all | Expand 10 after
548 c, ok := s.channels[msg.PeersId] 546 c, ok := s.channels[msg.PeersId]
549 if !ok { 547 if !ok {
550 s.lock.Unlock() 548 s.lock.Unlock()
551 continue 549 continue
552 } 550 }
553 c.handlePacket(msg) 551 c.handlePacket(msg)
554 s.lock.Unlock() 552 s.lock.Unlock()
555 553
556 case *globalRequestMsg: 554 case *globalRequestMsg:
557 if msg.WantReply { 555 if msg.WantReply {
558 » » » » » if err := s.writePacket([]byte{msgReques tFailure}); err != nil { 556 » » » » » if err := s.transport.writePacket([]byte {msgRequestFailure}); err != nil {
559 return nil, err 557 return nil, err
560 } 558 }
561 } 559 }
562 560
563 case *disconnectMsg: 561 case *disconnectMsg:
564 return nil, io.EOF 562 return nil, io.EOF
565 default: 563 default:
566 // Unknown message. Ignore. 564 // Unknown message. Ignore.
567 } 565 }
568 } 566 }
(...skipping 34 matching lines...) Expand 10 before | Expand all | Expand 10 after
603 func Listen(network, addr string, config *ServerConfig) (*Listener, error) { 601 func Listen(network, addr string, config *ServerConfig) (*Listener, error) {
604 l, err := net.Listen(network, addr) 602 l, err := net.Listen(network, addr)
605 if err != nil { 603 if err != nil {
606 return nil, err 604 return nil, err
607 } 605 }
608 return &Listener{ 606 return &Listener{
609 l, 607 l,
610 config, 608 config,
611 }, nil 609 }, nil
612 } 610 }
LEFTRIGHT

Powered by Google App Engine
RSS Feeds Recent Issues | This issue
This is Rietveld f62528b