OLD | NEW |
1 package rpc_test | 1 package rpc_test |
2 | 2 |
3 import ( | 3 import ( |
4 "encoding/json" | 4 "encoding/json" |
5 "fmt" | 5 "fmt" |
6 "io" | 6 "io" |
7 . "launchpad.net/gocheck" | 7 . "launchpad.net/gocheck" |
8 "launchpad.net/juju-core/log" | 8 "launchpad.net/juju-core/log" |
9 "launchpad.net/juju-core/rpc" | 9 "launchpad.net/juju-core/rpc" |
10 "launchpad.net/juju-core/testing" | 10 "launchpad.net/juju-core/testing" |
(...skipping 23 matching lines...) Expand all Loading... |
34 type callError callInfo | 34 type callError callInfo |
35 | 35 |
36 func (e *callError) Error() string { | 36 func (e *callError) Error() string { |
37 return fmt.Sprintf("error calling %s", e.method) | 37 return fmt.Sprintf("error calling %s", e.method) |
38 } | 38 } |
39 | 39 |
40 type stringVal struct { | 40 type stringVal struct { |
41 Val string | 41 Val string |
42 } | 42 } |
43 | 43 |
44 type TRoot struct { | 44 type Root struct { |
45 mu sync.Mutex | 45 mu sync.Mutex |
46 calls []*callInfo | 46 calls []*callInfo |
47 returnErr bool | 47 returnErr bool |
48 simple map[string]*SimpleMethods | 48 simple map[string]*SimpleMethods |
49 delayed map[string]*DelayedMethods | 49 delayed map[string]*DelayedMethods |
50 errorInst *ErrorMethods | 50 errorInst *ErrorMethods |
51 } | 51 } |
52 | 52 |
53 func (r *TRoot) callError(rcvr interface{}, name string, arg interface{}) error
{ | 53 func (r *Root) callError(rcvr interface{}, name string, arg interface{}) error { |
54 if r.returnErr { | 54 if r.returnErr { |
55 return &callError{rcvr, name, arg} | 55 return &callError{rcvr, name, arg} |
56 } | 56 } |
57 return nil | 57 return nil |
58 } | 58 } |
59 | 59 |
60 func (r *TRoot) SimpleMethods(id string) (*SimpleMethods, error) { | 60 func (r *Root) SimpleMethods(id string) (*SimpleMethods, error) { |
61 r.mu.Lock() | 61 r.mu.Lock() |
62 defer r.mu.Unlock() | 62 defer r.mu.Unlock() |
63 if a := r.simple[id]; a != nil { | 63 if a := r.simple[id]; a != nil { |
64 return a, nil | 64 return a, nil |
65 } | 65 } |
66 return nil, fmt.Errorf("unknown SimpleMethods id") | 66 return nil, fmt.Errorf("unknown SimpleMethods id") |
67 } | 67 } |
68 | 68 |
69 func (r *TRoot) DelayedMethods(id string) (*DelayedMethods, error) { | 69 func (r *Root) DelayedMethods(id string) (*DelayedMethods, error) { |
70 r.mu.Lock() | 70 r.mu.Lock() |
71 defer r.mu.Unlock() | 71 defer r.mu.Unlock() |
72 if a := r.delayed[id]; a != nil { | 72 if a := r.delayed[id]; a != nil { |
73 return a, nil | 73 return a, nil |
74 } | 74 } |
75 return nil, fmt.Errorf("unknown DelayedMethods id") | 75 return nil, fmt.Errorf("unknown DelayedMethods id") |
76 } | 76 } |
77 | 77 |
78 func (r *TRoot) ErrorMethods(id string) (*ErrorMethods, error) { | 78 func (r *Root) ErrorMethods(id string) (*ErrorMethods, error) { |
79 if r.errorInst == nil { | 79 if r.errorInst == nil { |
80 return nil, fmt.Errorf("no error methods") | 80 return nil, fmt.Errorf("no error methods") |
81 } | 81 } |
82 return r.errorInst, nil | 82 return r.errorInst, nil |
83 } | 83 } |
84 | 84 |
85 func (t *TRoot) called(rcvr interface{}, method string, arg interface{}) { | 85 func (t *Root) called(rcvr interface{}, method string, arg interface{}) { |
86 t.mu.Lock() | 86 t.mu.Lock() |
87 t.calls = append(t.calls, &callInfo{rcvr, method, arg}) | 87 t.calls = append(t.calls, &callInfo{rcvr, method, arg}) |
88 t.mu.Unlock() | 88 t.mu.Unlock() |
89 } | 89 } |
90 | 90 |
91 type SimpleMethods struct { | 91 type SimpleMethods struct { |
92 » root *TRoot | 92 » root *Root |
93 id string | 93 id string |
94 } | 94 } |
95 | 95 |
96 // Each Call method is named in this standard form: | 96 // Each Call method is named in this standard form: |
97 // | 97 // |
98 // Call<narg>r<nret><e> | 98 // Call<narg>r<nret><e> |
99 // | 99 // |
100 // where narg is the number of arguments, nret is the number of returned | 100 // where narg is the number of arguments, nret is the number of returned |
101 // values (not including the error) and e is the letter 'e' if the | 101 // values (not including the error) and e is the letter 'e' if the |
102 // method returns an error. | 102 // method returns an error. |
(...skipping 50 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
153 | 153 |
154 type ErrorMethods struct { | 154 type ErrorMethods struct { |
155 err error | 155 err error |
156 } | 156 } |
157 | 157 |
158 func (e *ErrorMethods) Call() error { | 158 func (e *ErrorMethods) Call() error { |
159 return e.err | 159 return e.err |
160 } | 160 } |
161 | 161 |
162 func (*suite) TestRPC(c *C) { | 162 func (*suite) TestRPC(c *C) { |
163 » root := &TRoot{ | 163 » root := &Root{ |
164 simple: make(map[string]*SimpleMethods), | 164 simple: make(map[string]*SimpleMethods), |
165 } | 165 } |
166 root.simple["a99"] = &SimpleMethods{root: root, id: "a99"} | 166 root.simple["a99"] = &SimpleMethods{root: root, id: "a99"} |
167 client, srvDone := newRPCClientServer(c, root, nil) | 167 client, srvDone := newRPCClientServer(c, root, nil) |
168 for narg := 0; narg < 2; narg++ { | 168 for narg := 0; narg < 2; narg++ { |
169 for nret := 0; nret < 2; nret++ { | 169 for nret := 0; nret < 2; nret++ { |
170 for nerr := 0; nerr < 2; nerr++ { | 170 for nerr := 0; nerr < 2; nerr++ { |
171 retErr := nerr != 0 | 171 retErr := nerr != 0 |
172 root.testCall(c, client, narg, nret, retErr, fal
se) | 172 root.testCall(c, client, narg, nret, retErr, fal
se) |
173 if retErr { | 173 if retErr { |
174 root.testCall(c, client, narg, nret, ret
Err, true) | 174 root.testCall(c, client, narg, nret, ret
Err, true) |
175 } | 175 } |
176 } | 176 } |
177 } | 177 } |
178 } | 178 } |
179 client.Close() | 179 client.Close() |
180 err := chanReadError(c, srvDone, "server done") | 180 err := chanReadError(c, srvDone, "server done") |
181 c.Assert(err, IsNil) | 181 c.Assert(err, IsNil) |
182 } | 182 } |
183 | 183 |
184 func (root *TRoot) testCall(c *C, client *rpc.Client, narg, nret int, retErr, te
stErr bool) { | 184 func (root *Root) testCall(c *C, client *rpc.Client, narg, nret int, retErr, tes
tErr bool) { |
185 root.calls = nil | 185 root.calls = nil |
186 root.returnErr = testErr | 186 root.returnErr = testErr |
187 e := "" | 187 e := "" |
188 if retErr { | 188 if retErr { |
189 e = "e" | 189 e = "e" |
190 } | 190 } |
191 method := fmt.Sprintf("Call%dr%d%s", narg, nret, e) | 191 method := fmt.Sprintf("Call%dr%d%s", narg, nret, e) |
192 c.Logf("test call %s", method) | 192 c.Logf("test call %s", method) |
193 var r stringVal | 193 var r stringVal |
194 err := client.Call("SimpleMethods", "a99", method, stringVal{"arg"}, &r) | 194 err := client.Call("SimpleMethods", "a99", method, stringVal{"arg"}, &r) |
(...skipping 18 matching lines...) Expand all Loading... |
213 c.Assert(r, Equals, stringVal{method + " ret"}) | 213 c.Assert(r, Equals, stringVal{method + " ret"}) |
214 } | 214 } |
215 } | 215 } |
216 | 216 |
217 func (*suite) TestConcurrentCalls(c *C) { | 217 func (*suite) TestConcurrentCalls(c *C) { |
218 start1 := make(chan string) | 218 start1 := make(chan string) |
219 start2 := make(chan string) | 219 start2 := make(chan string) |
220 ready1 := make(chan struct{}) | 220 ready1 := make(chan struct{}) |
221 ready2 := make(chan struct{}) | 221 ready2 := make(chan struct{}) |
222 | 222 |
223 » root := &TRoot{ | 223 » root := &Root{ |
224 delayed: map[string]*DelayedMethods{ | 224 delayed: map[string]*DelayedMethods{ |
225 "1": {ready: ready1, done: start1}, | 225 "1": {ready: ready1, done: start1}, |
226 "2": {ready: ready2, done: start2}, | 226 "2": {ready: ready2, done: start2}, |
227 }, | 227 }, |
228 } | 228 } |
229 | 229 |
230 client, srvDone := newRPCClientServer(c, root, nil) | 230 client, srvDone := newRPCClientServer(c, root, nil) |
231 call := func(id string, done chan<- struct{}) { | 231 call := func(id string, done chan<- struct{}) { |
232 var r stringVal | 232 var r stringVal |
233 err := client.Call("DelayedMethods", id, "Delay", nil, &r) | 233 err := client.Call("DelayedMethods", id, "Delay", nil, &r) |
(...skipping 27 matching lines...) Expand all Loading... |
261 | 261 |
262 func (e *codedError) Error() string { | 262 func (e *codedError) Error() string { |
263 return e.m | 263 return e.m |
264 } | 264 } |
265 | 265 |
266 func (e *codedError) ErrorCode() string { | 266 func (e *codedError) ErrorCode() string { |
267 return e.code | 267 return e.code |
268 } | 268 } |
269 | 269 |
270 func (*suite) TestErrorCode(c *C) { | 270 func (*suite) TestErrorCode(c *C) { |
271 » root := &TRoot{ | 271 » root := &Root{ |
272 errorInst: &ErrorMethods{&codedError{"message", "code"}}, | 272 errorInst: &ErrorMethods{&codedError{"message", "code"}}, |
273 } | 273 } |
274 client, srvDone := newRPCClientServer(c, root, nil) | 274 client, srvDone := newRPCClientServer(c, root, nil) |
275 err := client.Call("ErrorMethods", "", "Call", nil, nil) | 275 err := client.Call("ErrorMethods", "", "Call", nil, nil) |
276 c.Assert(err, ErrorMatches, `server error: message \(code\)`) | 276 c.Assert(err, ErrorMatches, `server error: message \(code\)`) |
277 c.Assert(err.(rpc.ErrorCoder).ErrorCode(), Equals, "code") | 277 c.Assert(err.(rpc.ErrorCoder).ErrorCode(), Equals, "code") |
278 client.Close() | 278 client.Close() |
279 err = chanReadError(c, srvDone, "server done") | 279 err = chanReadError(c, srvDone, "server done") |
280 c.Assert(err, IsNil) | 280 c.Assert(err, IsNil) |
281 } | 281 } |
282 | 282 |
283 func (*suite) TestTransformErrors(c *C) { | 283 func (*suite) TestTransformErrors(c *C) { |
284 » root := &TRoot{ | 284 » root := &Root{ |
285 errorInst: &ErrorMethods{&codedError{"message", "code"}}, | 285 errorInst: &ErrorMethods{&codedError{"message", "code"}}, |
286 } | 286 } |
287 tfErr := func(err error) error { | 287 tfErr := func(err error) error { |
288 c.Check(err, NotNil) | 288 c.Check(err, NotNil) |
289 if e, ok := err.(*codedError); ok { | 289 if e, ok := err.(*codedError); ok { |
290 return &codedError{ | 290 return &codedError{ |
291 m: "transformed: " + e.m, | 291 m: "transformed: " + e.m, |
292 code: "transformed: " + e.code, | 292 code: "transformed: " + e.code, |
293 } | 293 } |
294 } | 294 } |
(...skipping 17 matching lines...) Expand all Loading... |
312 }) | 312 }) |
313 | 313 |
314 client.Close() | 314 client.Close() |
315 err = chanReadError(c, srvDone, "server done") | 315 err = chanReadError(c, srvDone, "server done") |
316 c.Assert(err, IsNil) | 316 c.Assert(err, IsNil) |
317 } | 317 } |
318 | 318 |
319 func (*suite) TestServerWaitsForOutstandingCalls(c *C) { | 319 func (*suite) TestServerWaitsForOutstandingCalls(c *C) { |
320 ready := make(chan struct{}) | 320 ready := make(chan struct{}) |
321 start := make(chan string) | 321 start := make(chan string) |
322 » root := &TRoot{ | 322 » root := &Root{ |
323 delayed: map[string]*DelayedMethods{ | 323 delayed: map[string]*DelayedMethods{ |
324 "1": { | 324 "1": { |
325 ready: ready, | 325 ready: ready, |
326 done: start, | 326 done: start, |
327 }, | 327 }, |
328 }, | 328 }, |
329 } | 329 } |
330 client, srvDone := newRPCClientServer(c, root, nil) | 330 client, srvDone := newRPCClientServer(c, root, nil) |
331 done := make(chan struct{}) | 331 done := make(chan struct{}) |
332 go func() { | 332 go func() { |
(...skipping 19 matching lines...) Expand all Loading... |
352 func chanRead(c *C, ch <-chan struct{}, what string) { | 352 func chanRead(c *C, ch <-chan struct{}, what string) { |
353 select { | 353 select { |
354 case <-ch: | 354 case <-ch: |
355 return | 355 return |
356 case <-time.After(3 * time.Second): | 356 case <-time.After(3 * time.Second): |
357 c.Fatalf("timeout on channel read %s", what) | 357 c.Fatalf("timeout on channel read %s", what) |
358 } | 358 } |
359 } | 359 } |
360 | 360 |
361 func (*suite) TestCompatibility(c *C) { | 361 func (*suite) TestCompatibility(c *C) { |
362 » root := &TRoot{ | 362 » root := &Root{ |
363 simple: make(map[string]*SimpleMethods), | 363 simple: make(map[string]*SimpleMethods), |
364 } | 364 } |
365 a0 := &SimpleMethods{root: root, id: "a0"} | 365 a0 := &SimpleMethods{root: root, id: "a0"} |
366 root.simple["a0"] = a0 | 366 root.simple["a0"] = a0 |
367 | 367 |
368 client, srvDone := newRPCClientServer(c, root, nil) | 368 client, srvDone := newRPCClientServer(c, root, nil) |
369 call := func(method string, arg, ret interface{}) (passedArg interface{}
) { | 369 call := func(method string, arg, ret interface{}) (passedArg interface{}
) { |
370 root.calls = nil | 370 root.calls = nil |
371 err := client.Call("SimpleMethods", "a0", method, arg, ret) | 371 err := client.Call("SimpleMethods", "a0", method, arg, ret) |
372 c.Assert(err, IsNil) | 372 c.Assert(err, IsNil) |
(...skipping 26 matching lines...) Expand all Loading... |
399 arg = call("Call1r0", stringVal{"x"}, &r) | 399 arg = call("Call1r0", stringVal{"x"}, &r) |
400 c.Assert(arg, Equals, stringVal{"x"}) | 400 c.Assert(arg, Equals, stringVal{"x"}) |
401 c.Assert(r, Equals, extra{}) | 401 c.Assert(r, Equals, extra{}) |
402 | 402 |
403 client.Close() | 403 client.Close() |
404 err := chanReadError(c, srvDone, "server done") | 404 err := chanReadError(c, srvDone, "server done") |
405 c.Assert(err, IsNil) | 405 c.Assert(err, IsNil) |
406 } | 406 } |
407 | 407 |
408 func (*suite) TestBadCall(c *C) { | 408 func (*suite) TestBadCall(c *C) { |
409 » root := &TRoot{ | 409 » root := &Root{ |
410 simple: make(map[string]*SimpleMethods), | 410 simple: make(map[string]*SimpleMethods), |
411 } | 411 } |
412 a0 := &SimpleMethods{root: root, id: "a0"} | 412 a0 := &SimpleMethods{root: root, id: "a0"} |
413 root.simple["a0"] = a0 | 413 root.simple["a0"] = a0 |
414 client, srvDone := newRPCClientServer(c, root, nil) | 414 client, srvDone := newRPCClientServer(c, root, nil) |
415 | 415 |
416 err := client.Call("BadSomething", "a0", "No", nil, nil) | 416 err := client.Call("BadSomething", "a0", "No", nil, nil) |
417 c.Assert(err, ErrorMatches, `server error: unknown object type "BadSomet
hing"`) | 417 c.Assert(err, ErrorMatches, `server error: unknown object type "BadSomet
hing"`) |
418 | 418 |
419 err = client.Call("SimpleMethods", "xx", "No", nil, nil) | 419 err = client.Call("SimpleMethods", "xx", "No", nil, nil) |
420 c.Assert(err, ErrorMatches, `server error: no such request "No" on Simpl
eMethods`) | 420 c.Assert(err, ErrorMatches, `server error: no such request "No" on Simpl
eMethods`) |
421 | 421 |
422 err = client.Call("SimpleMethods", "xx", "Call0r0", nil, nil) | 422 err = client.Call("SimpleMethods", "xx", "Call0r0", nil, nil) |
423 c.Assert(err, ErrorMatches, "server error: unknown SimpleMethods id") | 423 c.Assert(err, ErrorMatches, "server error: unknown SimpleMethods id") |
424 | 424 |
425 client.Close() | 425 client.Close() |
426 err = chanReadError(c, srvDone, "server done") | 426 err = chanReadError(c, srvDone, "server done") |
427 c.Assert(err, IsNil) | 427 c.Assert(err, IsNil) |
428 } | 428 } |
429 | 429 |
430 func (*suite) TestErrorAfterClientClose(c *C) { | 430 func (*suite) TestErrorAfterClientClose(c *C) { |
431 » client, srvDone := newRPCClientServer(c, &TRoot{}, nil) | 431 » client, srvDone := newRPCClientServer(c, &Root{}, nil) |
432 err := client.Close() | 432 err := client.Close() |
433 c.Assert(err, IsNil) | 433 c.Assert(err, IsNil) |
434 err = client.Call("Foo", "", "Bar", nil, nil) | 434 err = client.Call("Foo", "", "Bar", nil, nil) |
435 c.Assert(err, Equals, rpc.ErrShutdown) | 435 c.Assert(err, Equals, rpc.ErrShutdown) |
436 err = chanReadError(c, srvDone, "server done") | 436 err = chanReadError(c, srvDone, "server done") |
437 c.Assert(err, IsNil) | 437 c.Assert(err, IsNil) |
438 } | 438 } |
439 | 439 |
| 440 type KillerRoot struct { |
| 441 killed bool |
| 442 Root |
| 443 } |
| 444 |
| 445 func (r *KillerRoot) Kill() { |
| 446 r.killed = true |
| 447 } |
| 448 |
| 449 func (*suite) TestRootIsKilled(c *C) { |
| 450 root := &KillerRoot{} |
| 451 client, srvDone := newRPCClientServer(c, root, nil) |
| 452 err := client.Close() |
| 453 c.Assert(err, IsNil) |
| 454 err = chanReadError(c, srvDone, "server done") |
| 455 c.Assert(err, IsNil) |
| 456 c.Assert(root.killed, Equals, true) |
| 457 } |
| 458 |
440 func chanReadError(c *C, ch <-chan error, what string) error { | 459 func chanReadError(c *C, ch <-chan error, what string) error { |
441 select { | 460 select { |
442 case e := <-ch: | 461 case e := <-ch: |
443 return e | 462 return e |
444 case <-time.After(3 * time.Second): | 463 case <-time.After(3 * time.Second): |
445 c.Fatalf("timeout on channel read %s", what) | 464 c.Fatalf("timeout on channel read %s", what) |
446 } | 465 } |
447 panic("unreachable") | 466 panic("unreachable") |
448 } | 467 } |
449 | 468 |
450 // newRPCClientServer starts an RPC server serving a connection from a | 469 // newRPCClientServer starts an RPC server serving a connection from a |
451 // single client. When the server has finished serving the connection, | 470 // single client. When the server has finished serving the connection, |
452 // it sends a value on done. | 471 // it sends a value on done. |
453 func newRPCClientServer(c *C, root interface{}, tfErr func(error) error) (client
*rpc.Client, done <-chan error) { | 472 func newRPCClientServer(c *C, root interface{}, tfErr func(error) error) (client
*rpc.Client, done <-chan error) { |
454 » srv, err := rpc.NewServer(&TRoot{}, tfErr) | 473 » srv, err := rpc.NewServer(root, tfErr) |
455 c.Assert(err, IsNil) | 474 c.Assert(err, IsNil) |
456 | 475 |
457 l, err := net.Listen("tcp", ":0") | 476 l, err := net.Listen("tcp", ":0") |
458 c.Assert(err, IsNil) | 477 c.Assert(err, IsNil) |
459 defer l.Close() | 478 defer l.Close() |
460 | 479 |
461 srvDone := make(chan error, 1) | 480 srvDone := make(chan error, 1) |
462 go func() { | 481 go func() { |
463 conn, err := l.Accept() | 482 conn, err := l.Accept() |
464 if err != nil { | 483 if err != nil { |
(...skipping 92 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... |
557 } | 576 } |
558 } | 577 } |
559 | 578 |
560 func NewJSONClientCodec(c io.ReadWriteCloser) rpc.ClientCodec { | 579 func NewJSONClientCodec(c io.ReadWriteCloser) rpc.ClientCodec { |
561 return &generalClientCodec{ | 580 return &generalClientCodec{ |
562 Closer: c, | 581 Closer: c, |
563 enc: json.NewEncoder(c), | 582 enc: json.NewEncoder(c), |
564 dec: json.NewDecoder(c), | 583 dec: json.NewDecoder(c), |
565 } | 584 } |
566 } | 585 } |
OLD | NEW |