Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(171)

Delta Between Two Patch Sets: ssh/server.go

Issue 15870044: code review 15870044: go.crypto/ssh: in {Server,Client}Conn, read session ID ... (Closed)
Left Patch Set: diff -r e3c12cab1b35 https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Right Patch Set: diff -r 213a06a7ce81 https://code.google.com/p/go.crypto Created 10 years, 5 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
Left: Side by side diff | Download
Right: Side by side diff | Download
« no previous file with change/comment | « ssh/client_auth.go ('k') | no next file » | no next file with change/comment »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
LEFTRIGHT
1 // Copyright 2011 The Go Authors. All rights reserved. 1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style 2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file. 3 // license that can be found in the LICENSE file.
4 4
5 package ssh 5 package ssh
6 6
7 import ( 7 import (
8 "bytes" 8 "bytes"
9 "crypto/rand" 9 "crypto/rand"
10 "encoding/binary" 10 "encoding/binary"
(...skipping 79 matching lines...) Expand 10 before | Expand all | Expand 10 after
90 type cachedPubKey struct { 90 type cachedPubKey struct {
91 user, algo string 91 user, algo string
92 pubKey []byte 92 pubKey []byte
93 result bool 93 result bool
94 } 94 }
95 95
96 const maxCachedPubKeys = 16 96 const maxCachedPubKeys = 16
97 97
98 // A ServerConn represents an incoming connection. 98 // A ServerConn represents an incoming connection.
99 type ServerConn struct { 99 type ServerConn struct {
100 » *transport 100 » transport *transport
101 » config *ServerConfig 101 » config *ServerConfig
102 102
103 channels map[uint32]*serverChan 103 channels map[uint32]*serverChan
104 nextChanId uint32 104 nextChanId uint32
105 105
106 // lock protects err and channels. 106 // lock protects err and channels.
107 lock sync.Mutex 107 lock sync.Mutex
108 err error 108 err error
109 109
110 // cachedPubKeys contains the cache results of tests for public keys. 110 // cachedPubKeys contains the cache results of tests for public keys.
111 // Since SSH clients will query whether a public key is acceptable 111 // Since SSH clients will query whether a public key is acceptable
(...skipping 28 matching lines...) Expand all
140 // and serializes the result in SSH wire format. 140 // and serializes the result in SSH wire format.
141 func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) { 141 func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
142 sig, err := k.Sign(rand, data) 142 sig, err := k.Sign(rand, data)
143 if err != nil { 143 if err != nil {
144 return nil, err 144 return nil, err
145 } 145 }
146 146
147 return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil 147 return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
148 } 148 }
149 149
150 // Close closes the connection.
151 func (s *ServerConn) Close() error { return s.transport.Close() }
152
153 // LocalAddr returns the local network address.
154 func (c *ServerConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
155
156 // RemoteAddr returns the remote network address.
157 func (c *ServerConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
158
150 // Handshake performs an SSH transport and client authentication on the given Se rverConn. 159 // Handshake performs an SSH transport and client authentication on the given Se rverConn.
151 func (s *ServerConn) Handshake() error { 160 func (s *ServerConn) Handshake() error {
152 var err error 161 var err error
153 s.serverVersion = []byte(packageVersion) 162 s.serverVersion = []byte(packageVersion)
154 s.ClientVersion, err = exchangeVersions(s.transport.Conn, s.serverVersio n) 163 s.ClientVersion, err = exchangeVersions(s.transport.Conn, s.serverVersio n)
155 if err != nil { 164 if err != nil {
156 return err 165 return err
157 } 166 }
158 if err := s.clientInitHandshake(nil, nil); err != nil { 167 if err := s.clientInitHandshake(nil, nil); err != nil {
159 return err 168 return err
160 } 169 }
161 170
162 var packet []byte 171 var packet []byte
163 » if packet, err = s.readPacket(); err != nil { 172 » if packet, err = s.transport.readPacket(); err != nil {
164 return err 173 return err
165 } 174 }
166 var serviceRequest serviceRequestMsg 175 var serviceRequest serviceRequestMsg
167 if err := unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil { 176 if err := unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
168 return err 177 return err
169 } 178 }
170 if serviceRequest.Service != serviceUserAuth { 179 if serviceRequest.Service != serviceUserAuth {
171 return errors.New("ssh: requested service '" + serviceRequest.Se rvice + "' before authenticating") 180 return errors.New("ssh: requested service '" + serviceRequest.Se rvice + "' before authenticating")
172 } 181 }
173 serviceAccept := serviceAcceptMsg{ 182 serviceAccept := serviceAcceptMsg{
174 Service: serviceUserAuth, 183 Service: serviceUserAuth,
175 } 184 }
176 » if err := s.writePacket(marshal(msgServiceAccept, serviceAccept)); err ! = nil { 185 » if err := s.transport.writePacket(marshal(msgServiceAccept, serviceAccep t)); err != nil {
177 return err 186 return err
178 } 187 }
179 188
180 if err := s.authenticate(); err != nil { 189 if err := s.authenticate(); err != nil {
181 return err 190 return err
182 } 191 }
183 return err 192 return err
184 } 193 }
185 194
186 func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni tPacket []byte) (err error) { 195 func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni tPacket []byte) (err error) {
187 serverKexInit := kexInitMsg{ 196 serverKexInit := kexInitMsg{
188 KexAlgos: s.config.Crypto.kexes(), 197 KexAlgos: s.config.Crypto.kexes(),
189 CiphersClientServer: s.config.Crypto.ciphers(), 198 CiphersClientServer: s.config.Crypto.ciphers(),
190 CiphersServerClient: s.config.Crypto.ciphers(), 199 CiphersServerClient: s.config.Crypto.ciphers(),
191 MACsClientServer: s.config.Crypto.macs(), 200 MACsClientServer: s.config.Crypto.macs(),
192 MACsServerClient: s.config.Crypto.macs(), 201 MACsServerClient: s.config.Crypto.macs(),
193 CompressionClientServer: supportedCompressions, 202 CompressionClientServer: supportedCompressions,
194 CompressionServerClient: supportedCompressions, 203 CompressionServerClient: supportedCompressions,
195 } 204 }
196 for _, k := range s.config.hostKeys { 205 for _, k := range s.config.hostKeys {
197 serverKexInit.ServerHostKeyAlgos = append( 206 serverKexInit.ServerHostKeyAlgos = append(
198 serverKexInit.ServerHostKeyAlgos, k.PublicKey().PublicKe yAlgo()) 207 serverKexInit.ServerHostKeyAlgos, k.PublicKey().PublicKe yAlgo())
199 } 208 }
200 209
201 serverKexInitPacket := marshal(msgKexInit, serverKexInit) 210 serverKexInitPacket := marshal(msgKexInit, serverKexInit)
202 » if err = s.writePacket(serverKexInitPacket); err != nil { 211 » if err = s.transport.writePacket(serverKexInitPacket); err != nil {
203 return 212 return
204 } 213 }
205 214
206 if clientKexInitPacket == nil { 215 if clientKexInitPacket == nil {
207 clientKexInit = new(kexInitMsg) 216 clientKexInit = new(kexInitMsg)
208 » » if clientKexInitPacket, err = s.readPacket(); err != nil { 217 » » if clientKexInitPacket, err = s.transport.readPacket(); err != n il {
209 return 218 return
210 } 219 }
211 if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexIni t); err != nil { 220 if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexIni t); err != nil {
212 return 221 return
213 } 222 }
214 } 223 }
215 224
216 algs := findAgreedAlgorithms(clientKexInit, &serverKexInit) 225 algs := findAgreedAlgorithms(clientKexInit, &serverKexInit)
217 if algs == nil { 226 if algs == nil {
218 return errors.New("ssh: no common algorithms") 227 return errors.New("ssh: no common algorithms")
219 } 228 }
220 229
221 if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0 ] { 230 if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0 ] {
222 // The client sent a Kex message for the wrong algorithm, 231 // The client sent a Kex message for the wrong algorithm,
223 // which we have to ignore. 232 // which we have to ignore.
224 » » if _, err = s.readPacket(); err != nil { 233 » » if _, err = s.transport.readPacket(); err != nil {
225 return 234 return
226 } 235 }
227 } 236 }
228 237
229 var hostKey Signer 238 var hostKey Signer
230 for _, k := range s.config.hostKeys { 239 for _, k := range s.config.hostKeys {
231 if algs.hostKey == k.PublicKey().PublicKeyAlgo() { 240 if algs.hostKey == k.PublicKey().PublicKeyAlgo() {
232 hostKey = k 241 hostKey = k
233 } 242 }
234 } 243 }
235 244
236 kex, ok := kexAlgoMap[algs.kex] 245 kex, ok := kexAlgoMap[algs.kex]
237 if !ok { 246 if !ok {
238 return fmt.Errorf("ssh: unexpected key exchange algorithm %v", a lgs.kex) 247 return fmt.Errorf("ssh: unexpected key exchange algorithm %v", a lgs.kex)
239 } 248 }
240 249
241 magics := handshakeMagics{ 250 magics := handshakeMagics{
242 serverVersion: s.serverVersion, 251 serverVersion: s.serverVersion,
243 clientVersion: s.ClientVersion, 252 clientVersion: s.ClientVersion,
244 serverKexInit: marshal(msgKexInit, serverKexInit), 253 serverKexInit: marshal(msgKexInit, serverKexInit),
245 clientKexInit: clientKexInitPacket, 254 clientKexInit: clientKexInitPacket,
246 } 255 }
247 » result, err := kex.Server(s, s.config.rand(), &magics, hostKey) 256 » result, err := kex.Server(s.transport, s.config.rand(), &magics, hostKey )
248 if err != nil { 257 if err != nil {
249 return err 258 return err
250 } 259 }
251 260
252 if err = s.transport.prepareKeyChange(algs, result); err != nil { 261 if err = s.transport.prepareKeyChange(algs, result); err != nil {
253 return err 262 return err
254 } 263 }
255 264
256 » if err = s.writePacket([]byte{msgNewKeys}); err != nil { 265 » if err = s.transport.writePacket([]byte{msgNewKeys}); err != nil {
257 return 266 return
258 } 267 }
259 » if packet, err := s.readPacket(); err != nil { 268 » if packet, err := s.transport.readPacket(); err != nil {
260 return err 269 return err
261 } else if packet[0] != msgNewKeys { 270 } else if packet[0] != msgNewKeys {
262 return UnexpectedMessageError{msgNewKeys, packet[0]} 271 return UnexpectedMessageError{msgNewKeys, packet[0]}
263 } 272 }
264 273
265 return 274 return
266 } 275 }
267 276
268 func isAcceptableAlgo(algo string) bool { 277 func isAcceptableAlgo(algo string) bool {
269 switch algo { 278 switch algo {
(...skipping 31 matching lines...) Expand 10 before | Expand all | Expand 10 after
301 return result 310 return result
302 } 311 }
303 312
304 func (s *ServerConn) authenticate() error { 313 func (s *ServerConn) authenticate() error {
305 var userAuthReq userAuthRequestMsg 314 var userAuthReq userAuthRequestMsg
306 var err error 315 var err error
307 var packet []byte 316 var packet []byte
308 317
309 userAuthLoop: 318 userAuthLoop:
310 for { 319 for {
311 » » if packet, err = s.readPacket(); err != nil { 320 » » if packet, err = s.transport.readPacket(); err != nil {
312 return err 321 return err
313 } 322 }
314 if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); er r != nil { 323 if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); er r != nil {
315 return err 324 return err
316 } 325 }
317 326
318 if userAuthReq.Service != serviceSSH { 327 if userAuthReq.Service != serviceSSH {
319 return errors.New("ssh: client attempted to negotiate fo r unknown service: " + userAuthReq.Service) 328 return errors.New("ssh: client attempted to negotiate fo r unknown service: " + userAuthReq.Service)
320 } 329 }
321 330
(...skipping 53 matching lines...) Expand 10 before | Expand all | Expand 10 after
375 // The client can query if the given public key 384 // The client can query if the given public key
376 // would be okay. 385 // would be okay.
377 if len(payload) > 0 { 386 if len(payload) > 0 {
378 return ParseError{msgUserAuthRequest} 387 return ParseError{msgUserAuthRequest}
379 } 388 }
380 if s.testPubKey(userAuthReq.User, algo, pubKey) { 389 if s.testPubKey(userAuthReq.User, algo, pubKey) {
381 okMsg := userAuthPubKeyOkMsg{ 390 okMsg := userAuthPubKeyOkMsg{
382 Algo: algo, 391 Algo: algo,
383 PubKey: string(pubKey), 392 PubKey: string(pubKey),
384 } 393 }
385 » » » » » if err = s.writePacket(marshal(msgUserAu thPubKeyOk, okMsg)); err != nil { 394 » » » » » if err = s.transport.writePacket(marshal (msgUserAuthPubKeyOk, okMsg)); err != nil {
386 return err 395 return err
387 } 396 }
388 continue userAuthLoop 397 continue userAuthLoop
389 } 398 }
390 } else { 399 } else {
391 sig, payload, ok := parseSignature(payload) 400 sig, payload, ok := parseSignature(payload)
392 if !ok || len(payload) > 0 { 401 if !ok || len(payload) > 0 {
393 return ParseError{msgUserAuthRequest} 402 return ParseError{msgUserAuthRequest}
394 } 403 }
395 // Ensure the public key algo and signature algo 404 // Ensure the public key algo and signature algo
(...skipping 29 matching lines...) Expand all
425 failureMsg.Methods = append(failureMsg.Methods, "publick ey") 434 failureMsg.Methods = append(failureMsg.Methods, "publick ey")
426 } 435 }
427 if s.config.KeyboardInteractiveCallback != nil { 436 if s.config.KeyboardInteractiveCallback != nil {
428 failureMsg.Methods = append(failureMsg.Methods, "keyboar d-interactive") 437 failureMsg.Methods = append(failureMsg.Methods, "keyboar d-interactive")
429 } 438 }
430 439
431 if len(failureMsg.Methods) == 0 { 440 if len(failureMsg.Methods) == 0 {
432 return errors.New("ssh: no authentication methods config ured but NoClientAuth is also false") 441 return errors.New("ssh: no authentication methods config ured but NoClientAuth is also false")
433 } 442 }
434 443
435 » » if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil { 444 » » if err = s.transport.writePacket(marshal(msgUserAuthFailure, fai lureMsg)); err != nil {
436 return err 445 return err
437 } 446 }
438 } 447 }
439 448
440 packet = []byte{msgUserAuthSuccess} 449 packet = []byte{msgUserAuthSuccess}
441 » if err = s.writePacket(packet); err != nil { 450 » if err = s.transport.writePacket(packet); err != nil {
442 return err 451 return err
443 } 452 }
444 453
445 return nil 454 return nil
446 } 455 }
447 456
448 // sshClientKeyboardInteractive implements a ClientKeyboardInteractive by 457 // sshClientKeyboardInteractive implements a ClientKeyboardInteractive by
449 // asking the client on the other side of a ServerConn. 458 // asking the client on the other side of a ServerConn.
450 type sshClientKeyboardInteractive struct { 459 type sshClientKeyboardInteractive struct {
451 *ServerConn 460 *ServerConn
452 } 461 }
453 462
454 func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest ions []string, echos []bool) (answers []string, err error) { 463 func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest ions []string, echos []bool) (answers []string, err error) {
455 if len(questions) != len(echos) { 464 if len(questions) != len(echos) {
456 return nil, errors.New("ssh: echos and questions must have equal length") 465 return nil, errors.New("ssh: echos and questions must have equal length")
457 } 466 }
458 467
459 var prompts []byte 468 var prompts []byte
460 for i := range questions { 469 for i := range questions {
461 prompts = appendString(prompts, questions[i]) 470 prompts = appendString(prompts, questions[i])
462 prompts = appendBool(prompts, echos[i]) 471 prompts = appendBool(prompts, echos[i])
463 } 472 }
464 473
465 » if err := c.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequ estMsg{ 474 » if err := c.transport.writePacket(marshal(msgUserAuthInfoRequest, userAu thInfoRequestMsg{
466 Instruction: instruction, 475 Instruction: instruction,
467 NumPrompts: uint32(len(questions)), 476 NumPrompts: uint32(len(questions)),
468 Prompts: prompts, 477 Prompts: prompts,
469 })); err != nil { 478 })); err != nil {
470 return nil, err 479 return nil, err
471 } 480 }
472 481
473 » packet, err := c.readPacket() 482 » packet, err := c.transport.readPacket()
474 if err != nil { 483 if err != nil {
475 return nil, err 484 return nil, err
476 } 485 }
477 if packet[0] != msgUserAuthInfoResponse { 486 if packet[0] != msgUserAuthInfoResponse {
478 return nil, UnexpectedMessageError{msgUserAuthInfoResponse, pack et[0]} 487 return nil, UnexpectedMessageError{msgUserAuthInfoResponse, pack et[0]}
479 } 488 }
480 packet = packet[1:] 489 packet = packet[1:]
481 490
482 n, packet, ok := parseUint32(packet) 491 n, packet, ok := parseUint32(packet)
483 if !ok || int(n) != len(questions) { 492 if !ok || int(n) != len(questions) {
(...skipping 20 matching lines...) Expand all
504 513
505 // Accept reads and processes messages on a ServerConn. It must be called 514 // Accept reads and processes messages on a ServerConn. It must be called
506 // in order to demultiplex messages to any resulting Channels. 515 // in order to demultiplex messages to any resulting Channels.
507 func (s *ServerConn) Accept() (Channel, error) { 516 func (s *ServerConn) Accept() (Channel, error) {
508 // TODO(dfc) s.lock is not held here so visibility of s.err is not guara nteed. 517 // TODO(dfc) s.lock is not held here so visibility of s.err is not guara nteed.
509 if s.err != nil { 518 if s.err != nil {
510 return nil, s.err 519 return nil, s.err
511 } 520 }
512 521
513 for { 522 for {
514 » » packet, err := s.readPacket() 523 » » packet, err := s.transport.readPacket()
515 if err != nil { 524 if err != nil {
516 525
517 s.lock.Lock() 526 s.lock.Lock()
518 s.err = err 527 s.err = err
519 s.lock.Unlock() 528 s.lock.Unlock()
520 529
521 // TODO(dfc) s.lock protects s.channels but isn't being held here. 530 // TODO(dfc) s.lock protects s.channels but isn't being held here.
522 for _, c := range s.channels { 531 for _, c := range s.channels {
523 c.setDead() 532 c.setDead()
524 c.handleData(nil) 533 c.handleData(nil)
(...skipping 25 matching lines...) Expand all
550 if err != nil { 559 if err != nil {
551 return nil, err 560 return nil, err
552 } 561 }
553 switch msg := decoded.(type) { 562 switch msg := decoded.(type) {
554 case *channelOpenMsg: 563 case *channelOpenMsg:
555 if msg.MaxPacketSize < minPacketLength || msg.Ma xPacketSize > 1<<31 { 564 if msg.MaxPacketSize < minPacketLength || msg.Ma xPacketSize > 1<<31 {
556 return nil, errors.New("ssh: invalid Max PacketSize from peer") 565 return nil, errors.New("ssh: invalid Max PacketSize from peer")
557 } 566 }
558 c := &serverChan{ 567 c := &serverChan{
559 channel: channel{ 568 channel: channel{
560 » » » » » » packetConn: s, 569 » » » » » » packetConn: s.transport,
561 remoteId: msg.PeersId, 570 remoteId: msg.PeersId,
562 remoteWin: window{Cond: newCond ()}, 571 remoteWin: window{Cond: newCond ()},
563 maxPacket: msg.MaxPacketSize, 572 maxPacket: msg.MaxPacketSize,
564 }, 573 },
565 chanType: msg.ChanType, 574 chanType: msg.ChanType,
566 extraData: msg.TypeSpecificData, 575 extraData: msg.TypeSpecificData,
567 myWindow: defaultWindowSize, 576 myWindow: defaultWindowSize,
568 serverConn: s, 577 serverConn: s,
569 cond: newCond(), 578 cond: newCond(),
570 pendingData: make([]byte, defaultWindowS ize), 579 pendingData: make([]byte, defaultWindowS ize),
(...skipping 41 matching lines...) Expand 10 before | Expand all | Expand 10 after
612 c, ok := s.channels[msg.PeersId] 621 c, ok := s.channels[msg.PeersId]
613 if !ok { 622 if !ok {
614 s.lock.Unlock() 623 s.lock.Unlock()
615 continue 624 continue
616 } 625 }
617 c.handlePacket(msg) 626 c.handlePacket(msg)
618 s.lock.Unlock() 627 s.lock.Unlock()
619 628
620 case *globalRequestMsg: 629 case *globalRequestMsg:
621 if msg.WantReply { 630 if msg.WantReply {
622 » » » » » if err := s.writePacket([]byte{msgReques tFailure}); err != nil { 631 » » » » » if err := s.transport.writePacket([]byte {msgRequestFailure}); err != nil {
623 return nil, err 632 return nil, err
624 } 633 }
625 } 634 }
626 635
627 case *kexInitMsg: 636 case *kexInitMsg:
628 s.lock.Lock() 637 s.lock.Lock()
629 if err := s.clientInitHandshake(msg, packet); er r != nil { 638 if err := s.clientInitHandshake(msg, packet); er r != nil {
630 s.lock.Unlock() 639 s.lock.Unlock()
631 return nil, err 640 return nil, err
632 } 641 }
(...skipping 41 matching lines...) Expand 10 before | Expand all | Expand 10 after
674 func Listen(network, addr string, config *ServerConfig) (*Listener, error) { 683 func Listen(network, addr string, config *ServerConfig) (*Listener, error) {
675 l, err := net.Listen(network, addr) 684 l, err := net.Listen(network, addr)
676 if err != nil { 685 if err != nil {
677 return nil, err 686 return nil, err
678 } 687 }
679 return &Listener{ 688 return &Listener{
680 l, 689 l,
681 config, 690 config,
682 }, nil 691 }, nil
683 } 692 }
LEFTRIGHT

Powered by Google App Engine
RSS Feeds Recent Issues | This issue
This is Rietveld f62528b