| OLD | NEW |
|---|---|
| 1 import unittest | 1 import unittest |
| 2 import pickle | 2 import pickle |
| 3 import pickletools | 3 import pickletools |
| 4 import copy_reg | 4 import copy_reg |
| 5 | 5 |
| 6 from test.test_support import TestFailed, TESTFN, run_with_locale | 6 from test.test_support import TestFailed, TESTFN, run_with_locale |
| 7 | 7 |
| 8 from pickle import bytes_types | 8 from pickle import bytes_types |
| 9 | 9 |
| 10 # Tests that try a number of pickle protocols should have a | 10 # Tests that try a number of pickle protocols should have a |
| 11 # for proto in protocols: | 11 # for proto in protocols: |
| 12 # kind of outer loop. | 12 # kind of outer loop. |
| 13 protocols = range(pickle.HIGHEST_PROTOCOL + 1) | 13 protocols = range(pickle.HIGHEST_PROTOCOL + 1) |
| 14 | 14 |
| 15 | 15 |
| 16 # Return True if opcode code appears in the pickle, else False. | 16 # Return True if opcode code appears in the pickle, else False. |
| 17 def opcode_in_pickle(code, pickle): | 17 def opcode_in_pickle(code, pickle): |
| 18 for op, dummy, dummy in pickletools.genops(pickle): | 18 for op, dummy, dummy in pickletools.genops(pickle): |
| 19 if op.code == code.decode("latin-1"): | 19 if op.code == code.decode("latin-1"): |
| 20 return True | 20 return True |
| 21 return False | 21 return False |
| 22 | 22 |
| 23 # Return the number of times opcode code appears in pickle. | 23 # Return the number of times opcode code appears in pickle. |
| 24 def count_opcode(code, pickle): | 24 def count_opcode(code, pickle): |
| 25 n = 0 | 25 n = 0 |
| 26 for op, dummy, dummy in pickletools.genops(pickle): | 26 for op, dummy, dummy in pickletools.genops(pickle): |
| 27 if op.code == code.decode("latin-1"): | 27 if op.code == code.decode("latin-1"): |
| 28 n += 1 | 28 n += 1 |
| 29 return n | 29 return n |
| 30 | 30 |
| 31 # We can't very well test the extension registry without putting known stuff | 31 # We can't very well test the extension registry without putting known stuff |
| 32 # in it, but we have to be careful to restore its original state. Code | 32 # in it, but we have to be careful to restore its original state. Code |
| 33 # should do this: | 33 # should do this: |
| 34 # | 34 # |
| 35 # e = ExtensionSaver(extension_code) | 35 # e = ExtensionSaver(extension_code) |
| 36 # try: | 36 # try: |
| 37 # fiddle w/ the extension registry's stuff for extension_code | 37 # fiddle w/ the extension registry's stuff for extension_code |
| 38 # finally: | 38 # finally: |
| 39 # e.restore() | 39 # e.restore() |
| 40 | 40 |
| 41 class ExtensionSaver: | 41 class ExtensionSaver: |
| 42 # Remember current registration for code (if any), and remove it (if | 42 # Remember current registration for code (if any), and remove it (if |
| 43 # there is one). | 43 # there is one). |
| 44 def __init__(self, code): | 44 def __init__(self, code): |
| 45 self.code = code | 45 self.code = code |
| 46 if code in copy_reg._inverted_registry: | 46 if code in copy_reg._inverted_registry: |
| 47 self.pair = copy_reg._inverted_registry[code] | 47 self.pair = copy_reg._inverted_registry[code] |
| 48 copy_reg.remove_extension(self.pair[0], self.pair[1], code) | 48 copy_reg.remove_extension(self.pair[0], self.pair[1], code) |
| 49 else: | 49 else: |
| 50 self.pair = None | 50 self.pair = None |
| (...skipping 487 matching lines...) Show 10 above Show 10 below | |
| 538 | 538 |
| 539 @run_with_locale('LC_ALL', 'de_DE', 'fr_FR') | 539 @run_with_locale('LC_ALL', 'de_DE', 'fr_FR') |
| 540 def test_float_format(self): | 540 def test_float_format(self): |
| 541 # make sure that floats are formatted locale independent with proto 0 | 541 # make sure that floats are formatted locale independent with proto 0 |
| 542 self.assertEqual(self.dumps(1.2, 0)[0:3], b'F1.') | 542 self.assertEqual(self.dumps(1.2, 0)[0:3], b'F1.') |
| 543 | 543 |
| 544 def test_reduce(self): | 544 def test_reduce(self): |
| 545 pass | 545 pass |
| 546 | 546 |
| 547 def test_getinitargs(self): | 547 def test_getinitargs(self): |
| 548 pass | 548 pass |
| 549 | 549 |
| 550 def test_metaclass(self): | 550 def test_metaclass(self): |
| 551 a = use_metaclass() | 551 a = use_metaclass() |
| 552 for proto in protocols: | 552 for proto in protocols: |
| 553 s = self.dumps(a, proto) | 553 s = self.dumps(a, proto) |
| 554 b = self.loads(s) | 554 b = self.loads(s) |
| 555 self.assertEqual(a.__class__, b.__class__) | 555 self.assertEqual(a.__class__, b.__class__) |
| 556 | 556 |
| 557 def test_structseq(self): | 557 def test_structseq(self): |
| 558 import time | 558 import time |
| 559 import os | 559 import os |
| 560 | 560 |
| 561 t = time.localtime() | 561 t = time.localtime() |
| 562 for proto in protocols: | 562 for proto in protocols: |
| 563 s = self.dumps(t, proto) | 563 s = self.dumps(t, proto) |
| 564 u = self.loads(s) | 564 u = self.loads(s) |
| 565 self.assertEqual(t, u) | 565 self.assertEqual(t, u) |
| 566 if hasattr(os, "stat"): | 566 if hasattr(os, "stat"): |
| 567 t = os.stat(os.curdir) | 567 t = os.stat(os.curdir) |
| 568 s = self.dumps(t, proto) | 568 s = self.dumps(t, proto) |
| 569 u = self.loads(s) | 569 u = self.loads(s) |
| 570 self.assertEqual(t, u) | 570 self.assertEqual(t, u) |
| 571 if hasattr(os, "statvfs"): | 571 if hasattr(os, "statvfs"): |
| 572 t = os.statvfs(os.curdir) | 572 t = os.statvfs(os.curdir) |
| 573 s = self.dumps(t, proto) | 573 s = self.dumps(t, proto) |
| 574 u = self.loads(s) | 574 u = self.loads(s) |
| 575 self.assertEqual(t, u) | 575 self.assertEqual(t, u) |
| 576 | 576 |
| 577 # Tests for protocol 2 | 577 # Tests for protocol 2 |
| 578 | 578 |
| 579 def test_proto(self): | 579 def test_proto(self): |
| 580 build_none = pickle.NONE + pickle.STOP | 580 build_none = pickle.NONE + pickle.STOP |
| 581 for proto in protocols: | 581 for proto in protocols: |
| 582 expected = build_none | 582 expected = build_none |
| 583 if proto >= 2: | 583 if proto >= 2: |
| 584 expected = pickle.PROTO + bytes([proto]) + expected | 584 expected = pickle.PROTO + bytes([proto]) + expected |
| 585 p = self.dumps(None, proto) | 585 p = self.dumps(None, proto) |
| 586 self.assertEqual(p, expected) | 586 self.assertEqual(p, expected) |
| 587 | 587 |
| 588 oob = protocols[-1] + 1 # a future protocol | 588 oob = list(protocols)[-1] + 1 # a future protocol |
|
Alexandre Vassalotti
2008/05/03 07:12:01
The list() call should be moved to line 13, where
| |
| 589 badpickle = pickle.PROTO + bytes([oob]) + build_none | 589 badpickle = pickle.PROTO + bytes([oob]) + build_none |
| 590 try: | 590 try: |
| 591 self.loads(badpickle) | 591 self.loads(badpickle) |
| 592 except ValueError as detail: | 592 except ValueError as detail: |
| 593 self.failUnless(str(detail).startswith( | 593 self.failUnless(str(detail).startswith( |
| 594 "unsupported pickle protocol")) | 594 "unsupported pickle protocol")) |
| 595 else: | 595 else: |
| 596 self.fail("expected bad protocol number to raise ValueError") | 596 self.fail("expected bad protocol number to raise ValueError") |
| 597 | 597 |
| 598 def test_long1(self): | 598 def test_long1(self): |
| 599 x = 12345678910111213141516178920 | 599 x = 12345678910111213141516178920 |
| 600 for proto in protocols: | 600 for proto in protocols: |
| 601 s = self.dumps(x, proto) | 601 s = self.dumps(x, proto) |
| 602 y = self.loads(s) | 602 y = self.loads(s) |
| 603 self.assertEqual(x, y) | 603 self.assertEqual(x, y) |
| 604 self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2) | 604 self.assertEqual(opcode_in_pickle(pickle.LONG1, s), proto >= 2) |
| 605 | 605 |
| 606 def test_long4(self): | 606 def test_long4(self): |
| 607 x = 12345678910111213141516178920 << (256*8) | 607 x = 12345678910111213141516178920 << (256*8) |
| 608 for proto in protocols: | 608 for proto in protocols: |
| 609 s = self.dumps(x, proto) | 609 s = self.dumps(x, proto) |
| 610 y = self.loads(s) | 610 y = self.loads(s) |
| 611 self.assertEqual(x, y) | 611 self.assertEqual(x, y) |
| 612 self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2) | 612 self.assertEqual(opcode_in_pickle(pickle.LONG4, s), proto >= 2) |
| 613 | 613 |
| 614 def test_short_tuples(self): | 614 def test_short_tuples(self): |
| 615 # Map (proto, len(tuple)) to expected opcode. | 615 # Map (proto, len(tuple)) to expected opcode. |
| 616 expected_opcode = {(0, 0): pickle.TUPLE, | 616 expected_opcode = {(0, 0): pickle.TUPLE, |
| 617 (0, 1): pickle.TUPLE, | 617 (0, 1): pickle.TUPLE, |
| 618 (0, 2): pickle.TUPLE, | 618 (0, 2): pickle.TUPLE, |
| 619 (0, 3): pickle.TUPLE, | 619 (0, 3): pickle.TUPLE, |
| 620 (0, 4): pickle.TUPLE, | 620 (0, 4): pickle.TUPLE, |
| 621 | 621 |
| 622 (1, 0): pickle.EMPTY_TUPLE, | 622 (1, 0): pickle.EMPTY_TUPLE, |
| 623 (1, 1): pickle.TUPLE, | 623 (1, 1): pickle.TUPLE, |
| 624 (1, 2): pickle.TUPLE, | 624 (1, 2): pickle.TUPLE, |
| 625 (1, 3): pickle.TUPLE, | 625 (1, 3): pickle.TUPLE, |
| 626 (1, 4): pickle.TUPLE, | 626 (1, 4): pickle.TUPLE, |
| 627 | 627 |
| 628 (2, 0): pickle.EMPTY_TUPLE, | 628 (2, 0): pickle.EMPTY_TUPLE, |
| 629 (2, 1): pickle.TUPLE1, | 629 (2, 1): pickle.TUPLE1, |
| 630 (2, 2): pickle.TUPLE2, | 630 (2, 2): pickle.TUPLE2, |
| 631 (2, 3): pickle.TUPLE3, | 631 (2, 3): pickle.TUPLE3, |
| 632 (2, 4): pickle.TUPLE, | 632 (2, 4): pickle.TUPLE, |
| 633 | 633 |
| 634 (3, 0): pickle.EMPTY_TUPLE, | 634 (3, 0): pickle.EMPTY_TUPLE, |
| 635 (3, 1): pickle.TUPLE1, | 635 (3, 1): pickle.TUPLE1, |
| 636 (3, 2): pickle.TUPLE2, | 636 (3, 2): pickle.TUPLE2, |
| 637 (3, 3): pickle.TUPLE3, | 637 (3, 3): pickle.TUPLE3, |
| 638 (3, 4): pickle.TUPLE, | 638 (3, 4): pickle.TUPLE, |
| (...skipping 349 matching lines...) Show 10 above Show 10 below | |
| 988 # This class defines persistent_id() and persistent_load() | 988 # This class defines persistent_id() and persistent_load() |
| 989 # functions that should be used by the pickler. All even integers | 989 # functions that should be used by the pickler. All even integers |
| 990 # are pickled using persistent ids. | 990 # are pickled using persistent ids. |
| 991 | 991 |
| 992 def persistent_id(self, object): | 992 def persistent_id(self, object): |
| 993 if isinstance(object, int) and object % 2 == 0: | 993 if isinstance(object, int) and object % 2 == 0: |
| 994 self.id_count += 1 | 994 self.id_count += 1 |
| 995 return str(object) | 995 return str(object) |
| 996 else: | 996 else: |
| 997 return None | 997 return None |
| 998 | 998 |
| 999 def persistent_load(self, oid): | 999 def persistent_load(self, oid): |
| 1000 self.load_count += 1 | 1000 self.load_count += 1 |
| 1001 object = int(oid) | 1001 object = int(oid) |
| 1002 assert object % 2 == 0 | 1002 assert object % 2 == 0 |
| 1003 return object | 1003 return object |
| 1004 | 1004 |
| 1005 def test_persistence(self): | 1005 def test_persistence(self): |
| 1006 self.id_count = 0 | 1006 self.id_count = 0 |
| 1007 self.load_count = 0 | 1007 self.load_count = 0 |
| 1008 L = list(range(10)) | 1008 L = list(range(10)) |
| 1009 self.assertEqual(self.loads(self.dumps(L)), L) | 1009 self.assertEqual(self.loads(self.dumps(L)), L) |
| 1010 self.assertEqual(self.id_count, 5) | 1010 self.assertEqual(self.id_count, 5) |
| 1011 self.assertEqual(self.load_count, 5) | 1011 self.assertEqual(self.load_count, 5) |
| 1012 | 1012 |
| 1013 def test_bin_persistence(self): | 1013 def test_bin_persistence(self): |
| 1014 self.id_count = 0 | 1014 self.id_count = 0 |
| 1015 self.load_count = 0 | 1015 self.load_count = 0 |
| 1016 L = list(range(10)) | 1016 L = list(range(10)) |
| 1017 self.assertEqual(self.loads(self.dumps(L, 1)), L) | 1017 self.assertEqual(self.loads(self.dumps(L, 1)), L) |
| 1018 self.assertEqual(self.id_count, 5) | 1018 self.assertEqual(self.id_count, 5) |
| 1019 self.assertEqual(self.load_count, 5) | 1019 self.assertEqual(self.load_count, 5) |
| 1020 | 1020 |
| 1021 if __name__ == "__main__": | 1021 if __name__ == "__main__": |
| 1022 # Print some stuff that can be used to rewrite DATA{0,1,2} | 1022 # Print some stuff that can be used to rewrite DATA{0,1,2} |
| 1023 from pickletools import dis | 1023 from pickletools import dis |
| 1024 x = create_data() | 1024 x = create_data() |
| 1025 for i in range(3): | 1025 for i in range(3): |
| 1026 p = pickle.dumps(x, i) | 1026 p = pickle.dumps(x, i) |
| 1027 print("DATA{0} = (".format(i)) | 1027 print("DATA{0} = (".format(i)) |
| 1028 for j in range(0, len(p), 20): | 1028 for j in range(0, len(p), 20): |
| 1029 b = bytes(p[j:j+20]) | 1029 b = bytes(p[j:j+20]) |
| 1030 print(" {0!r}".format(b)) | 1030 print(" {0!r}".format(b)) |
| 1031 print(")") | 1031 print(")") |
| 1032 print() | 1032 print() |
| 1033 print("# Disassembly of DATA{0}".format(i)) | 1033 print("# Disassembly of DATA{0}".format(i)) |
| 1034 print("DATA{0}_DIS = \"\"\"\\".format(i)) | 1034 print("DATA{0}_DIS = \"\"\"\\".format(i)) |
| 1035 dis(p) | 1035 dis(p) |
| 1036 print("\"\"\"") | 1036 print("\"\"\"") |
| 1037 print() | 1037 print() |
| OLD | NEW |