OLD | NEW |
(Empty) | |
| 1 // Copyright 2013 The Go Authors. All rights reserved. |
| 2 // Use of this source code is governed by a BSD-style |
| 3 // license that can be found in the LICENSE file. |
| 4 |
| 5 package ssh |
| 6 |
| 7 import ( |
| 8 "errors" |
| 9 "fmt" |
| 10 "io" |
| 11 "log" |
| 12 "net" |
| 13 "sync" |
| 14 ) |
| 15 |
| 16 // If set, debug will log print messages sent and received. |
| 17 const debug = false |
| 18 |
| 19 // keyingTransport is a packet based transport that supports key |
| 20 // changes. It need not be thread-safe. It should pass through |
| 21 // msgNewKeys in both directions. |
| 22 type keyingTransport interface { |
| 23 packetConn |
| 24 |
| 25 // prepareKeyChange sets up a key change. The key change for a |
| 26 // direction will be effected if a msgNewKeys message is sent |
| 27 // or received. |
| 28 prepareKeyChange(*algorithms, *kexResult) error |
| 29 |
| 30 // getSessionID returns the session ID. prepareKeyChange must |
| 31 // have been called once. |
| 32 getSessionID() []byte |
| 33 } |
| 34 |
| 35 // rekeyingTransport is the interface of handshakeTransport that we |
| 36 // (internally) expose to ClientConn and ServerConn. |
| 37 type rekeyingTransport interface { |
| 38 packetConn |
| 39 |
| 40 // requestKeyChange asks the remote side to change keys. All |
| 41 // writes are blocked until the key change succeeds, which is |
| 42 // signaled by reading a msgNewKeys. |
| 43 requestKeyChange() error |
| 44 |
| 45 // getSessionID returns the session ID. This is only valid |
| 46 // after the first key change has completed. |
| 47 getSessionID() []byte |
| 48 } |
| 49 |
| 50 // handshakeTransport implements rekeying on top of a keyingTransport |
| 51 // and offers a thread-safe writePacket() interface. |
| 52 type handshakeTransport struct { |
| 53 conn keyingTransport |
| 54 config *CryptoConfig |
| 55 |
| 56 // TODO(hanwen): move Rand into CryptoConfig. |
| 57 rand func() io.Reader |
| 58 |
| 59 serverVersion []byte |
| 60 clientVersion []byte |
| 61 |
| 62 hostKeys []Signer // If hostKeys are given, we are the server. |
| 63 |
| 64 // On read error, incoming is closed, and readError is set. |
| 65 incoming chan []byte |
| 66 readError error |
| 67 |
| 68 // data for host key checking |
| 69 checker HostKeyChecker |
| 70 dialAddress string |
| 71 remoteAddr net.Addr |
| 72 |
| 73 rekeyThreshold uint64 // rekey after sending/receiving this much data. |
| 74 readSinceKex uint64 |
| 75 |
| 76 // Protects the writing side of the connection |
| 77 mu sync.Mutex |
| 78 cond *sync.Cond |
| 79 sentInitPacket []byte |
| 80 sentInitMsg *kexInitMsg |
| 81 writtenSinceKex uint64 |
| 82 writeError error |
| 83 } |
| 84 |
| 85 func newHandshakeTransport(conn keyingTransport, clientVersion, serverVersion []
byte) *handshakeTransport { |
| 86 t := &handshakeTransport{ |
| 87 conn: conn, |
| 88 serverVersion: serverVersion, |
| 89 clientVersion: clientVersion, |
| 90 incoming: make(chan []byte, 16), |
| 91 } |
| 92 t.cond = sync.NewCond(&t.mu) |
| 93 return t |
| 94 } |
| 95 |
| 96 func newClientTransport(conn keyingTransport, clientVersion, serverVersion []byt
e, config *ClientConfig, dialAddr string, addr net.Addr) *handshakeTransport { |
| 97 t := newHandshakeTransport(conn, clientVersion, serverVersion) |
| 98 t.setCryptoConfig(&config.Crypto) |
| 99 t.dialAddress = dialAddr |
| 100 t.rand = config.rand |
| 101 t.checker = config.HostKeyChecker |
| 102 go t.readLoop() |
| 103 return t |
| 104 } |
| 105 |
| 106 func newServerTransport(conn keyingTransport, clientVersion, serverVersion []byt
e, config *ServerConfig) *handshakeTransport { |
| 107 t := newHandshakeTransport(conn, clientVersion, serverVersion) |
| 108 t.setCryptoConfig(&config.Crypto) |
| 109 t.hostKeys = config.hostKeys |
| 110 t.rand = config.rand |
| 111 go t.readLoop() |
| 112 return t |
| 113 } |
| 114 |
| 115 func (t *handshakeTransport) getSessionID() []byte { |
| 116 return t.conn.getSessionID() |
| 117 } |
| 118 |
| 119 func (t *handshakeTransport) setCryptoConfig(c *CryptoConfig) { |
| 120 t.config = c |
| 121 t.rekeyThreshold = t.config.RekeyThreshold |
| 122 if t.rekeyThreshold == 0 { |
| 123 // RFC 4253, section 9 suggests rekeying after 1G. |
| 124 t.rekeyThreshold = 1 << 30 |
| 125 } |
| 126 } |
| 127 |
| 128 func (t *handshakeTransport) id() string { |
| 129 if len(t.hostKeys) > 0 { |
| 130 return "server" |
| 131 } |
| 132 return "client" |
| 133 } |
| 134 |
| 135 func (t *handshakeTransport) readPacket() ([]byte, error) { |
| 136 p, ok := <-t.incoming |
| 137 if !ok { |
| 138 return nil, t.readError |
| 139 } |
| 140 return p, nil |
| 141 } |
| 142 |
| 143 func (t *handshakeTransport) readLoop() { |
| 144 for { |
| 145 p, err := t.readOnePacket() |
| 146 if err != nil { |
| 147 t.readError = err |
| 148 close(t.incoming) |
| 149 break |
| 150 } |
| 151 if p[0] == msgIgnore || p[0] == msgDebug { |
| 152 continue |
| 153 } |
| 154 t.incoming <- p |
| 155 } |
| 156 } |
| 157 |
| 158 func (t *handshakeTransport) readOnePacket() ([]byte, error) { |
| 159 if t.readSinceKex > t.rekeyThreshold { |
| 160 if err := t.requestKeyChange(); err != nil { |
| 161 return nil, err |
| 162 } |
| 163 } |
| 164 |
| 165 p, err := t.conn.readPacket() |
| 166 if err != nil { |
| 167 return nil, err |
| 168 } |
| 169 |
| 170 t.readSinceKex += uint64(len(p)) |
| 171 if debug { |
| 172 msg, err := decode(p) |
| 173 log.Printf("%s got %T %v (%v)", t.id(), msg, msg, err) |
| 174 } |
| 175 if p[0] != msgKexInit { |
| 176 return p, nil |
| 177 } |
| 178 err = t.enterKeyExchange(p) |
| 179 |
| 180 t.mu.Lock() |
| 181 if err != nil { |
| 182 // drop connection |
| 183 t.conn.Close() |
| 184 t.writeError = err |
| 185 } |
| 186 |
| 187 if debug { |
| 188 log.Printf("%s exited key exchange, err %v", t.id(), err) |
| 189 } |
| 190 |
| 191 // Unblock writers. |
| 192 t.sentInitMsg = nil |
| 193 t.sentInitPacket = nil |
| 194 t.cond.Broadcast() |
| 195 t.writtenSinceKex = 0 |
| 196 t.mu.Unlock() |
| 197 |
| 198 if err != nil { |
| 199 return nil, err |
| 200 } |
| 201 |
| 202 t.readSinceKex = 0 |
| 203 return []byte{msgNewKeys}, nil |
| 204 } |
| 205 |
| 206 // sendKexInit sends a key change message, and returns the message |
| 207 // that was sent. After initiating the key change, all writes will be |
| 208 // blocked until the change is done, and a failed key change will |
| 209 // close the underlying transport. This function is safe for |
| 210 // concurrent use by multiple goroutines. |
| 211 func (t *handshakeTransport) sendKexInit() (*kexInitMsg, []byte, error) { |
| 212 t.mu.Lock() |
| 213 defer t.mu.Unlock() |
| 214 return t.sendKexInitLocked() |
| 215 } |
| 216 |
| 217 func (t *handshakeTransport) requestKeyChange() error { |
| 218 _, _, err := t.sendKexInit() |
| 219 return err |
| 220 } |
| 221 |
| 222 // sendKexInitLocked sends a key change message. t.mu must be locked |
| 223 // while this happens. |
| 224 func (t *handshakeTransport) sendKexInitLocked() (*kexInitMsg, []byte, error) { |
| 225 // kexInits may be sent either in response to the other side, |
| 226 // or because our side wants to initiate a key change, so we |
| 227 // may have already sent a kexInit. In that case, don't send a |
| 228 // second kexInit. |
| 229 if t.sentInitMsg != nil { |
| 230 return t.sentInitMsg, t.sentInitPacket, nil |
| 231 } |
| 232 msg := &kexInitMsg{ |
| 233 KexAlgos: t.config.kexes(), |
| 234 CiphersClientServer: t.config.ciphers(), |
| 235 CiphersServerClient: t.config.ciphers(), |
| 236 MACsClientServer: t.config.macs(), |
| 237 MACsServerClient: t.config.macs(), |
| 238 CompressionClientServer: supportedCompressions, |
| 239 CompressionServerClient: supportedCompressions, |
| 240 } |
| 241 |
| 242 // TODO(hanwen): add random bits to kexInit.Cookie. |
| 243 |
| 244 if len(t.hostKeys) > 0 { |
| 245 for _, k := range t.hostKeys { |
| 246 msg.ServerHostKeyAlgos = append( |
| 247 msg.ServerHostKeyAlgos, k.PublicKey().PublicKeyA
lgo()) |
| 248 } |
| 249 } else { |
| 250 msg.ServerHostKeyAlgos = supportedHostKeyAlgos |
| 251 } |
| 252 packet := marshal(msgKexInit, *msg) |
| 253 |
| 254 // writePacket destroys the contents, so save a copy. |
| 255 packetCopy := make([]byte, len(packet)) |
| 256 copy(packetCopy, packet) |
| 257 |
| 258 if err := t.conn.writePacket(packetCopy); err != nil { |
| 259 return nil, nil, err |
| 260 } |
| 261 |
| 262 t.sentInitMsg = msg |
| 263 t.sentInitPacket = packet |
| 264 return msg, packet, nil |
| 265 } |
| 266 |
| 267 func (t *handshakeTransport) writePacket(p []byte) error { |
| 268 t.mu.Lock() |
| 269 if t.writtenSinceKex > t.rekeyThreshold { |
| 270 t.sendKexInitLocked() |
| 271 } |
| 272 for t.sentInitMsg != nil { |
| 273 t.cond.Wait() |
| 274 } |
| 275 if t.writeError != nil { |
| 276 return t.writeError |
| 277 } |
| 278 t.writtenSinceKex += uint64(len(p)) |
| 279 |
| 280 var err error |
| 281 switch p[0] { |
| 282 case msgKexInit: |
| 283 err = errors.New("ssh: only handshakeTransport can send kexInit"
) |
| 284 case msgNewKeys: |
| 285 err = errors.New("ssh: only handshakeTransport can send newKeys"
) |
| 286 default: |
| 287 err = t.conn.writePacket(p) |
| 288 } |
| 289 t.mu.Unlock() |
| 290 return err |
| 291 } |
| 292 |
| 293 func (t *handshakeTransport) Close() error { |
| 294 return t.conn.Close() |
| 295 } |
| 296 |
| 297 // enterKeyExchange runs the key exchange. |
| 298 func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { |
| 299 if debug { |
| 300 log.Printf("%s entered key exchange", t.id()) |
| 301 } |
| 302 myInit, myInitPacket, err := t.sendKexInit() |
| 303 if err != nil { |
| 304 return err |
| 305 } |
| 306 |
| 307 otherInit := &kexInitMsg{} |
| 308 if err := unmarshal(otherInit, otherInitPacket, msgKexInit); err != nil
{ |
| 309 return err |
| 310 } |
| 311 |
| 312 magics := handshakeMagics{ |
| 313 clientVersion: t.clientVersion, |
| 314 serverVersion: t.serverVersion, |
| 315 clientKexInit: otherInitPacket, |
| 316 serverKexInit: myInitPacket, |
| 317 } |
| 318 |
| 319 clientInit := otherInit |
| 320 serverInit := myInit |
| 321 if len(t.hostKeys) == 0 { |
| 322 clientInit = myInit |
| 323 serverInit = otherInit |
| 324 |
| 325 magics.clientKexInit = myInitPacket |
| 326 magics.serverKexInit = otherInitPacket |
| 327 } |
| 328 |
| 329 algs := findAgreedAlgorithms(clientInit, serverInit) |
| 330 if algs == nil { |
| 331 return errors.New("ssh: no common algorithms") |
| 332 } |
| 333 |
| 334 // We don't send FirstKexFollows, but we handle receiving it. |
| 335 if otherInit.FirstKexFollows && algs.kex != otherInit.KexAlgos[0] { |
| 336 // other side sent a kex message for the wrong algorithm, |
| 337 // which we have to ignore. |
| 338 if _, err := t.conn.readPacket(); err != nil { |
| 339 return err |
| 340 } |
| 341 } |
| 342 |
| 343 kex, ok := kexAlgoMap[algs.kex] |
| 344 if !ok { |
| 345 return fmt.Errorf("ssh: unexpected key exchange algorithm %v", a
lgs.kex) |
| 346 } |
| 347 |
| 348 var result *kexResult |
| 349 if len(t.hostKeys) > 0 { |
| 350 result, err = t.server(kex, algs, &magics) |
| 351 } else { |
| 352 result, err = t.client(kex, algs, &magics) |
| 353 } |
| 354 |
| 355 if err != nil { |
| 356 return err |
| 357 } |
| 358 |
| 359 t.conn.prepareKeyChange(algs, result) |
| 360 if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { |
| 361 return err |
| 362 } |
| 363 if packet, err := t.conn.readPacket(); err != nil { |
| 364 return err |
| 365 } else if packet[0] != msgNewKeys { |
| 366 return UnexpectedMessageError{msgNewKeys, packet[0]} |
| 367 } |
| 368 return nil |
| 369 } |
| 370 |
| 371 func (t *handshakeTransport) server(kex kexAlgorithm, algs *algorithms, magics *
handshakeMagics) (*kexResult, error) { |
| 372 var hostKey Signer |
| 373 for _, k := range t.hostKeys { |
| 374 if algs.hostKey == k.PublicKey().PublicKeyAlgo() { |
| 375 hostKey = k |
| 376 } |
| 377 } |
| 378 |
| 379 r, err := kex.Server(t.conn, t.rand(), magics, hostKey) |
| 380 return r, err |
| 381 } |
| 382 |
| 383 func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics *
handshakeMagics) (*kexResult, error) { |
| 384 result, err := kex.Client(t.conn, t.rand(), magics) |
| 385 if err != nil { |
| 386 return nil, err |
| 387 } |
| 388 |
| 389 if err := verifyHostKeySignature(algs.hostKey, result); err != nil { |
| 390 return nil, err |
| 391 } |
| 392 |
| 393 if t.checker != nil { |
| 394 err = t.checker.Check(t.dialAddress, t.remoteAddr, algs.hostKey,
result.HostKey) |
| 395 if err != nil { |
| 396 return nil, err |
| 397 } |
| 398 } |
| 399 |
| 400 return result, nil |
| 401 } |
OLD | NEW |