OLD | NEW |
1 // Copied with small adaptations from the reflect package in the | 1 // Copied with small adaptations from the reflect package in the |
2 // Go source tree. | 2 // Go source tree. |
3 | 3 |
4 // Copyright 2009 The Go Authors. All rights reserved. | 4 // Copyright 2009 The Go Authors. All rights reserved. |
5 // Use of this source code is governed by a BSD-style | 5 // Use of this source code is governed by a BSD-style |
6 // license that can be found in the LICENSE file. | 6 // license that can be found in the LICENSE file. |
7 | 7 |
8 // Deep equality test via reflection | |
9 | |
10 package checkers | 8 package checkers |
11 | 9 |
12 import "reflect" | 10 import ( |
| 11 » "fmt" |
| 12 » "reflect" |
| 13 » "unsafe" |
| 14 ) |
13 | 15 |
14 // During deepValueEqual, must keep track of checks that are | 16 // During deepValueEqual, must keep track of checks that are |
15 // in progress. The comparison algorithm assumes that all | 17 // in progress. The comparison algorithm assumes that all |
16 // checks in progress are true when it reencounters them. | 18 // checks in progress are true when it reencounters them. |
17 // Visited comparisons are stored in a map indexed by visit. | 19 // Visited comparisons are stored in a map indexed by visit. |
18 type visit struct { | 20 type visit struct { |
19 a1 uintptr | 21 a1 uintptr |
20 a2 uintptr | 22 a2 uintptr |
21 typ reflect.Type | 23 typ reflect.Type |
22 } | 24 } |
23 | 25 |
| 26 type mismatchError struct { |
| 27 v1, v2 reflect.Value |
| 28 path string |
| 29 how string |
| 30 } |
| 31 |
| 32 func (err *mismatchError) Error() string { |
| 33 path := err.path |
| 34 if path == "" { |
| 35 path = "top level" |
| 36 } |
| 37 return fmt.Sprintf("mismatch at %s: %s; obtained %#v; expected %#v", pat
h, err.how, interfaceOf(err.v1), interfaceOf(err.v2)) |
| 38 } |
| 39 |
24 // Tests for deep equality using reflected types. The map argument tracks | 40 // Tests for deep equality using reflected types. The map argument tracks |
25 // comparisons that have already been seen, which allows short circuiting on | 41 // comparisons that have already been seen, which allows short circuiting on |
26 // recursive types. | 42 // recursive types. |
27 func deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) boo
l { | 43 func deepValueEqual(path string, v1, v2 reflect.Value, visited map[visit]bool, d
epth int) (ok bool, err error) { |
| 44 » errorf := func(f string, a ...interface{}) error { |
| 45 » » return &mismatchError{ |
| 46 » » » v1: v1, |
| 47 » » » v2: v2, |
| 48 » » » path: path, |
| 49 » » » how: fmt.Sprintf(f, a...), |
| 50 » » } |
| 51 » } |
28 if !v1.IsValid() || !v2.IsValid() { | 52 if !v1.IsValid() || !v2.IsValid() { |
29 » » return v1.IsValid() == v2.IsValid() | 53 » » if v1.IsValid() == v2.IsValid() { |
| 54 » » » return true, nil |
| 55 » » } |
| 56 » » return false, errorf("validity mismatch") |
30 } | 57 } |
31 if v1.Type() != v2.Type() { | 58 if v1.Type() != v2.Type() { |
32 » » return false | 59 » » return false, errorf("type mismatch %s vs %s", v1.Type(), v2.Typ
e()) |
33 } | 60 } |
34 | 61 |
35 // if depth > 10 { panic("deepValueEqual") } // for debugging | 62 // if depth > 10 { panic("deepValueEqual") } // for debugging |
36 hard := func(k reflect.Kind) bool { | 63 hard := func(k reflect.Kind) bool { |
37 switch k { | 64 switch k { |
38 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: | 65 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct: |
39 return true | 66 return true |
40 } | 67 } |
41 return false | 68 return false |
42 } | 69 } |
43 | 70 |
44 if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) { | 71 if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) { |
45 addr1 := v1.UnsafeAddr() | 72 addr1 := v1.UnsafeAddr() |
46 addr2 := v2.UnsafeAddr() | 73 addr2 := v2.UnsafeAddr() |
47 if addr1 > addr2 { | 74 if addr1 > addr2 { |
48 // Canonicalize order to reduce number of entries in vis
ited. | 75 // Canonicalize order to reduce number of entries in vis
ited. |
49 addr1, addr2 = addr2, addr1 | 76 addr1, addr2 = addr2, addr1 |
50 } | 77 } |
51 | 78 |
52 // Short circuit if references are identical ... | 79 // Short circuit if references are identical ... |
53 if addr1 == addr2 { | 80 if addr1 == addr2 { |
54 » » » return true | 81 » » » return true, nil |
55 } | 82 } |
56 | 83 |
57 // ... or already seen | 84 // ... or already seen |
58 typ := v1.Type() | 85 typ := v1.Type() |
59 v := visit{addr1, addr2, typ} | 86 v := visit{addr1, addr2, typ} |
60 if visited[v] { | 87 if visited[v] { |
61 » » » return true | 88 » » » return true, nil |
62 } | 89 } |
63 | 90 |
64 // Remember for later. | 91 // Remember for later. |
65 visited[v] = true | 92 visited[v] = true |
66 } | 93 } |
67 | 94 |
68 switch v1.Kind() { | 95 switch v1.Kind() { |
69 case reflect.Array: | 96 case reflect.Array: |
70 if v1.Len() != v2.Len() { | 97 if v1.Len() != v2.Len() { |
71 » » » return false | 98 » » » // can't happen! |
| 99 » » » return false, errorf("length mismatch, %d vs %d", v1.Len
(), v2.Len()) |
72 } | 100 } |
73 for i := 0; i < v1.Len(); i++ { | 101 for i := 0; i < v1.Len(); i++ { |
74 » » » if !deepValueEqual(v1.Index(i), v2.Index(i), visited, de
pth+1) { | 102 » » » if ok, err := deepValueEqual( |
75 » » » » return false | 103 » » » » fmt.Sprintf("%s[%d]", path, i), |
76 » » » } | 104 » » » » v1.Index(i), v2.Index(i), visited, depth+1); !ok
{ |
77 » » } | 105 » » » » return false, err |
78 » » return true | 106 » » » } |
| 107 » » } |
| 108 » » return true, nil |
79 case reflect.Slice: | 109 case reflect.Slice: |
80 » » // No check for nil == nil here. | 110 » » // We treat a nil slice the same as an empty slice. |
81 | |
82 if v1.Len() != v2.Len() { | 111 if v1.Len() != v2.Len() { |
83 » » » return false | 112 » » » return false, errorf("length mismatch, %d vs %d", v1.Len
(), v2.Len()) |
84 } | 113 } |
85 if v1.Pointer() == v2.Pointer() { | 114 if v1.Pointer() == v2.Pointer() { |
86 » » » return true | 115 » » » return true, nil |
87 } | 116 } |
88 for i := 0; i < v1.Len(); i++ { | 117 for i := 0; i < v1.Len(); i++ { |
89 » » » if !deepValueEqual(v1.Index(i), v2.Index(i), visited, de
pth+1) { | 118 » » » if ok, err := deepValueEqual( |
90 » » » » return false | 119 » » » » fmt.Sprintf("%s[%d]", path, i), |
91 » » » } | 120 » » » » v1.Index(i), v2.Index(i), visited, depth+1); !ok
{ |
92 » » } | 121 » » » » return false, err |
93 » » return true | 122 » » » } |
| 123 » » } |
| 124 » » return true, nil |
94 case reflect.Interface: | 125 case reflect.Interface: |
95 if v1.IsNil() || v2.IsNil() { | 126 if v1.IsNil() || v2.IsNil() { |
96 » » » return v1.IsNil() == v2.IsNil() | 127 » » » if v1.IsNil() != v2.IsNil() { |
97 » » } | 128 » » » » return false, fmt.Errorf("nil vs non-nil interfa
ce mismatch") |
98 » » return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) | 129 » » » } |
| 130 » » » return true, nil |
| 131 » » } |
| 132 » » return deepValueEqual(path, v1.Elem(), v2.Elem(), visited, depth
+1) |
99 case reflect.Ptr: | 133 case reflect.Ptr: |
100 » » return deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1) | 134 » » return deepValueEqual("(*"+path+")", v1.Elem(), v2.Elem(), visit
ed, depth+1) |
101 case reflect.Struct: | 135 case reflect.Struct: |
102 for i, n := 0, v1.NumField(); i < n; i++ { | 136 for i, n := 0, v1.NumField(); i < n; i++ { |
103 » » » if !deepValueEqual(v1.Field(i), v2.Field(i), visited, de
pth+1) { | 137 » » » path := path + "." + v1.Type().Field(i).Name |
104 » » » » return false | 138 » » » if ok, err := deepValueEqual(path, v1.Field(i), v2.Field
(i), visited, depth+1); !ok { |
105 » » » } | 139 » » » » return false, err |
106 » » } | 140 » » » } |
107 » » return true | 141 » » } |
| 142 » » return true, nil |
108 case reflect.Map: | 143 case reflect.Map: |
109 » » // No check for nil == nil here. | 144 » » if v1.IsNil() != v2.IsNil() { |
110 | 145 » » » return false, errorf("nil vs non-nil mismatch") |
| 146 » » } |
111 if v1.Len() != v2.Len() { | 147 if v1.Len() != v2.Len() { |
112 » » » return false | 148 » » » return false, errorf("length mismatch, %d vs %d", v1.Len
(), v2.Len()) |
113 } | 149 } |
114 if v1.Pointer() == v2.Pointer() { | 150 if v1.Pointer() == v2.Pointer() { |
115 » » » return true | 151 » » » return true, nil |
116 } | 152 } |
117 for _, k := range v1.MapKeys() { | 153 for _, k := range v1.MapKeys() { |
118 » » » if !deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visit
ed, depth+1) { | 154 » » » var p string |
119 » » » » return false | 155 » » » if k.CanInterface() { |
120 » » » } | 156 » » » » p = path + "[" + fmt.Sprintf("%#v", k.Interface(
)) + "]" |
121 » » } | 157 » » » } else { |
122 » » return true | 158 » » » » p = path + "[someKey]" |
| 159 » » » } |
| 160 » » » if ok, err := deepValueEqual(p, v1.MapIndex(k), v2.MapIn
dex(k), visited, depth+1); !ok { |
| 161 » » » » return false, err |
| 162 » » » } |
| 163 » » } |
| 164 » » return true, nil |
123 case reflect.Func: | 165 case reflect.Func: |
124 if v1.IsNil() && v2.IsNil() { | 166 if v1.IsNil() && v2.IsNil() { |
125 » » » return true | 167 » » » return true, nil |
126 } | 168 } |
127 // Can't do better than this: | 169 // Can't do better than this: |
128 » » return false | 170 » » return false, errorf("non-nil functions") |
129 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.In
t64: | 171 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.In
t64: |
130 » » return v1.Int() == v2.Int() | 172 » » if v1.Int() != v2.Int() { |
| 173 » » » return false, errorf("unequal") |
| 174 » » } |
| 175 » » return true, nil |
131 case reflect.Uint, reflect.Uintptr, reflect.Uint8, reflect.Uint16, refle
ct.Uint32, reflect.Uint64: | 176 case reflect.Uint, reflect.Uintptr, reflect.Uint8, reflect.Uint16, refle
ct.Uint32, reflect.Uint64: |
132 » » return v1.Uint() == v2.Uint() | 177 » » if v1.Uint() != v2.Uint() { |
| 178 » » » return false, errorf("unequal") |
| 179 » » } |
| 180 » » return true, nil |
133 case reflect.Float32, reflect.Float64: | 181 case reflect.Float32, reflect.Float64: |
134 » » return v1.Float() == v2.Float() | 182 » » if v1.Float() != v2.Float() { |
| 183 » » » return false, errorf("unequal") |
| 184 » » } |
| 185 » » return true, nil |
135 case reflect.Complex64, reflect.Complex128: | 186 case reflect.Complex64, reflect.Complex128: |
136 » » return v1.Complex() == v2.Complex() | 187 » » if v1.Complex() != v2.Complex() { |
| 188 » » » return false, errorf("unequal") |
| 189 » » } |
| 190 » » return true, nil |
137 case reflect.Bool: | 191 case reflect.Bool: |
138 » » return v1.Bool() == v2.Bool() | 192 » » if v1.Bool() != v2.Bool() { |
| 193 » » » return false, errorf("unequal") |
| 194 » » } |
| 195 » » return true, nil |
139 case reflect.String: | 196 case reflect.String: |
140 » » return v1.String() == v2.String() | 197 » » if v1.String() != v2.String() { |
| 198 » » » return false, errorf("unequal") |
| 199 » » } |
| 200 » » return true, nil |
141 case reflect.Chan, reflect.UnsafePointer: | 201 case reflect.Chan, reflect.UnsafePointer: |
142 » » return v1.Pointer() == v2.Pointer() | 202 » » if v1.Pointer() != v2.Pointer() { |
| 203 » » » return false, errorf("unequal") |
| 204 » » } |
| 205 » » return true, nil |
143 default: | 206 default: |
144 panic("unexpected type " + v1.Type().String()) | 207 panic("unexpected type " + v1.Type().String()) |
145 } | 208 } |
146 } | 209 } |
147 | 210 |
148 // DeepEqual tests for deep equality. It uses normal == equality where | 211 // DeepEqual tests for deep equality. It uses normal == equality where |
149 // possible but will scan elements of arrays, slices, maps, and fields of | 212 // possible but will scan elements of arrays, slices, maps, and fields |
150 // structs. In maps, keys are compared with == but elements use deep | 213 // of structs. In maps, keys are compared with == but elements use deep |
151 // equality. DeepEqual correctly handles recursive types. Functions are equal | 214 // equality. DeepEqual correctly handles recursive types. Functions are |
152 // only if they are both nil. | 215 // equal only if they are both nil. |
153 // DeepEqual differs from reflect.DeepEqual in that an empty | 216 // |
154 // slice is equal to a nil slice, and an empty map is equal to a nil map. | 217 // DeepEqual differs from reflect.DeepEqual in that an empty slice is |
155 func DeepEqual(a1, a2 interface{}) bool { | 218 // equal to a nil slice. If the two values compare unequal, the |
| 219 // resulting error holds the first difference encountered. |
| 220 func DeepEqual(a1, a2 interface{}) (bool, error) { |
| 221 » errorf := func(f string, a ...interface{}) error { |
| 222 » » return &mismatchError{ |
| 223 » » » v1: reflect.ValueOf(a1), |
| 224 » » » v2: reflect.ValueOf(a2), |
| 225 » » » path: "", |
| 226 » » » how: fmt.Sprintf(f, a...), |
| 227 » » } |
| 228 » } |
156 if a1 == nil || a2 == nil { | 229 if a1 == nil || a2 == nil { |
157 » » return a1 == a2 | 230 » » if a1 == a2 { |
| 231 » » » return true, nil |
| 232 » » } |
| 233 » » return false, errorf("nil vs non-nil mismatch") |
158 } | 234 } |
159 v1 := reflect.ValueOf(a1) | 235 v1 := reflect.ValueOf(a1) |
160 v2 := reflect.ValueOf(a2) | 236 v2 := reflect.ValueOf(a2) |
161 if v1.Type() != v2.Type() { | 237 if v1.Type() != v2.Type() { |
162 » » return false | 238 » » return false, errorf("type mismatch %s vs %s", v1.Type(), v2.Typ
e()) |
163 » } | 239 » } |
164 » return deepValueEqual(v1, v2, make(map[visit]bool), 0) | 240 » return deepValueEqual("", v1, v2, make(map[visit]bool), 0) |
165 } | 241 } |
| 242 |
| 243 // interfaceOf returns v.Interface() even if v.CanInterface() == false. |
| 244 // This enables us to call fmt.Printf on a value even if it's derived |
| 245 // from inside an unexported field. |
| 246 func interfaceOf(v reflect.Value) interface{} { |
| 247 » if !v.IsValid() { |
| 248 » » return nil |
| 249 » } |
| 250 » return bypassCanInterface(v).Interface() |
| 251 } |
| 252 |
| 253 type flag uintptr |
| 254 |
| 255 // copied from reflect/value.go |
| 256 const ( |
| 257 » flagRO flag = 1 << iota |
| 258 ) |
| 259 |
| 260 var flagValOffset = func() uintptr { |
| 261 » field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag") |
| 262 » if !ok { |
| 263 » » panic("reflect.Value has no flag field") |
| 264 » } |
| 265 » return field.Offset |
| 266 }() |
| 267 |
| 268 func flagField(v *reflect.Value) *flag { |
| 269 » return (*flag)(unsafe.Pointer(uintptr(unsafe.Pointer(v)) + flagValOffset
)) |
| 270 } |
| 271 |
| 272 // bypassCanInterface returns a version of v that |
| 273 // bypasses the CanInterface check. |
| 274 func bypassCanInterface(v reflect.Value) reflect.Value { |
| 275 » if !v.IsValid() || v.CanInterface() { |
| 276 » » return v |
| 277 » } |
| 278 » *flagField(&v) &^= flagRO |
| 279 » return v |
| 280 } |
| 281 |
| 282 // Sanity checks against future reflect package changes |
| 283 // to the type or semantics of the Value.flag field. |
| 284 func init() { |
| 285 » field, ok := reflect.TypeOf(reflect.Value{}).FieldByName("flag") |
| 286 » if !ok { |
| 287 » » panic("reflect.Value has no flag field") |
| 288 » } |
| 289 » if field.Type.Kind() != reflect.TypeOf(flag(0)).Kind() { |
| 290 » » panic("reflect.Value flag field has changed kind") |
| 291 » } |
| 292 » var t struct { |
| 293 » » a int |
| 294 » » A int |
| 295 » } |
| 296 » vA := reflect.ValueOf(t).FieldByName("A") |
| 297 » va := reflect.ValueOf(t).FieldByName("a") |
| 298 » flagA := *flagField(&vA) |
| 299 » flaga := *flagField(&va) |
| 300 » if flagA&flagRO != 0 || flaga&flagRO == 0 { |
| 301 » » panic("reflect.Value read-only flag has changed value") |
| 302 » } |
| 303 } |
OLD | NEW |