Left: | ||
Right: |
LEFT | RIGHT |
---|---|
1 // Copyright 2009 The Go Authors. All rights reserved. | 1 // Copyright 2009 The Go Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style | 2 // Use of this source code is governed by a BSD-style |
3 // license that can be found in the LICENSE file. | 3 // license that can be found in the LICENSE file. |
4 | 4 |
5 // This file contains operations on unsigned multi-precision integers. | 5 // This file contains operations on unsigned multi-precision integers. |
6 // These are the building blocks for the operations on signed integers | 6 // These are the building blocks for the operations on signed integers |
7 // and rationals. | 7 // and rationals. |
8 | 8 |
9 // This package implements multi-precision arithmetic (big numbers). | 9 // This package implements multi-precision arithmetic (big numbers). |
10 // The following numeric types are supported: | 10 // The following numeric types are supported: |
(...skipping 19 matching lines...) Expand all Loading... | |
30 // with the digits x[i] as the slice elements. | 30 // with the digits x[i] as the slice elements. |
31 // | 31 // |
32 // A number is normalized if the slice contains no leading 0 digits. | 32 // A number is normalized if the slice contains no leading 0 digits. |
33 // During arithmetic operations, denormalized values may occur but are | 33 // During arithmetic operations, denormalized values may occur but are |
34 // always normalized before returning the final result. The normalized | 34 // always normalized before returning the final result. The normalized |
35 // representation of 0 is the empty or nil slice (length = 0). | 35 // representation of 0 is the empty or nil slice (length = 0). |
36 | 36 |
37 type nat []Word | 37 type nat []Word |
38 | 38 |
39 var ( | 39 var ( |
40 » natZero = nat(nil) | 40 » natOne = nat{1} |
41 » natOne = nat{1} | 41 » natTwo = nat{2} |
42 » natTwo = nat{2} | |
43 ) | 42 ) |
43 | |
44 | |
45 func (z nat) clear() nat { | |
46 for i := range z { | |
47 z[i] = 0 | |
48 } | |
49 return z | |
50 } | |
44 | 51 |
45 | 52 |
46 func (z nat) norm() nat { | 53 func (z nat) norm() nat { |
47 i := len(z) | 54 i := len(z) |
48 for i > 0 && z[i-1] == 0 { | 55 for i > 0 && z[i-1] == 0 { |
49 i-- | 56 i-- |
50 } | 57 } |
51 z = z[0:i] | 58 z = z[0:i] |
52 return z | 59 return z |
53 } | 60 } |
54 | 61 |
55 | 62 |
56 func (z nat) make(m int, clear bool) nat { | 63 func (z nat) make(m int) nat { |
57 if cap(z) > m { | 64 if cap(z) > m { |
58 » » z = z[0:m] // reuse z - has at least one extra word for a carry, if any | 65 » » return z[0:m] // reuse z - has at least one extra word for a car ry, if any |
59 » » if clear { | |
60 » » » for i := range z { | |
61 » » » » z[i] = 0 | |
62 » » » } | |
63 » » } | |
64 » » return z | |
65 } | 66 } |
66 | 67 |
67 c := 4 // minimum capacity | 68 c := 4 // minimum capacity |
68 if m > c { | 69 if m > c { |
69 c = m | 70 c = m |
70 } | 71 } |
71 return make(nat, m, c+1) // +1: extra word for a carry, if any | 72 return make(nat, m, c+1) // +1: extra word for a carry, if any |
72 } | 73 } |
73 | 74 |
74 | 75 |
75 func (z nat) new(x uint64) nat { | 76 func (z nat) new(x uint64) nat { |
76 if x == 0 { | 77 if x == 0 { |
77 » » return z.make(0, false) | 78 » » return z.make(0) |
78 } | 79 } |
79 | 80 |
80 // single-digit values | 81 // single-digit values |
81 if x == uint64(Word(x)) { | 82 if x == uint64(Word(x)) { |
82 » » z = z.make(1, false) | 83 » » z = z.make(1) |
83 z[0] = Word(x) | 84 z[0] = Word(x) |
84 return z | 85 return z |
85 } | 86 } |
86 | 87 |
87 // compute number of words n required to represent x | 88 // compute number of words n required to represent x |
88 n := 0 | 89 n := 0 |
89 for t := x; t > 0; t >>= _W { | 90 for t := x; t > 0; t >>= _W { |
90 n++ | 91 n++ |
91 } | 92 } |
92 | 93 |
93 // split x into n words | 94 // split x into n words |
94 » z = z.make(n, false) | 95 » z = z.make(n) |
95 for i := 0; i < n; i++ { | 96 for i := 0; i < n; i++ { |
96 z[i] = Word(x & _M) | 97 z[i] = Word(x & _M) |
97 x >>= _W | 98 x >>= _W |
98 } | 99 } |
99 | 100 |
100 return z | 101 return z |
101 } | 102 } |
102 | 103 |
103 | 104 |
104 func (z nat) set(x nat) nat { | 105 func (z nat) set(x nat) nat { |
105 » z = z.make(len(x), false) | 106 » z = z.make(len(x)) |
106 for i, d := range x { | 107 for i, d := range x { |
107 z[i] = d | 108 z[i] = d |
108 } | 109 } |
109 return z | 110 return z |
110 } | 111 } |
111 | 112 |
112 | 113 |
113 func (z nat) add(x, y nat) nat { | 114 func (z nat) add(x, y nat) nat { |
114 m := len(x) | 115 m := len(x) |
115 n := len(y) | 116 n := len(y) |
116 | 117 |
117 switch { | 118 switch { |
118 case m < n: | 119 case m < n: |
119 return z.add(y, x) | 120 return z.add(y, x) |
120 case m == 0: | 121 case m == 0: |
121 // n == 0 because m >= n; result is 0 | 122 // n == 0 because m >= n; result is 0 |
122 » » return z.make(0, false) | 123 » » return z.make(0) |
123 case n == 0: | 124 case n == 0: |
124 // result is x | 125 // result is x |
125 return z.set(x) | 126 return z.set(x) |
126 } | 127 } |
127 // m > 0 | 128 // m > 0 |
128 | 129 |
129 » z = z.make(m, false) | 130 » z = z.make(m) |
130 c := addVV(&z[0], &x[0], &y[0], n) | 131 c := addVV(&z[0], &x[0], &y[0], n) |
131 if m > n { | 132 if m > n { |
132 c = addVW(&z[n], &x[n], c, m-n) | 133 c = addVW(&z[n], &x[n], c, m-n) |
133 } | 134 } |
134 if c > 0 { | 135 if c > 0 { |
135 z = z[0 : m+1] | 136 z = z[0 : m+1] |
136 z[m] = c | 137 z[m] = c |
137 } | 138 } |
138 | 139 |
139 return z | 140 return z |
140 } | 141 } |
141 | 142 |
142 | 143 |
143 func (z nat) sub(x, y nat) nat { | 144 func (z nat) sub(x, y nat) nat { |
144 m := len(x) | 145 m := len(x) |
145 n := len(y) | 146 n := len(y) |
146 | 147 |
147 switch { | 148 switch { |
148 case m < n: | 149 case m < n: |
149 panic("underflow") | 150 panic("underflow") |
150 case m == 0: | 151 case m == 0: |
151 // n == 0 because m >= n; result is 0 | 152 // n == 0 because m >= n; result is 0 |
152 » » return z.make(0, false) | 153 » » return z.make(0) |
153 case n == 0: | 154 case n == 0: |
154 // result is x | 155 // result is x |
155 return z.set(x) | 156 return z.set(x) |
156 } | 157 } |
157 // m > 0 | 158 // m > 0 |
158 | 159 |
159 » z = z.make(m, false) | 160 » z = z.make(m) |
160 c := subVV(&z[0], &x[0], &y[0], n) | 161 c := subVV(&z[0], &x[0], &y[0], n) |
161 if m > n { | 162 if m > n { |
162 c = subVW(&z[n], &x[n], c, m-n) | 163 c = subVW(&z[n], &x[n], c, m-n) |
163 } | 164 } |
164 if c != 0 { | 165 if c != 0 { |
165 panic("underflow") | 166 panic("underflow") |
166 } | 167 } |
167 z = z.norm() | 168 z = z.norm() |
168 | 169 |
169 return z | 170 return z |
(...skipping 28 matching lines...) Expand all Loading... | |
198 } | 199 } |
199 | 200 |
200 | 201 |
201 func (z nat) mulAddWW(x nat, y, r Word) nat { | 202 func (z nat) mulAddWW(x nat, y, r Word) nat { |
202 m := len(x) | 203 m := len(x) |
203 if m == 0 || y == 0 { | 204 if m == 0 || y == 0 { |
204 return z.new(uint64(r)) // result is r | 205 return z.new(uint64(r)) // result is r |
205 } | 206 } |
206 // m > 0 | 207 // m > 0 |
207 | 208 |
208 » z = z.make(m, false) | 209 » z = z.make(m) |
209 c := mulAddVWW(&z[0], &x[0], y, r, m) | 210 c := mulAddVWW(&z[0], &x[0], y, r, m) |
210 if c > 0 { | 211 if c > 0 { |
211 z = z[0 : m+1] | 212 z = z[0 : m+1] |
212 z[m] = c | 213 z[m] = c |
213 } | 214 } |
214 | 215 |
215 return z | 216 return z |
216 } | 217 } |
217 | 218 |
218 | 219 |
219 // Operands that are shorter than this threshold are multiplied using | 220 // basicMul multiplies x and y and leaves the result in z. |
220 // "grade school" multiplication; for larger operands the Karatsuba | 221 // The (non-normalized) result is placed in z[0 : len(x) + len(y)]. |
221 // algorithm is used. | 222 func basicMul(z, x, y nat) { |
222 // | 223 » // initialize z |
223 // The value has been found empirically for gotest -benchmarks=Fact | 224 » for i := range z[0 : len(x)+len(y)] { |
224 // on a machine running OS X on a 3.06GHz Intel Core 2 Duo. | 225 » » z[i] = 0 |
225 // | 226 » } |
226 // (To disable Karatsuba multiplication, set the threshold to a very | 227 » // multiply |
227 // large value). | 228 » for i, d := range y { |
228 const karatsubaThreshold = 245 | 229 » » if d != 0 { |
229 | 230 » » » z[len(x)+i] = addMulVVW(&z[i], &x[0], d, len(x)) |
230 func init() { | 231 » » } |
231 » if karatsubaThreshold <= 1 { | 232 » } |
232 » » panic("karatsubaThreshold must be > 1") | 233 } |
233 » } | 234 |
234 } | 235 |
235 | 236 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. |
236 | 237 // Factored out for readability - do not use outside karatsuba. |
237 func karatsubaAdd(z, x nat, n int) { | 238 func karatsubaAdd(z, x nat, n int) { |
238 if c := addVV(&z[0], &z[0], &x[0], n); c != 0 { | 239 if c := addVV(&z[0], &z[0], &x[0], n); c != 0 { |
239 addVW(&z[n], &z[n], c, n>>1) | 240 addVW(&z[n], &z[n], c, n>>1) |
240 } | 241 } |
241 } | 242 } |
242 | 243 |
243 | 244 |
245 // Like karatsubaAdd, but does subtract. | |
244 func karatsubaSub(z, x nat, n int) { | 246 func karatsubaSub(z, x nat, n int) { |
245 if c := subVV(&z[0], &z[0], &x[0], n); c != 0 { | 247 if c := subVV(&z[0], &z[0], &x[0], n); c != 0 { |
246 subVW(&z[n], &z[n], c, n>>1) | 248 subVW(&z[n], &z[n], c, n>>1) |
247 } | 249 } |
248 } | 250 } |
249 | 251 |
250 | 252 |
253 // Operands that are shorter than karatsubaThreshold are multiplied using | |
254 // "grade school" multiplication; for longer operands the Karatsuba algorithm | |
255 // is used. | |
256 var karatsubaThreshold int = 30 // modified by calibrate.go | |
257 | |
251 // karatsuba multiplies x and y and leaves the result in z. | 258 // karatsuba multiplies x and y and leaves the result in z. |
252 // Both x and y must have the same length and n must be a | 259 // Both x and y must have the same length n and n must be a |
253 // power of 2. The result vector z must have len(z) >= 6*n. | 260 // power of 2. The result vector z must have len(z) >= 6*n. |
254 // The (non-normalized) result is placed in z[0 : 2*n]. | 261 // The (non-normalized) result is placed in z[0 : 2*n]. |
255 func karatsuba(z, x, y nat) { | 262 func karatsuba(z, x, y nat) { |
256 n := len(y) | 263 n := len(y) |
257 | 264 |
258 » // Switch to basic multiplication if the numbers are small. | 265 » // Switch to basic multiplication if numbers are odd or small. |
259 » if n < karatsubaThreshold { | 266 » // (n is always even if karatsubaThreshold is even, but be |
260 » » // initialize z | 267 » // conservative) |
261 » » for i := 2*n - 1; i >= 0; i-- { | 268 » if n&1 != 0 || n < karatsubaThreshold || n < 2 { |
262 » » » z[i] = 0 | 269 » » basicMul(z, x, y) |
263 » » } | |
264 » » // "grade school" multiplication | |
265 » » for i, d := range y { | |
266 » » » if d != 0 { | |
267 » » » » z[n+i] = addMulVVW(&z[i], &x[0], d, n) | |
268 » » » } | |
269 » » } | |
270 return | 270 return |
271 } | 271 } |
272 » // n >= karatsubaThreshold > 1 | 272 » // n&1 == 0 && n >= karatsubaThreshold && n >= 2 |
273 | 273 |
274 // Karatsuba multiplication is based on the observation that | 274 // Karatsuba multiplication is based on the observation that |
275 // for two numbers x and y with: | 275 // for two numbers x and y with: |
276 // | 276 // |
277 // x = x1*b + x0 | 277 // x = x1*b + x0 |
278 // y = y1*b + y0 | 278 // y = y1*b + y0 |
279 // | 279 // |
280 // the product x*y can be obtained with 3 products z2, z1, z0 | 280 // the product x*y can be obtained with 3 products z2, z1, z0 |
281 // instead of 4: | 281 // instead of 4: |
282 // | 282 // |
(...skipping 12 matching lines...) Expand all Loading... | |
295 // = x1*y0 + x0*y1 | 295 // = x1*y0 + x0*y1 |
296 | 296 |
297 // split x, y into "digits" | 297 // split x, y into "digits" |
298 n2 := n >> 1 // n2 >= 1 | 298 n2 := n >> 1 // n2 >= 1 |
299 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 | 299 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 |
300 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 | 300 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 |
301 | 301 |
302 // z is used for the result and temporary storage: | 302 // z is used for the result and temporary storage: |
303 // | 303 // |
304 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n | 304 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n |
305 » // z = [z2 copy|z0 copy| xd*yd | xd:yd | x1*y1 | x0*y0 ] | 305 » // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] |
306 // | 306 // |
307 // For each recursive call of karatsuba, an unused slice of | 307 // For each recursive call of karatsuba, an unused slice of |
308 // z is passed in that has (at least) half the length of the | 308 // z is passed in that has (at least) half the length of the |
309 // caller's z. | 309 // caller's z. |
310 | 310 |
311 // compute z0 and z2 with the result "in place" in z | 311 // compute z0 and z2 with the result "in place" in z |
312 karatsuba(z, x0, y0) // z0 = x0*y0 | 312 karatsuba(z, x0, y0) // z0 = x0*y0 |
313 karatsuba(z[n:], x1, y1) // z2 = x1*y1 | 313 karatsuba(z[n:], x1, y1) // z2 = x1*y1 |
314 | |
315 // TODO(gri): In the following we carefully avoid underflow | |
316 // by recomputing differences and keeping track | |
317 // of sign changes. Can probably optimize this by | |
318 // simply ignoring the overflow but track sign changes | |
319 // and use this to sign extend the product xd*yd before | |
320 // adding it to z. This should remove quite a bit of code. | |
321 | 314 |
322 // compute xd (or the negative value if underflow occurs) | 315 // compute xd (or the negative value if underflow occurs) |
323 s := 1 // sign of product xd*yd | 316 s := 1 // sign of product xd*yd |
324 xd := z[2*n : 2*n+n2] | 317 xd := z[2*n : 2*n+n2] |
325 if subVV(&xd[0], &x1[0], &x0[0], n2) != 0 { // x1-x0 | 318 if subVV(&xd[0], &x1[0], &x0[0], n2) != 0 { // x1-x0 |
326 s = -s | 319 s = -s |
327 subVV(&xd[0], &x0[0], &x1[0], n2) // x0-x1 | 320 subVV(&xd[0], &x0[0], &x1[0], n2) // x0-x1 |
328 } | 321 } |
329 | 322 |
330 // compute yd (or the negative value if underflow occurs) | 323 // compute yd (or the negative value if underflow occurs) |
331 yd := z[2*n+n2 : 3*n] | 324 yd := z[2*n+n2 : 3*n] |
332 if subVV(&yd[0], &y0[0], &y1[0], n2) != 0 { // y0-y1 | 325 if subVV(&yd[0], &y0[0], &y1[0], n2) != 0 { // y0-y1 |
333 s = -s | 326 s = -s |
334 subVV(&yd[0], &y1[0], &y0[0], n2) // y1-y0 | 327 subVV(&yd[0], &y1[0], &y0[0], n2) // y1-y0 |
335 } | 328 } |
336 | 329 |
337 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 | 330 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 |
338 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 | 331 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 |
339 p := z[n*3:] | 332 p := z[n*3:] |
340 karatsuba(p, xd, yd) | 333 karatsuba(p, xd, yd) |
341 | 334 |
342 // save original z2:z0 | 335 // save original z2:z0 |
343 // (ok to use upper half of z since we're done recursing) | 336 // (ok to use upper half of z since we're done recursing) |
344 r := z[n*4:] | 337 r := z[n*4:] |
345 copy(r, z) | 338 copy(r, z) |
346 | 339 |
347 // add up all partial products | 340 // add up all partial products |
348 // | 341 // |
342 // 2*n n 0 | |
349 // z = [ z2 | z0 ] | 343 // z = [ z2 | z0 ] |
350 // + [ z0 ] | 344 // + [ z0 ] |
351 // + [ z2 ] | 345 // + [ z2 ] |
352 // + [ p ] | 346 // + [ p ] |
353 // | 347 // |
354 karatsubaAdd(z[n2:], r, n) | 348 karatsubaAdd(z[n2:], r, n) |
355 karatsubaAdd(z[n2:], r[n:], n) | 349 karatsubaAdd(z[n2:], r[n:], n) |
356 if s > 0 { | 350 if s > 0 { |
357 karatsubaAdd(z[n2:], p, n) | 351 karatsubaAdd(z[n2:], p, n) |
358 } else { | 352 } else { |
359 karatsubaSub(z[n2:], p, n) | 353 karatsubaSub(z[n2:], p, n) |
360 } | 354 } |
361 } | 355 } |
362 | 356 |
363 | 357 |
364 // alias returns true if x and y share the same base array. | 358 // alias returns true if x and y share the same base array. |
365 func alias(x, y nat) bool { | 359 func alias(x, y nat) bool { |
366 return &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] | 360 return &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] |
367 } | 361 } |
368 | 362 |
369 | 363 |
370 // addAt implements z += x*(1<<(_W*i)); z must be long enough. | 364 // addAt implements z += x*(1<<(_W*i)); z must be long enough. |
365 // (we don't use nat.add because we need z to stay the same | |
366 // slice, and we don't need to normalize z after each addition) | |
371 func addAt(z, x nat, i int) { | 367 func addAt(z, x nat, i int) { |
372 if n := len(x); n > 0 { | 368 if n := len(x); n > 0 { |
373 if c := addVV(&z[i], &z[i], &x[0], n); c != 0 { | 369 if c := addVV(&z[i], &z[i], &x[0], n); c != 0 { |
374 j := i + n | 370 j := i + n |
375 if j < len(z) { | 371 if j < len(z) { |
376 addVW(&z[j], &z[j], c, len(z)-j) | 372 addVW(&z[j], &z[j], c, len(z)-j) |
377 } | 373 } |
378 } | 374 } |
379 } | 375 } |
380 } | 376 } |
381 | 377 |
382 | 378 |
379 func max(x, y int) int { | |
380 if x > y { | |
381 return x | |
382 } | |
383 return y | |
384 } | |
385 | |
386 | |
383 func (z nat) mul(x, y nat) nat { | 387 func (z nat) mul(x, y nat) nat { |
384 m := len(x) | 388 m := len(x) |
385 n := len(y) | 389 n := len(y) |
386 | 390 |
387 switch { | 391 switch { |
388 case m < n: | 392 case m < n: |
389 return z.mul(y, x) | 393 return z.mul(y, x) |
390 case m == 0 || n == 0: | 394 case m == 0 || n == 0: |
391 » » return z.make(0, false) | 395 » » return z.make(0) |
392 case n == 1: | 396 case n == 1: |
393 return z.mulAddWW(x, y[0], 0) | 397 return z.mulAddWW(x, y[0], 0) |
394 } | 398 } |
395 » // m >= n && m > 1 && n > 1 | 399 » // m >= n > 1 |
396 | 400 |
397 // determine if z can be reused | 401 // determine if z can be reused |
398 if len(z) > 0 && (alias(z, x) || alias(z, y)) { | 402 if len(z) > 0 && (alias(z, x) || alias(z, y)) { |
399 z = nil // z is an alias for x or y - cannot reuse | 403 z = nil // z is an alias for x or y - cannot reuse |
400 } | 404 } |
401 | 405 |
402 » if n < karatsubaThreshold { | 406 » // use basic multiplication if the numbers are small |
403 » » // "grade school" multiplication | 407 » if n < karatsubaThreshold || n < 2 { |
404 » » z = z.make(m+n, true) | 408 » » z = z.make(m + n) |
405 » » for i, d := range y { | 409 » » basicMul(z, x, y) |
406 » » » if d != 0 { | |
407 » » » » z[m+i] = addMulVVW(&z[i], &x[0], d, m) | |
408 » » » } | |
409 » » } | |
410 return z.norm() | 410 return z.norm() |
411 } | 411 } |
412 » // m >= n && n >= karatsubaThreshold | 412 » // m >= n && n >= karatsubaThreshold && n >= 2 |
413 | 413 |
414 » // Note that even though we passed the Karatsuba threshold, | 414 » // determine largest k such that |
415 » // because we tested against n and not k (see below) we may | |
416 » // still end up using grade-school multiplication, albeit with | |
417 » // an intermediate step if the Karatsuba theshold is not a | |
418 » // power of 2. It appears that this intermediate step makes | |
419 » // things faster (e.g., the threshold is < 256 at the moment). | |
420 » // Theoretically, there are more operations involved but the numbers | |
421 » // are larger and thus "internal fragmentation" (i.e., total number | |
422 » // unused bits in leading words) may be smaller, possibly resulting | |
423 » // in fewer actual machine multiplications. | |
424 | |
425 » // Determine k such that: | |
426 // | 415 // |
427 // x = x1*b + x0 | 416 // x = x1*b + x0 |
428 // y = y1*b + y0 (and k <= len(y), which implies k <= len(x)) | 417 // y = y1*b + y0 (and k <= len(y), which implies k <= len(x)) |
429 // | |
430 // and | |
431 // | |
432 // b = 1<<(_W*k) ("base" of digits xi, yi) | 418 // b = 1<<(_W*k) ("base" of digits xi, yi) |
433 // | 419 // |
434 » k := 1 << uint(log2(Word(n))) | 420 » // and k is karatsubaThreshold multiplied by a power of 2 |
435 | 421 » k := max(karatsubaThreshold, 2) |
436 » // If x1 and/or y1 are not 0, compute product explicitly: | 422 » for k*2 <= n { |
437 » // | 423 » » k *= 2 |
438 » // x*y = x1*y1*b*b + x1*y0*b + x0*y1*b + x0*y0 | 424 » } |
425 » // k <= n | |
426 | |
427 » // multiply x0 and y0 via Karatsuba | |
428 » x0 := x[0:k] // x0 is not normalized | |
429 » y0 := y[0:k] // y0 is not normalized | |
430 » z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and ful l result of x*y | |
431 » karatsuba(z, x0, y0) | |
432 » z = z[0 : m+n] // z has final length but may be incomplete, upper portio n is garbage | |
433 | |
434 » // If x1 and/or y1 are not 0, add missing terms to z explicitly: | |
435 » // | |
436 » // m+n 2*k 0 | |
437 » // z = [ ... | x0*y0 ] | |
438 » // + [ x1*y1 ] | |
439 » // + [ x1*y0 ] | |
440 » // + [ x0*y1 ] | |
439 // | 441 // |
440 if k < n || m != n { | 442 if k < n || m != n { |
441 » » x1, x0 := x[k:], x[0:k].norm() // x1 is normalized because x is | 443 » » x1 := x[k:] // x1 is normalized because x is |
442 » » y1, y0 := y[k:], y[0:k].norm() // y1 is normalized because y is | 444 » » y1 := y[k:] // y1 is normalized because y is |
443 » » z = z.make(m+n, true) | 445 » » var t nat |
444 » » copy(z[2*k:], natZero.mul(x1, y1)) // z may not be normalized! | 446 » » t = t.mul(x1, y1) |
445 » » addAt(z, natZero.mul(x1, y0), k) | 447 » » copy(z[2*k:], t) |
446 » » addAt(z, natZero.mul(x0, y1), k) | 448 » » z[2*k+len(t):].clear() // upper portion of z is garbage |
447 » » addAt(z, natZero.mul(x0, y0), 0) // (could invoke karatsuba for x0, y0 directly) | 449 » » t = t.mul(x1, y0.norm()) |
448 » » return z.norm() | 450 » » addAt(z, t, k) |
449 » } | 451 » » t = t.mul(x0.norm(), y1) |
450 » // k == n && m == n | 452 » » addAt(z, t, k) |
451 | 453 » } |
452 » // Both x and y have the same length k which is a power of 2 | 454 |
453 » // and thus are directly suitable for Karatsuba multiplication. | 455 » return z.norm() |
454 » z = z.make(6*k, false) | |
455 » karatsuba(z, x, y) | |
456 » return z[0 : 2*n].norm() | |
457 } | 456 } |
458 | 457 |
459 | 458 |
460 // mulRange computes the product of all the unsigned integers in the | 459 // mulRange computes the product of all the unsigned integers in the |
461 // range [a, b] inclusively. If a > b (empty range), the result is 1. | 460 // range [a, b] inclusively. If a > b (empty range), the result is 1. |
462 func (z nat) mulRange(a, b uint64) nat { | 461 func (z nat) mulRange(a, b uint64) nat { |
adonovan
2020/02/12 14:46:30
What's the purpose of treating the sequence as lea
gri
2020/02/19 00:51:24
The recursive approach appears faster in practice
| |
463 switch { | 462 switch { |
464 case a == 0: | 463 case a == 0: |
465 // cut long ranges short (optimization) | 464 // cut long ranges short (optimization) |
466 return z.new(0) | 465 return z.new(0) |
467 case a > b: | 466 case a > b: |
468 return z.new(1) | 467 return z.new(1) |
469 case a == b: | 468 case a == b: |
470 return z.new(a) | 469 return z.new(a) |
471 case a+1 == b: | 470 case a+1 == b: |
472 » » return z.mul(natZero.new(a), natZero.new(b)) | 471 » » return z.mul(nat(nil).new(a), nat(nil).new(b)) |
473 } | 472 } |
474 m := (a + b) / 2 | 473 m := (a + b) / 2 |
475 » return z.mul(natZero.mulRange(a, m), natZero.mulRange(m+1, b)) | 474 » return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b)) |
476 } | 475 } |
477 | 476 |
478 | 477 |
479 // q = (x-r)/y, with 0 <= r < y | 478 // q = (x-r)/y, with 0 <= r < y |
480 func (z nat) divW(x nat, y Word) (q nat, r Word) { | 479 func (z nat) divW(x nat, y Word) (q nat, r Word) { |
481 m := len(x) | 480 m := len(x) |
482 switch { | 481 switch { |
483 case y == 0: | 482 case y == 0: |
484 panic("division by zero") | 483 panic("division by zero") |
485 case y == 1: | 484 case y == 1: |
486 q = z.set(x) // result is x | 485 q = z.set(x) // result is x |
487 return | 486 return |
488 case m == 0: | 487 case m == 0: |
489 q = z.set(nil) // result is 0 | 488 q = z.set(nil) // result is 0 |
490 return | 489 return |
491 } | 490 } |
492 // m > 0 | 491 // m > 0 |
493 » z = z.make(m, false) | 492 » z = z.make(m) |
494 r = divWVW(&z[0], 0, &x[0], y, m) | 493 r = divWVW(&z[0], 0, &x[0], y, m) |
495 q = z.norm() | 494 q = z.norm() |
496 return | 495 return |
497 } | 496 } |
498 | 497 |
499 | 498 |
500 func (z nat) div(z2, u, v nat) (q, r nat) { | 499 func (z nat) div(z2, u, v nat) (q, r nat) { |
501 if len(v) == 0 { | 500 if len(v) == 0 { |
502 panic("division by zero") | 501 panic("division by zero") |
503 } | 502 } |
504 | 503 |
505 if u.cmp(v) < 0 { | 504 if u.cmp(v) < 0 { |
506 » » q = z.make(0, false) | 505 » » q = z.make(0) |
507 r = z2.set(u) | 506 r = z2.set(u) |
508 return | 507 return |
509 } | 508 } |
510 | 509 |
511 if len(v) == 1 { | 510 if len(v) == 1 { |
512 var rprime Word | 511 var rprime Word |
513 q, rprime = z.divW(u, v[0]) | 512 q, rprime = z.divW(u, v[0]) |
514 if rprime > 0 { | 513 if rprime > 0 { |
515 » » » r = z2.make(1, false) | 514 » » » r = z2.make(1) |
516 r[0] = rprime | 515 r[0] = rprime |
517 } else { | 516 } else { |
518 » » » r = z2.make(0, false) | 517 » » » r = z2.make(0) |
519 } | 518 } |
520 return | 519 return |
521 } | 520 } |
522 | 521 |
523 q, r = z.divLarge(z2, u, v) | 522 q, r = z.divLarge(z2, u, v) |
524 return | 523 return |
525 } | 524 } |
526 | 525 |
527 | 526 |
528 // q = (uIn-r)/v, with 0 <= r < y | 527 // q = (uIn-r)/v, with 0 <= r < y |
529 // See Knuth, Volume 2, section 4.3.1, Algorithm D. | 528 // See Knuth, Volume 2, section 4.3.1, Algorithm D. |
530 // Preconditions: | 529 // Preconditions: |
531 // len(v) >= 2 | 530 // len(v) >= 2 |
532 // len(uIn) >= len(v) | 531 // len(uIn) >= len(v) |
533 func (z nat) divLarge(z2, uIn, v nat) (q, r nat) { | 532 func (z nat) divLarge(z2, uIn, v nat) (q, r nat) { |
534 n := len(v) | 533 n := len(v) |
535 m := len(uIn) - len(v) | 534 m := len(uIn) - len(v) |
536 | 535 |
537 var u nat | 536 var u nat |
538 if z2 == nil || &z2[0] == &uIn[0] { | 537 if z2 == nil || &z2[0] == &uIn[0] { |
539 » » u = u.make(len(uIn)+1, true) // uIn is an alias for z2 | 538 » » u = u.make(len(uIn) + 1).clear() // uIn is an alias for z2 |
540 } else { | 539 } else { |
541 » » u = z2.make(len(uIn)+1, true) | 540 » » u = z2.make(len(uIn) + 1).clear() |
542 } | 541 } |
543 qhatv := make(nat, len(v)+1) | 542 qhatv := make(nat, len(v)+1) |
544 » q = z.make(m+1, false) | 543 » q = z.make(m + 1) |
545 | 544 |
546 // D1. | 545 // D1. |
547 shift := uint(leadingZeroBits(v[n-1])) | 546 shift := uint(leadingZeroBits(v[n-1])) |
548 v.shiftLeft(v, shift) | 547 v.shiftLeft(v, shift) |
549 u.shiftLeft(uIn, shift) | 548 u.shiftLeft(uIn, shift) |
550 u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift)) | 549 u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift)) |
551 | 550 |
552 // D2. | 551 // D2. |
553 for j := m; j >= 0; j-- { | 552 for j := m; j >= 0; j-- { |
554 // D3. | 553 // D3. |
(...skipping 136 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
691 | 690 |
692 if len(x) == 0 { | 691 if len(x) == 0 { |
693 return "0" | 692 return "0" |
694 } | 693 } |
695 | 694 |
696 // allocate buffer for conversion | 695 // allocate buffer for conversion |
697 i := (x.log2()+1)/log2(Word(base)) + 1 // +1: round up | 696 i := (x.log2()+1)/log2(Word(base)) + 1 // +1: round up |
698 s := make([]byte, i) | 697 s := make([]byte, i) |
699 | 698 |
700 // don't destroy x | 699 // don't destroy x |
701 » q := natZero.set(x) | 700 » q := nat(nil).set(x) |
702 | 701 |
703 // convert | 702 // convert |
704 for len(q) > 0 { | 703 for len(q) > 0 { |
705 i-- | 704 i-- |
706 var r Word | 705 var r Word |
707 q, r = q.divW(q, Word(base)) | 706 q, r = q.divW(q, Word(base)) |
708 s[i] = "0123456789abcdef"[r] | 707 s[i] = "0123456789abcdef"[r] |
709 } | 708 } |
710 | 709 |
711 return string(s[i:]) | 710 return string(s[i:]) |
(...skipping 52 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
764 case 64: | 763 case 64: |
765 return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) | 764 return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) |
766 default: | 765 default: |
767 panic("Unknown word size") | 766 panic("Unknown word size") |
768 } | 767 } |
769 | 768 |
770 return 0 | 769 return 0 |
771 } | 770 } |
772 | 771 |
773 | 772 |
773 // TODO(gri) Make the shift routines faster. | |
774 // Use pidigits.go benchmark as a test case. | |
775 | |
774 // To avoid losing the top n bits, z should be sized so that | 776 // To avoid losing the top n bits, z should be sized so that |
775 // len(z) == len(x) + 1. | 777 // len(z) == len(x) + 1. |
776 func (z nat) shiftLeft(x nat, n uint) nat { | 778 func (z nat) shiftLeft(x nat, n uint) nat { |
777 if len(x) == 0 { | 779 if len(x) == 0 { |
778 return x | 780 return x |
779 } | 781 } |
780 | 782 |
781 ñ := _W - n | 783 ñ := _W - n |
782 m := x[len(x)-1] | 784 m := x[len(x)-1] |
783 if len(z) > len(x) { | 785 if len(z) > len(x) { |
(...skipping 27 matching lines...) Expand all Loading... | |
811 | 813 |
812 | 814 |
813 // greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2) | 815 // greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2) |
814 func greaterThan(x1, x2, y1, y2 Word) bool { return x1 > y1 || x1 == y1 && x2 > y2 } | 816 func greaterThan(x1, x2, y1, y2 Word) bool { return x1 > y1 || x1 == y1 && x2 > y2 } |
815 | 817 |
816 | 818 |
817 // modW returns x % d. | 819 // modW returns x % d. |
818 func (x nat) modW(d Word) (r Word) { | 820 func (x nat) modW(d Word) (r Word) { |
819 // TODO(agl): we don't actually need to store the q value. | 821 // TODO(agl): we don't actually need to store the q value. |
820 var q nat | 822 var q nat |
821 » q = q.make(len(x), false) | 823 » q = q.make(len(x)) |
822 return divWVW(&q[0], 0, &x[0], d, len(x)) | 824 return divWVW(&q[0], 0, &x[0], d, len(x)) |
823 } | 825 } |
824 | 826 |
825 | 827 |
826 // powersOfTwoDecompose finds q and k such that q * 1<<k = n and q is odd. | 828 // powersOfTwoDecompose finds q and k such that q * 1<<k = n and q is odd. |
827 func (n nat) powersOfTwoDecompose() (q nat, k Word) { | 829 func (n nat) powersOfTwoDecompose() (q nat, k Word) { |
828 if len(n) == 0 { | 830 if len(n) == 0 { |
829 return n, 0 | 831 return n, 0 |
830 } | 832 } |
831 | 833 |
832 zeroWords := 0 | 834 zeroWords := 0 |
833 for n[zeroWords] == 0 { | 835 for n[zeroWords] == 0 { |
834 zeroWords++ | 836 zeroWords++ |
835 } | 837 } |
836 // One of the words must be non-zero by invariant, therefore | 838 // One of the words must be non-zero by invariant, therefore |
837 // zeroWords < len(n). | 839 // zeroWords < len(n). |
838 x := trailingZeroBits(n[zeroWords]) | 840 x := trailingZeroBits(n[zeroWords]) |
839 | 841 |
840 » q = q.make(len(n)-zeroWords, false) | 842 » q = q.make(len(n) - zeroWords) |
841 q.shiftRight(n[zeroWords:], uint(x)) | 843 q.shiftRight(n[zeroWords:], uint(x)) |
842 q = q.norm() | 844 q = q.norm() |
843 | 845 |
844 k = Word(_W*zeroWords + x) | 846 k = Word(_W*zeroWords + x) |
845 return | 847 return |
846 } | 848 } |
847 | 849 |
848 | 850 |
849 // random creates a random integer in [0..limit), using the space in z if | 851 // random creates a random integer in [0..limit), using the space in z if |
850 // possible. n is the bit length of limit. | 852 // possible. n is the bit length of limit. |
851 func (z nat) random(rand *rand.Rand, limit nat, n int) nat { | 853 func (z nat) random(rand *rand.Rand, limit nat, n int) nat { |
852 bitLengthOfMSW := uint(n % _W) | 854 bitLengthOfMSW := uint(n % _W) |
853 if bitLengthOfMSW == 0 { | 855 if bitLengthOfMSW == 0 { |
854 bitLengthOfMSW = _W | 856 bitLengthOfMSW = _W |
855 } | 857 } |
856 mask := Word((1 << bitLengthOfMSW) - 1) | 858 mask := Word((1 << bitLengthOfMSW) - 1) |
857 » z = z.make(len(limit), false) | 859 » z = z.make(len(limit)) |
858 | 860 |
859 for { | 861 for { |
860 for i := range z { | 862 for i := range z { |
861 switch _W { | 863 switch _W { |
862 case 32: | 864 case 32: |
863 z[i] = Word(rand.Uint32()) | 865 z[i] = Word(rand.Uint32()) |
864 case 64: | 866 case 64: |
865 z[i] = Word(rand.Uint32()) | Word(rand.Uint32()) <<32 | 867 z[i] = Word(rand.Uint32()) | Word(rand.Uint32()) <<32 |
866 } | 868 } |
867 } | 869 } |
868 | 870 |
869 z[len(limit)-1] &= mask | 871 z[len(limit)-1] &= mask |
870 | 872 |
871 if z.cmp(limit) < 0 { | 873 if z.cmp(limit) < 0 { |
872 break | 874 break |
873 } | 875 } |
874 } | 876 } |
875 | 877 |
876 return z.norm() | 878 return z.norm() |
877 } | 879 } |
878 | 880 |
879 | 881 |
880 // If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It | 882 // If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It |
881 // reuses the storage of z if possible. | 883 // reuses the storage of z if possible. |
882 func (z nat) expNN(x, y, m nat) nat { | 884 func (z nat) expNN(x, y, m nat) nat { |
883 if len(y) == 0 { | 885 if len(y) == 0 { |
884 » » z = z.make(1, false) | 886 » » z = z.make(1) |
885 z[0] = 1 | 887 z[0] = 1 |
886 return z | 888 return z |
887 } | 889 } |
888 | 890 |
889 if m != nil { | 891 if m != nil { |
890 // We likely end up being as long as the modulus. | 892 // We likely end up being as long as the modulus. |
891 » » z = z.make(len(m), false) | 893 » » z = z.make(len(m)) |
892 } | 894 } |
893 z = z.set(x) | 895 z = z.set(x) |
894 v := y[len(y)-1] | 896 v := y[len(y)-1] |
895 // It's invalid for the most significant word to be zero, therefore we | 897 // It's invalid for the most significant word to be zero, therefore we |
896 // will find a one bit. | 898 // will find a one bit. |
897 shift := leadingZeros(v) + 1 | 899 shift := leadingZeros(v) + 1 |
898 v <<= shift | 900 v <<= shift |
899 var q nat | 901 var q nat |
900 | 902 |
901 const mask = 1 << (_W - 1) | 903 const mask = 1 << (_W - 1) |
(...skipping 92 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
994 if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 || | 996 if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 || |
995 r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 { | 997 r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 { |
996 return false | 998 return false |
997 } | 999 } |
998 | 1000 |
999 if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 || | 1001 if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 || |
1000 r%43 == 0 || r%47 == 0 || r%53 == 0) { | 1002 r%43 == 0 || r%47 == 0 || r%53 == 0) { |
1001 return false | 1003 return false |
1002 } | 1004 } |
1003 | 1005 |
1004 » nm1 := natZero.sub(n, natOne) | 1006 » nm1 := nat(nil).sub(n, natOne) |
1005 // 1<<k * q = nm1; | 1007 // 1<<k * q = nm1; |
1006 q, k := nm1.powersOfTwoDecompose() | 1008 q, k := nm1.powersOfTwoDecompose() |
1007 | 1009 |
1008 » nm3 := natZero.sub(nm1, natTwo) | 1010 » nm3 := nat(nil).sub(nm1, natTwo) |
1009 rand := rand.New(rand.NewSource(int64(n[0]))) | 1011 rand := rand.New(rand.NewSource(int64(n[0]))) |
1010 | 1012 |
1011 var x, y, quotient nat | 1013 var x, y, quotient nat |
1012 nm3Len := nm3.len() | 1014 nm3Len := nm3.len() |
1013 | 1015 |
1014 NextRandom: | 1016 NextRandom: |
1015 for i := 0; i < reps; i++ { | 1017 for i := 0; i < reps; i++ { |
1016 x = x.random(rand, nm3, nm3Len) | 1018 x = x.random(rand, nm3, nm3Len) |
1017 x = x.add(x, natTwo) | 1019 x = x.add(x, natTwo) |
1018 y = y.expNN(x, q, n) | 1020 y = y.expNN(x, q, n) |
1019 if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { | 1021 if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { |
1020 continue | 1022 continue |
1021 } | 1023 } |
1022 for j := Word(1); j < k; j++ { | 1024 for j := Word(1); j < k; j++ { |
1023 y = y.mul(y, y) | 1025 y = y.mul(y, y) |
1024 quotient, y = quotient.div(y, y, n) | 1026 quotient, y = quotient.div(y, y, n) |
1025 if y.cmp(nm1) == 0 { | 1027 if y.cmp(nm1) == 0 { |
1026 continue NextRandom | 1028 continue NextRandom |
1027 } | 1029 } |
1028 if y.cmp(natOne) == 0 { | 1030 if y.cmp(natOne) == 0 { |
1029 return false | 1031 return false |
1030 } | 1032 } |
1031 } | 1033 } |
1032 return false | 1034 return false |
1033 } | 1035 } |
1034 | 1036 |
1035 return true | 1037 return true |
1036 } | 1038 } |
LEFT | RIGHT |