LEFT | RIGHT |
1 // Copyright 2011 The Go Authors. All rights reserved. | 1 // Copyright 2011 The Go Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style | 2 // Use of this source code is governed by a BSD-style |
3 // license that can be found in the LICENSE file. | 3 // license that can be found in the LICENSE file. |
4 | 4 |
5 package ssh | 5 package ssh |
6 | 6 |
7 import ( | 7 import ( |
8 "crypto/rand" | 8 "crypto/rand" |
9 "errors" | 9 "errors" |
10 "fmt" | 10 "fmt" |
11 "io" | 11 "io" |
12 "net" | 12 "net" |
13 ) | 13 ) |
14 | 14 |
15 // ClientConn represents the client side of an SSH connection. | 15 // ClientConn represents the client side of an SSH connection. |
16 type ClientConn struct { | 16 type ClientConn struct { |
17 » *transport | 17 » transport *transport |
18 config *ClientConfig | 18 config *ClientConfig |
19 forwardList // forwarded tcpip connections from the remote side | 19 forwardList // forwarded tcpip connections from the remote side |
20 | 20 |
21 // Address as passed to the Dial function. | 21 // Address as passed to the Dial function. |
22 dialAddress string | 22 dialAddress string |
23 | 23 |
24 serverVersion string | 24 serverVersion string |
25 | 25 |
26 mux *mux | 26 mux *mux |
27 } | 27 } |
28 | 28 |
29 // Client returns a new SSH client connection using c as the underlying transpor
t. | 29 // Client returns a new SSH client connection using c as the underlying transpor
t. |
30 func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) { | 30 func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) { |
31 return clientWithAddress(c, "", config) | 31 return clientWithAddress(c, "", config) |
32 } | 32 } |
33 | 33 |
34 func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientCo
nn, error) { | 34 func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientCo
nn, error) { |
35 conn := &ClientConn{ | 35 conn := &ClientConn{ |
36 transport: newTransport(c, config.rand(), true /* is client */
), | 36 transport: newTransport(c, config.rand(), true /* is client */
), |
37 config: config, | 37 config: config, |
38 dialAddress: addr, | 38 dialAddress: addr, |
39 } | 39 } |
40 | 40 |
41 if err := conn.handshake(); err != nil { | 41 if err := conn.handshake(); err != nil { |
42 » » conn.Close() | 42 » » conn.transport.Close() |
43 return nil, fmt.Errorf("handshake failed: %v", err) | 43 return nil, fmt.Errorf("handshake failed: %v", err) |
44 } | 44 } |
45 | 45 |
46 » conn.mux = newMux(conn) | 46 » conn.mux = newMux(conn.transport) |
47 » go conn.loop() | 47 » go conn.handleGlobalRequests(conn.mux.incomingRequests) |
| 48 » go conn.handleChannelOpens(conn.mux.incomingChannels) |
| 49 » go func() { |
| 50 » » conn.mux.Loop() |
| 51 » » conn.forwardList.closeAll() |
| 52 » }() |
48 return conn, nil | 53 return conn, nil |
49 } | 54 } |
50 | 55 |
51 func (c *ClientConn) loop() { | 56 // Close closes the connection. |
52 » go c.handleGlobalRequests() | 57 func (c *ClientConn) Close() error { return c.transport.Close() } |
53 » go c.handleChannelOpens() | 58 |
54 » c.mux.Loop() | 59 // LocalAddr returns the local network address. |
55 » c.forwardList.closeAll() | 60 func (c *ClientConn) LocalAddr() net.Addr { return c.transport.LocalAddr() } |
56 } | 61 |
| 62 // RemoteAddr returns the remote network address. |
| 63 func (c *ClientConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() } |
57 | 64 |
58 // handshake performs the client side key exchange. See RFC 4253 Section 7. | 65 // handshake performs the client side key exchange. See RFC 4253 Section 7. |
59 func (c *ClientConn) handshake() error { | 66 func (c *ClientConn) handshake() error { |
60 clientVersion := []byte(packageVersion) | 67 clientVersion := []byte(packageVersion) |
61 if c.config.ClientVersion != "" { | 68 if c.config.ClientVersion != "" { |
62 clientVersion = []byte(c.config.ClientVersion) | 69 clientVersion = []byte(c.config.ClientVersion) |
63 } | 70 } |
64 | 71 |
65 serverVersion, err := exchangeVersions(c.transport.Conn, clientVersion) | 72 serverVersion, err := exchangeVersions(c.transport.Conn, clientVersion) |
66 if err != nil { | 73 if err != nil { |
67 return err | 74 return err |
68 } | 75 } |
69 c.serverVersion = string(serverVersion) | 76 c.serverVersion = string(serverVersion) |
70 clientKexInit := kexInitMsg{ | 77 clientKexInit := kexInitMsg{ |
71 KexAlgos: c.config.Crypto.kexes(), | 78 KexAlgos: c.config.Crypto.kexes(), |
72 ServerHostKeyAlgos: supportedHostKeyAlgos, | 79 ServerHostKeyAlgos: supportedHostKeyAlgos, |
73 CiphersClientServer: c.config.Crypto.ciphers(), | 80 CiphersClientServer: c.config.Crypto.ciphers(), |
74 CiphersServerClient: c.config.Crypto.ciphers(), | 81 CiphersServerClient: c.config.Crypto.ciphers(), |
75 MACsClientServer: c.config.Crypto.macs(), | 82 MACsClientServer: c.config.Crypto.macs(), |
76 MACsServerClient: c.config.Crypto.macs(), | 83 MACsServerClient: c.config.Crypto.macs(), |
77 CompressionClientServer: supportedCompressions, | 84 CompressionClientServer: supportedCompressions, |
78 CompressionServerClient: supportedCompressions, | 85 CompressionServerClient: supportedCompressions, |
79 } | 86 } |
80 kexInitPacket := marshal(msgKexInit, clientKexInit) | 87 kexInitPacket := marshal(msgKexInit, clientKexInit) |
81 » if err := c.writePacket(kexInitPacket); err != nil { | 88 » if err := c.transport.writePacket(kexInitPacket); err != nil { |
82 » » return err | 89 » » return err |
83 » } | 90 » } |
84 » packet, err := c.readPacket() | 91 » packet, err := c.transport.readPacket() |
85 if err != nil { | 92 if err != nil { |
86 return err | 93 return err |
87 } | 94 } |
88 | 95 |
89 var serverKexInit kexInitMsg | 96 var serverKexInit kexInitMsg |
90 if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil { | 97 if err = unmarshal(&serverKexInit, packet, msgKexInit); err != nil { |
91 return err | 98 return err |
92 } | 99 } |
93 | 100 |
94 algs := findAgreedAlgorithms(&clientKexInit, &serverKexInit) | 101 algs := findAgreedAlgorithms(&clientKexInit, &serverKexInit) |
95 if algs == nil { | 102 if algs == nil { |
96 return errors.New("ssh: no common algorithms") | 103 return errors.New("ssh: no common algorithms") |
97 } | 104 } |
98 | 105 |
99 if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0
] { | 106 if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0
] { |
100 // The server sent a Kex message for the wrong algorithm, | 107 // The server sent a Kex message for the wrong algorithm, |
101 // which we have to ignore. | 108 // which we have to ignore. |
102 » » if _, err := c.readPacket(); err != nil { | 109 » » if _, err := c.transport.readPacket(); err != nil { |
103 return err | 110 return err |
104 } | 111 } |
105 } | 112 } |
106 | 113 |
107 kex, ok := kexAlgoMap[algs.kex] | 114 kex, ok := kexAlgoMap[algs.kex] |
108 if !ok { | 115 if !ok { |
109 return fmt.Errorf("ssh: unexpected key exchange algorithm %v", a
lgs.kex) | 116 return fmt.Errorf("ssh: unexpected key exchange algorithm %v", a
lgs.kex) |
110 } | 117 } |
111 | 118 |
112 magics := handshakeMagics{ | 119 magics := handshakeMagics{ |
113 clientVersion: clientVersion, | 120 clientVersion: clientVersion, |
114 serverVersion: serverVersion, | 121 serverVersion: serverVersion, |
115 clientKexInit: kexInitPacket, | 122 clientKexInit: kexInitPacket, |
116 serverKexInit: packet, | 123 serverKexInit: packet, |
117 } | 124 } |
118 » result, err := kex.Client(c, c.config.rand(), &magics) | 125 » result, err := kex.Client(c.transport, c.config.rand(), &magics) |
119 if err != nil { | 126 if err != nil { |
120 return err | 127 return err |
121 } | 128 } |
122 | 129 |
123 err = verifyHostKeySignature(algs.hostKey, result.HostKey, result.H, res
ult.Signature) | 130 err = verifyHostKeySignature(algs.hostKey, result.HostKey, result.H, res
ult.Signature) |
124 if err != nil { | 131 if err != nil { |
125 return err | 132 return err |
126 } | 133 } |
127 | 134 |
128 if checker := c.config.HostKeyChecker; checker != nil { | 135 if checker := c.config.HostKeyChecker; checker != nil { |
129 » » err = checker.Check(c.dialAddress, c.RemoteAddr(), algs.hostKey,
result.HostKey) | 136 » » err = checker.Check(c.dialAddress, c.transport.RemoteAddr(), alg
s.hostKey, result.HostKey) |
130 if err != nil { | 137 if err != nil { |
131 return err | 138 return err |
132 } | 139 } |
133 } | 140 } |
134 | 141 |
135 c.transport.prepareKeyChange(algs, result) | 142 c.transport.prepareKeyChange(algs, result) |
136 | 143 |
137 » if err = c.writePacket([]byte{msgNewKeys}); err != nil { | 144 » if err = c.transport.writePacket([]byte{msgNewKeys}); err != nil { |
138 » » return err | 145 » » return err |
139 » } | 146 » } |
140 » if packet, err = c.readPacket(); err != nil { | 147 » if packet, err = c.transport.readPacket(); err != nil { |
141 return err | 148 return err |
142 } | 149 } |
143 if packet[0] != msgNewKeys { | 150 if packet[0] != msgNewKeys { |
144 return UnexpectedMessageError{msgNewKeys, packet[0]} | 151 return UnexpectedMessageError{msgNewKeys, packet[0]} |
145 } | 152 } |
146 » return c.authenticate(result.H) | 153 » return c.authenticate() |
147 } | 154 } |
148 | 155 |
149 // Verify the host key obtained in the key exchange. | 156 // Verify the host key obtained in the key exchange. |
150 func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte
, signature []byte) error { | 157 func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte
, signature []byte) error { |
151 hostKey, rest, ok := ParsePublicKey(hostKeyBytes) | 158 hostKey, rest, ok := ParsePublicKey(hostKeyBytes) |
152 if len(rest) > 0 || !ok { | 159 if len(rest) > 0 || !ok { |
153 return errors.New("ssh: could not parse hostkey") | 160 return errors.New("ssh: could not parse hostkey") |
154 } | 161 } |
155 | 162 |
156 sig, rest, ok := parseSignatureBody(signature) | 163 sig, rest, ok := parseSignatureBody(signature) |
157 if len(rest) > 0 || !ok { | 164 if len(rest) > 0 || !ok { |
158 return errors.New("ssh: signature parse error") | 165 return errors.New("ssh: signature parse error") |
159 } | 166 } |
160 if sig.Format != hostKeyAlgo { | 167 if sig.Format != hostKeyAlgo { |
161 return fmt.Errorf("ssh: unexpected signature type %q", sig.Forma
t) | 168 return fmt.Errorf("ssh: unexpected signature type %q", sig.Forma
t) |
162 } | 169 } |
163 | 170 |
164 if !hostKey.Verify(data, sig.Blob) { | 171 if !hostKey.Verify(data, sig.Blob) { |
165 return errors.New("ssh: host key signature error") | 172 return errors.New("ssh: host key signature error") |
166 } | 173 } |
167 return nil | 174 return nil |
168 } | 175 } |
169 | 176 |
170 func (c *ClientConn) handleGlobalRequests() { | 177 func (c *ClientConn) handleGlobalRequests(incoming chan *ChannelRequest) { |
171 » for r := range c.mux.IncomingRequests() { | 178 » for r := range incoming { |
172 if r.WantReply { | 179 if r.WantReply { |
173 // This handles keepalive messages and matches | 180 // This handles keepalive messages and matches |
174 // the behaviour of OpenSSH. | 181 // the behaviour of OpenSSH. |
175 c.mux.AckRequest(false, nil) | 182 c.mux.AckRequest(false, nil) |
176 } | 183 } |
177 } | 184 } |
178 } | 185 } |
179 | 186 |
180 // Handle channel open messages from the remote side. | 187 // Handle channel open messages from the remote side. |
181 func (c *ClientConn) handleChannelOpens() { | 188 func (c *ClientConn) handleChannelOpens(in chan *channel) { |
182 » for ch := range c.mux.IncomingChannels() { | 189 » for ch := range in { |
183 c.handleChannelOpen(ch) | 190 c.handleChannelOpen(ch) |
184 } | 191 } |
185 } | 192 } |
186 | 193 |
187 func (c *ClientConn) handleChannelOpen(ch ChannelCreationRequest) { | 194 func (c *ClientConn) handleChannelOpen(ch *channel) { |
188 switch ch.ChannelType() { | 195 switch ch.ChannelType() { |
189 case "forwarded-tcpip": | 196 case "forwarded-tcpip": |
190 laddr, rest, ok := parseTCPAddr(ch.ExtraData()) | 197 laddr, rest, ok := parseTCPAddr(ch.ExtraData()) |
191 if !ok { | 198 if !ok { |
192 // invalid request | 199 // invalid request |
193 ch.Reject(ConnectionFailed, "could not parse TCP address
") | 200 ch.Reject(ConnectionFailed, "could not parse TCP address
") |
194 return | 201 return |
195 } | 202 } |
196 | 203 |
197 l, ok := c.forwardList.lookup(*laddr) | 204 l, ok := c.forwardList.lookup(*laddr) |
198 if !ok { | 205 if !ok { |
199 // Section 7.2, implementations MUST reject suprious inc
oming | 206 // Section 7.2, implementations MUST reject suprious inc
oming |
200 // connections. | 207 // connections. |
201 ch.Reject(Prohibited, "no forward for address") | 208 ch.Reject(Prohibited, "no forward for address") |
202 return | 209 return |
203 } | 210 } |
204 | 211 |
205 raddr, rest, ok := parseTCPAddr(rest) | 212 raddr, rest, ok := parseTCPAddr(rest) |
206 if !ok { | 213 if !ok { |
207 ch.Reject(ConnectionFailed, "could not parse TCP address
") | 214 ch.Reject(ConnectionFailed, "could not parse TCP address
") |
208 return | 215 return |
209 } | 216 } |
210 | 217 |
211 » » if channel, err := ch.Accept(); err == nil { | 218 » » if err := ch.Accept(); err == nil { |
212 » » » l <- forward{channel, raddr} | 219 » » » l <- forward{ch, raddr} |
213 } | 220 } |
214 default: | 221 default: |
215 // unknown channel type | 222 // unknown channel type |
216 ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type:
%v", ch.ChannelType())) | 223 ch.Reject(UnknownChannelType, fmt.Sprintf("unknown channel type:
%v", ch.ChannelType())) |
217 } | 224 } |
218 } | 225 } |
219 | 226 |
220 // parseTCPAddr parses the originating address from the remote into a | 227 // parseTCPAddr parses the originating address from the remote into a |
221 // *net.TCPAddr. RFC 4254 section 7.2 is mute on what to do if | 228 // *net.TCPAddr. RFC 4254 section 7.2 is mute on what to do if |
222 // parsing fails but the forwardlist requires a valid *net.TCPAddr to | 229 // parsing fails but the forwardlist requires a valid *net.TCPAddr to |
(...skipping 51 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
274 // If empty, a reasonable default is used. | 281 // If empty, a reasonable default is used. |
275 ClientVersion string | 282 ClientVersion string |
276 } | 283 } |
277 | 284 |
278 func (c *ClientConfig) rand() io.Reader { | 285 func (c *ClientConfig) rand() io.Reader { |
279 if c.Rand == nil { | 286 if c.Rand == nil { |
280 return rand.Reader | 287 return rand.Reader |
281 } | 288 } |
282 return c.Rand | 289 return c.Rand |
283 } | 290 } |
LEFT | RIGHT |