Left: | ||
Right: |
LEFT | RIGHT |
---|---|
1 // Copyright 2009 The Go Authors. All rights reserved. | 1 // Copyright 2009 The Go Authors. All rights reserved. |
2 // Use of this source code is governed by a BSD-style | 2 // Use of this source code is governed by a BSD-style |
3 // license that can be found in the LICENSE file. | 3 // license that can be found in the LICENSE file. |
4 | 4 |
5 // This file contains operations on unsigned multi-precision integers. | 5 // This file contains operations on unsigned multi-precision integers. |
6 // These are the building blocks for the operations on signed integers | 6 // These are the building blocks for the operations on signed integers |
7 // and rationals. | 7 // and rationals. |
8 | 8 |
9 // This package implements multi-precision arithmetic (big numbers). | 9 // This package implements multi-precision arithmetic (big numbers). |
10 // The following numeric types are supported: | 10 // The following numeric types are supported: |
(...skipping 24 matching lines...) Expand all Loading... | |
35 // representation of 0 is the empty or nil slice (length = 0). | 35 // representation of 0 is the empty or nil slice (length = 0). |
36 | 36 |
37 type nat []Word | 37 type nat []Word |
38 | 38 |
39 var ( | 39 var ( |
40 natOne = nat{1} | 40 natOne = nat{1} |
41 natTwo = nat{2} | 41 natTwo = nat{2} |
42 ) | 42 ) |
43 | 43 |
44 | 44 |
45 func (z nat) clear() nat { | |
46 for i := range z { | |
47 z[i] = 0 | |
48 } | |
49 return z | |
50 } | |
51 | |
52 | |
45 func (z nat) norm() nat { | 53 func (z nat) norm() nat { |
46 i := len(z) | 54 i := len(z) |
47 for i > 0 && z[i-1] == 0 { | 55 for i > 0 && z[i-1] == 0 { |
48 i-- | 56 i-- |
49 } | 57 } |
50 z = z[0:i] | 58 z = z[0:i] |
51 return z | 59 return z |
52 } | 60 } |
53 | 61 |
54 | 62 |
55 func (z nat) make(m int, clear bool) nat { | 63 func (z nat) make(m int) nat { |
56 if cap(z) > m { | 64 if cap(z) > m { |
57 » » z = z[0:m] // reuse z - has at least one extra word for a carry, if any | 65 » » return z[0:m] // reuse z - has at least one extra word for a car ry, if any |
58 » » if clear { | |
59 » » » for i := range z { | |
60 » » » » z[i] = 0 | |
61 » » » } | |
62 » » } | |
63 » » return z | |
64 } | 66 } |
65 | 67 |
66 c := 4 // minimum capacity | 68 c := 4 // minimum capacity |
67 if m > c { | 69 if m > c { |
68 c = m | 70 c = m |
69 } | 71 } |
70 return make(nat, m, c+1) // +1: extra word for a carry, if any | 72 return make(nat, m, c+1) // +1: extra word for a carry, if any |
71 } | 73 } |
72 | 74 |
73 | 75 |
74 func (z nat) new(x uint64) nat { | 76 func (z nat) new(x uint64) nat { |
75 if x == 0 { | 77 if x == 0 { |
76 » » return z.make(0, false) | 78 » » return z.make(0) |
77 } | 79 } |
78 | 80 |
79 // single-digit values | 81 // single-digit values |
80 if x == uint64(Word(x)) { | 82 if x == uint64(Word(x)) { |
81 » » z = z.make(1, false) | 83 » » z = z.make(1) |
82 z[0] = Word(x) | 84 z[0] = Word(x) |
83 return z | 85 return z |
84 } | 86 } |
85 | 87 |
86 // compute number of words n required to represent x | 88 // compute number of words n required to represent x |
87 n := 0 | 89 n := 0 |
88 for t := x; t > 0; t >>= _W { | 90 for t := x; t > 0; t >>= _W { |
89 n++ | 91 n++ |
90 } | 92 } |
91 | 93 |
92 // split x into n words | 94 // split x into n words |
93 » z = z.make(n, false) | 95 » z = z.make(n) |
94 for i := 0; i < n; i++ { | 96 for i := 0; i < n; i++ { |
95 z[i] = Word(x & _M) | 97 z[i] = Word(x & _M) |
96 x >>= _W | 98 x >>= _W |
97 } | 99 } |
98 | 100 |
99 return z | 101 return z |
100 } | 102 } |
101 | 103 |
102 | 104 |
103 func (z nat) set(x nat) nat { | 105 func (z nat) set(x nat) nat { |
104 » z = z.make(len(x), false) | 106 » z = z.make(len(x)) |
105 for i, d := range x { | 107 for i, d := range x { |
106 z[i] = d | 108 z[i] = d |
107 } | 109 } |
108 return z | 110 return z |
109 } | 111 } |
110 | 112 |
111 | 113 |
112 func (z nat) add(x, y nat) nat { | 114 func (z nat) add(x, y nat) nat { |
113 m := len(x) | 115 m := len(x) |
114 n := len(y) | 116 n := len(y) |
115 | 117 |
116 switch { | 118 switch { |
117 case m < n: | 119 case m < n: |
118 return z.add(y, x) | 120 return z.add(y, x) |
119 case m == 0: | 121 case m == 0: |
120 // n == 0 because m >= n; result is 0 | 122 // n == 0 because m >= n; result is 0 |
121 » » return z.make(0, false) | 123 » » return z.make(0) |
122 case n == 0: | 124 case n == 0: |
123 // result is x | 125 // result is x |
124 return z.set(x) | 126 return z.set(x) |
125 } | 127 } |
126 // m > 0 | 128 // m > 0 |
127 | 129 |
128 » z = z.make(m, false) | 130 » z = z.make(m) |
129 c := addVV(&z[0], &x[0], &y[0], n) | 131 c := addVV(&z[0], &x[0], &y[0], n) |
130 if m > n { | 132 if m > n { |
131 c = addVW(&z[n], &x[n], c, m-n) | 133 c = addVW(&z[n], &x[n], c, m-n) |
132 } | 134 } |
133 if c > 0 { | 135 if c > 0 { |
134 z = z[0 : m+1] | 136 z = z[0 : m+1] |
135 z[m] = c | 137 z[m] = c |
136 } | 138 } |
137 | 139 |
138 return z | 140 return z |
139 } | 141 } |
140 | 142 |
141 | 143 |
142 func (z nat) sub(x, y nat) nat { | 144 func (z nat) sub(x, y nat) nat { |
143 m := len(x) | 145 m := len(x) |
144 n := len(y) | 146 n := len(y) |
145 | 147 |
146 switch { | 148 switch { |
147 case m < n: | 149 case m < n: |
148 panic("underflow") | 150 panic("underflow") |
149 case m == 0: | 151 case m == 0: |
150 // n == 0 because m >= n; result is 0 | 152 // n == 0 because m >= n; result is 0 |
151 » » return z.make(0, false) | 153 » » return z.make(0) |
152 case n == 0: | 154 case n == 0: |
153 // result is x | 155 // result is x |
154 return z.set(x) | 156 return z.set(x) |
155 } | 157 } |
156 // m > 0 | 158 // m > 0 |
157 | 159 |
158 » z = z.make(m, false) | 160 » z = z.make(m) |
159 c := subVV(&z[0], &x[0], &y[0], n) | 161 c := subVV(&z[0], &x[0], &y[0], n) |
160 if m > n { | 162 if m > n { |
161 c = subVW(&z[n], &x[n], c, m-n) | 163 c = subVW(&z[n], &x[n], c, m-n) |
162 } | 164 } |
163 if c != 0 { | 165 if c != 0 { |
164 panic("underflow") | 166 panic("underflow") |
165 } | 167 } |
166 z = z.norm() | 168 z = z.norm() |
167 | 169 |
168 return z | 170 return z |
(...skipping 28 matching lines...) Expand all Loading... | |
197 } | 199 } |
198 | 200 |
199 | 201 |
200 func (z nat) mulAddWW(x nat, y, r Word) nat { | 202 func (z nat) mulAddWW(x nat, y, r Word) nat { |
201 m := len(x) | 203 m := len(x) |
202 if m == 0 || y == 0 { | 204 if m == 0 || y == 0 { |
203 return z.new(uint64(r)) // result is r | 205 return z.new(uint64(r)) // result is r |
204 } | 206 } |
205 // m > 0 | 207 // m > 0 |
206 | 208 |
207 » z = z.make(m, false) | 209 » z = z.make(m) |
208 c := mulAddVWW(&z[0], &x[0], y, r, m) | 210 c := mulAddVWW(&z[0], &x[0], y, r, m) |
209 if c > 0 { | 211 if c > 0 { |
210 z = z[0 : m+1] | 212 z = z[0 : m+1] |
211 z[m] = c | 213 z[m] = c |
212 } | 214 } |
213 | 215 |
214 return z | 216 return z |
215 } | 217 } |
216 | 218 |
217 | 219 |
218 // Operands that are shorter than this threshold are multiplied using | 220 // basicMul multiplies x and y and leaves the result in z. |
219 // "grade school" multiplication; for larger operands the Karatsuba | 221 // The (non-normalized) result is placed in z[0 : len(x) + len(y)]. |
220 // algorithm is used. | 222 func basicMul(z, x, y nat) { |
221 // | 223 » // initialize z |
222 // The value has been found empirically for gotest -benchmarks=Fact | 224 » for i := range z[0 : len(x)+len(y)] { |
223 // on a machine running OS X on a 3.06GHz Intel Core 2 Duo. | 225 » » z[i] = 0 |
224 // | 226 » } |
225 // (To disable Karatsuba multiplication, set the threshold to a very | 227 » // multiply |
226 // large value). | 228 » for i, d := range y { |
227 const karatsubaThreshold = 245 | 229 » » if d != 0 { |
228 | 230 » » » z[len(x)+i] = addMulVVW(&z[i], &x[0], d, len(x)) |
229 // karatsubaThreshold must be >= 2. | 231 » » } |
230 // Trigger compile error if that's not true. | 232 » } |
231 const _ uint = karatsubaThreshold - 2 | 233 } |
232 | 234 |
233 | 235 |
234 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. | 236 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. |
235 // Factored out for readability - do not use outside karatsuba. | 237 // Factored out for readability - do not use outside karatsuba. |
236 func karatsubaAdd(z, x nat, n int) { | 238 func karatsubaAdd(z, x nat, n int) { |
237 if c := addVV(&z[0], &z[0], &x[0], n); c != 0 { | 239 if c := addVV(&z[0], &z[0], &x[0], n); c != 0 { |
238 addVW(&z[n], &z[n], c, n>>1) | 240 addVW(&z[n], &z[n], c, n>>1) |
239 } | 241 } |
240 } | 242 } |
241 | 243 |
242 | 244 |
243 // Like karatsubaAdd, but does subtract. | 245 // Like karatsubaAdd, but does subtract. |
244 func karatsubaSub(z, x nat, n int) { | 246 func karatsubaSub(z, x nat, n int) { |
245 if c := subVV(&z[0], &z[0], &x[0], n); c != 0 { | 247 if c := subVV(&z[0], &z[0], &x[0], n); c != 0 { |
246 subVW(&z[n], &z[n], c, n>>1) | 248 subVW(&z[n], &z[n], c, n>>1) |
247 } | 249 } |
248 } | 250 } |
249 | 251 |
252 | |
253 // Operands that are shorter than karatsubaThreshold are multiplied using | |
254 // "grade school" multiplication; for longer operands the Karatsuba algorithm | |
255 // is used. | |
256 var karatsubaThreshold int = 30 // modified by calibrate.go | |
250 | 257 |
251 // karatsuba multiplies x and y and leaves the result in z. | 258 // karatsuba multiplies x and y and leaves the result in z. |
252 // Both x and y must have the same length n and n must be a | 259 // Both x and y must have the same length n and n must be a |
253 // power of 2. The result vector z must have len(z) >= 6*n. | 260 // power of 2. The result vector z must have len(z) >= 6*n. |
254 // The (non-normalized) result is placed in z[0 : 2*n]. | 261 // The (non-normalized) result is placed in z[0 : 2*n]. |
255 func karatsuba(z, x, y nat) { | 262 func karatsuba(z, x, y nat) { |
256 n := len(y) | 263 n := len(y) |
257 | 264 |
258 » // Switch to basic multiplication if the numbers are small. | 265 » // Switch to basic multiplication if numbers are odd or small. |
259 » if n < karatsubaThreshold { | 266 » // (n is always even if karatsubaThreshold is even, but be |
260 » » // initialize z | 267 » // conservative) |
261 » » for i := 2*n - 1; i >= 0; i-- { | 268 » if n&1 != 0 || n < karatsubaThreshold || n < 2 { |
262 » » » z[i] = 0 | 269 » » basicMul(z, x, y) |
263 » » } | |
264 » » // "grade school" multiplication | |
265 » » for i, d := range y { | |
266 » » » if d != 0 { | |
267 » » » » z[n+i] = addMulVVW(&z[i], &x[0], d, n) | |
268 » » » } | |
269 » » } | |
270 return | 270 return |
271 } | 271 } |
272 » // n >= karatsubaThreshold > 1 | 272 » // n&1 == 0 && n >= karatsubaThreshold && n >= 2 |
273 | 273 |
274 // Karatsuba multiplication is based on the observation that | 274 // Karatsuba multiplication is based on the observation that |
275 // for two numbers x and y with: | 275 // for two numbers x and y with: |
276 // | 276 // |
277 // x = x1*b + x0 | 277 // x = x1*b + x0 |
278 // y = y1*b + y0 | 278 // y = y1*b + y0 |
279 // | 279 // |
280 // the product x*y can be obtained with 3 products z2, z1, z0 | 280 // the product x*y can be obtained with 3 products z2, z1, z0 |
281 // instead of 4: | 281 // instead of 4: |
282 // | 282 // |
(...skipping 12 matching lines...) Expand all Loading... | |
295 // = x1*y0 + x0*y1 | 295 // = x1*y0 + x0*y1 |
296 | 296 |
297 // split x, y into "digits" | 297 // split x, y into "digits" |
298 n2 := n >> 1 // n2 >= 1 | 298 n2 := n >> 1 // n2 >= 1 |
299 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 | 299 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 |
300 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 | 300 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 |
301 | 301 |
302 // z is used for the result and temporary storage: | 302 // z is used for the result and temporary storage: |
303 // | 303 // |
304 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n | 304 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n |
305 » // z = [z2 copy|z0 copy| xd*yd | xd:yd | x1*y1 | x0*y0 ] | 305 » // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ] |
306 // | 306 // |
307 // For each recursive call of karatsuba, an unused slice of | 307 // For each recursive call of karatsuba, an unused slice of |
308 // z is passed in that has (at least) half the length of the | 308 // z is passed in that has (at least) half the length of the |
309 // caller's z. | 309 // caller's z. |
310 | 310 |
311 // compute z0 and z2 with the result "in place" in z | 311 // compute z0 and z2 with the result "in place" in z |
312 karatsuba(z, x0, y0) // z0 = x0*y0 | 312 karatsuba(z, x0, y0) // z0 = x0*y0 |
313 karatsuba(z[n:], x1, y1) // z2 = x1*y1 | 313 karatsuba(z[n:], x1, y1) // z2 = x1*y1 |
314 | |
315 // TODO(gri): In the following we carefully avoid underflow | |
316 // by recomputing differences and keeping track | |
317 // of sign changes. Can probably optimize this by | |
318 // simply ignoring the overflow but track sign changes | |
319 // and use this to sign extend the product xd*yd before | |
320 // adding it to z. This should remove quite a bit of code. | |
321 | 314 |
322 // compute xd (or the negative value if underflow occurs) | 315 // compute xd (or the negative value if underflow occurs) |
323 s := 1 // sign of product xd*yd | 316 s := 1 // sign of product xd*yd |
324 xd := z[2*n : 2*n+n2] | 317 xd := z[2*n : 2*n+n2] |
325 if subVV(&xd[0], &x1[0], &x0[0], n2) != 0 { // x1-x0 | 318 if subVV(&xd[0], &x1[0], &x0[0], n2) != 0 { // x1-x0 |
326 s = -s | 319 s = -s |
327 subVV(&xd[0], &x0[0], &x1[0], n2) // x0-x1 | 320 subVV(&xd[0], &x0[0], &x1[0], n2) // x0-x1 |
328 } | 321 } |
329 | 322 |
330 // compute yd (or the negative value if underflow occurs) | 323 // compute yd (or the negative value if underflow occurs) |
331 yd := z[2*n+n2 : 3*n] | 324 yd := z[2*n+n2 : 3*n] |
332 if subVV(&yd[0], &y0[0], &y1[0], n2) != 0 { // y0-y1 | 325 if subVV(&yd[0], &y0[0], &y1[0], n2) != 0 { // y0-y1 |
333 s = -s | 326 s = -s |
334 subVV(&yd[0], &y1[0], &y0[0], n2) // y1-y0 | 327 subVV(&yd[0], &y1[0], &y0[0], n2) // y1-y0 |
335 } | 328 } |
336 | 329 |
337 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 | 330 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 |
338 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 | 331 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 |
339 p := z[n*3:] | 332 p := z[n*3:] |
340 karatsuba(p, xd, yd) | 333 karatsuba(p, xd, yd) |
341 | 334 |
342 // save original z2:z0 | 335 // save original z2:z0 |
343 // (ok to use upper half of z since we're done recursing) | 336 // (ok to use upper half of z since we're done recursing) |
344 r := z[n*4:] | 337 r := z[n*4:] |
345 copy(r, z) | 338 copy(r, z) |
346 | 339 |
347 // add up all partial products | 340 // add up all partial products |
348 // | 341 // |
342 // 2*n n 0 | |
349 // z = [ z2 | z0 ] | 343 // z = [ z2 | z0 ] |
350 // + [ z0 ] | 344 // + [ z0 ] |
351 // + [ z2 ] | 345 // + [ z2 ] |
352 // + [ p ] | 346 // + [ p ] |
353 // | 347 // |
354 karatsubaAdd(z[n2:], r, n) | 348 karatsubaAdd(z[n2:], r, n) |
355 karatsubaAdd(z[n2:], r[n:], n) | 349 karatsubaAdd(z[n2:], r[n:], n) |
356 if s > 0 { | 350 if s > 0 { |
357 karatsubaAdd(z[n2:], p, n) | 351 karatsubaAdd(z[n2:], p, n) |
358 } else { | 352 } else { |
359 karatsubaSub(z[n2:], p, n) | 353 karatsubaSub(z[n2:], p, n) |
360 } | 354 } |
361 } | 355 } |
362 | 356 |
363 | 357 |
364 // alias returns true if x and y share the same base array. | 358 // alias returns true if x and y share the same base array. |
365 func alias(x, y nat) bool { | 359 func alias(x, y nat) bool { |
366 return &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] | 360 return &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] |
367 } | 361 } |
368 | 362 |
369 | 363 |
370 // addAt implements z += x*(1<<(_W*i)); z must be long enough. | 364 // addAt implements z += x*(1<<(_W*i)); z must be long enough. |
365 // (we don't use nat.add because we need z to stay the same | |
366 // slice, and we don't need to normalize z after each addition) | |
371 func addAt(z, x nat, i int) { | 367 func addAt(z, x nat, i int) { |
372 if n := len(x); n > 0 { | 368 if n := len(x); n > 0 { |
373 if c := addVV(&z[i], &z[i], &x[0], n); c != 0 { | 369 if c := addVV(&z[i], &z[i], &x[0], n); c != 0 { |
374 j := i + n | 370 j := i + n |
375 if j < len(z) { | 371 if j < len(z) { |
376 addVW(&z[j], &z[j], c, len(z)-j) | 372 addVW(&z[j], &z[j], c, len(z)-j) |
377 } | 373 } |
378 } | 374 } |
379 } | 375 } |
380 } | 376 } |
381 | 377 |
382 | 378 |
379 func max(x, y int) int { | |
380 if x > y { | |
381 return x | |
382 } | |
383 return y | |
384 } | |
385 | |
386 | |
383 func (z nat) mul(x, y nat) nat { | 387 func (z nat) mul(x, y nat) nat { |
384 m := len(x) | 388 m := len(x) |
385 n := len(y) | 389 n := len(y) |
386 | 390 |
387 switch { | 391 switch { |
388 case m < n: | 392 case m < n: |
389 return z.mul(y, x) | 393 return z.mul(y, x) |
390 case m == 0 || n == 0: | 394 case m == 0 || n == 0: |
391 » » return z.make(0, false) | 395 » » return z.make(0) |
392 case n == 1: | 396 case n == 1: |
393 return z.mulAddWW(x, y[0], 0) | 397 return z.mulAddWW(x, y[0], 0) |
394 } | 398 } |
395 // m >= n > 1 | 399 // m >= n > 1 |
396 | 400 |
397 // determine if z can be reused | 401 // determine if z can be reused |
398 if len(z) > 0 && (alias(z, x) || alias(z, y)) { | 402 if len(z) > 0 && (alias(z, x) || alias(z, y)) { |
399 z = nil // z is an alias for x or y - cannot reuse | 403 z = nil // z is an alias for x or y - cannot reuse |
400 } | 404 } |
401 | 405 |
402 » if n < karatsubaThreshold { | 406 » // use basic multiplication if the numbers are small |
403 » » // "grade school" multiplication | 407 » if n < karatsubaThreshold || n < 2 { |
404 » » z = z.make(m+n, true) | 408 » » z = z.make(m + n) |
405 » » for i, d := range y { | 409 » » basicMul(z, x, y) |
406 » » » if d != 0 { | |
407 » » » » z[m+i] = addMulVVW(&z[i], &x[0], d, m) | |
408 » » » } | |
409 » » } | |
410 return z.norm() | 410 return z.norm() |
411 } | 411 } |
412 » // m >= n && n >= karatsubaThreshold | 412 » // m >= n && n >= karatsubaThreshold && n >= 2 |
413 | 413 |
414 » // Note that even though we passed the Karatsuba threshold, | 414 » // determine largest k such that |
415 » // because we tested against n and not k (see below) we may | |
416 » // still end up using grade-school multiplication, albeit with | |
417 » // an intermediate step if the Karatsuba theshold is not a | |
418 » // power of 2. It appears that this intermediate step makes | |
419 » // things faster (e.g., the threshold is < 256 at the moment). | |
420 » // Theoretically, there are more operations involved but the numbers | |
421 » // are larger and thus "internal fragmentation" (i.e., total number | |
422 » // unused bits in leading words) may be smaller, possibly resulting | |
423 » // in fewer actual machine multiplications. | |
424 | |
425 » // Determine k such that: | |
426 // | 415 // |
427 // x = x1*b + x0 | 416 // x = x1*b + x0 |
428 // y = y1*b + y0 (and k <= len(y), which implies k <= len(x)) | 417 // y = y1*b + y0 (and k <= len(y), which implies k <= len(x)) |
429 // | |
430 // and | |
431 // | |
432 // b = 1<<(_W*k) ("base" of digits xi, yi) | 418 // b = 1<<(_W*k) ("base" of digits xi, yi) |
433 // | 419 // |
434 » k := 1 << uint(log2(Word(n))) | 420 » // and k is karatsubaThreshold multiplied by a power of 2 |
435 | 421 » k := max(karatsubaThreshold, 2) |
436 » // If x1 and/or y1 are not 0, compute product explicitly: | 422 » for k*2 <= n { |
437 » // | 423 » » k *= 2 |
438 » // x*y = x1*y1*b*b + x1*y0*b + x0*y1*b + x0*y0 | 424 » } |
425 » // k <= n | |
426 | |
427 » // multiply x0 and y0 via Karatsuba | |
428 » x0 := x[0:k] // x0 is not normalized | |
429 » y0 := y[0:k] // y0 is not normalized | |
430 » z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and ful l result of x*y | |
431 » karatsuba(z, x0, y0) | |
432 » z = z[0 : m+n] // z has final length but may be incomplete, upper portio n is garbage | |
433 | |
434 » // If x1 and/or y1 are not 0, add missing terms to z explicitly: | |
435 » // | |
436 » // m+n 2*k 0 | |
437 » // z = [ ... | x0*y0 ] | |
438 » // + [ x1*y1 ] | |
439 » // + [ x1*y0 ] | |
440 » // + [ x0*y1 ] | |
439 // | 441 // |
440 if k < n || m != n { | 442 if k < n || m != n { |
441 » » x1, x0 := x[k:], x[0:k].norm() // x1 is normalized because x is | 443 » » x1 := x[k:] // x1 is normalized because x is |
442 » » y1, y0 := y[k:], y[0:k].norm() // y1 is normalized because y is | 444 » » y1 := y[k:] // y1 is normalized because y is |
443 var t nat | 445 var t nat |
444 z = z.make(m+n, true) | |
445 t = t.mul(x1, y1) | 446 t = t.mul(x1, y1) |
446 » » copy(z[2*k:], t) // z may not be normalized! | 447 » » copy(z[2*k:], t) |
447 » » t = t.mul(x1, y0) | 448 » » z[2*k+len(t):].clear() // upper portion of z is garbage |
449 » » t = t.mul(x1, y0.norm()) | |
448 addAt(z, t, k) | 450 addAt(z, t, k) |
449 » » t = t.mul(x0, y1) | 451 » » t = t.mul(x0.norm(), y1) |
450 addAt(z, t, k) | 452 addAt(z, t, k) |
451 » » t = t.mul(x0, y0) | 453 » } |
452 » » addAt(z, t, 0) // (could invoke karatsuba for x0, y0 directly) | 454 |
453 » » return z.norm() | 455 » return z.norm() |
454 » } | |
455 » // k == n && m == n | |
456 | |
457 » // Both x and y have the same length k which is a power of 2 | |
458 » // and thus are directly suitable for Karatsuba multiplication. | |
459 » z = z.make(6*k, false) | |
460 » karatsuba(z, x, y) | |
461 » return z[0 : 2*n].norm() | |
462 } | 456 } |
463 | 457 |
464 | 458 |
465 // mulRange computes the product of all the unsigned integers in the | 459 // mulRange computes the product of all the unsigned integers in the |
466 // range [a, b] inclusively. If a > b (empty range), the result is 1. | 460 // range [a, b] inclusively. If a > b (empty range), the result is 1. |
467 func (z nat) mulRange(a, b uint64) nat { | 461 func (z nat) mulRange(a, b uint64) nat { |
adonovan
2020/02/12 14:46:30
What's the purpose of treating the sequence as lea
gri
2020/02/19 00:51:24
The recursive approach appears faster in practice
| |
468 switch { | 462 switch { |
469 case a == 0: | 463 case a == 0: |
470 // cut long ranges short (optimization) | 464 // cut long ranges short (optimization) |
471 return z.new(0) | 465 return z.new(0) |
472 case a > b: | 466 case a > b: |
473 return z.new(1) | 467 return z.new(1) |
474 case a == b: | 468 case a == b: |
475 return z.new(a) | 469 return z.new(a) |
476 case a+1 == b: | 470 case a+1 == b: |
477 return z.mul(nat(nil).new(a), nat(nil).new(b)) | 471 return z.mul(nat(nil).new(a), nat(nil).new(b)) |
(...skipping 10 matching lines...) Expand all Loading... | |
488 case y == 0: | 482 case y == 0: |
489 panic("division by zero") | 483 panic("division by zero") |
490 case y == 1: | 484 case y == 1: |
491 q = z.set(x) // result is x | 485 q = z.set(x) // result is x |
492 return | 486 return |
493 case m == 0: | 487 case m == 0: |
494 q = z.set(nil) // result is 0 | 488 q = z.set(nil) // result is 0 |
495 return | 489 return |
496 } | 490 } |
497 // m > 0 | 491 // m > 0 |
498 » z = z.make(m, false) | 492 » z = z.make(m) |
499 r = divWVW(&z[0], 0, &x[0], y, m) | 493 r = divWVW(&z[0], 0, &x[0], y, m) |
500 q = z.norm() | 494 q = z.norm() |
501 return | 495 return |
502 } | 496 } |
503 | 497 |
504 | 498 |
505 func (z nat) div(z2, u, v nat) (q, r nat) { | 499 func (z nat) div(z2, u, v nat) (q, r nat) { |
506 if len(v) == 0 { | 500 if len(v) == 0 { |
507 panic("division by zero") | 501 panic("division by zero") |
508 } | 502 } |
509 | 503 |
510 if u.cmp(v) < 0 { | 504 if u.cmp(v) < 0 { |
511 » » q = z.make(0, false) | 505 » » q = z.make(0) |
512 r = z2.set(u) | 506 r = z2.set(u) |
513 return | 507 return |
514 } | 508 } |
515 | 509 |
516 if len(v) == 1 { | 510 if len(v) == 1 { |
517 var rprime Word | 511 var rprime Word |
518 q, rprime = z.divW(u, v[0]) | 512 q, rprime = z.divW(u, v[0]) |
519 if rprime > 0 { | 513 if rprime > 0 { |
520 » » » r = z2.make(1, false) | 514 » » » r = z2.make(1) |
521 r[0] = rprime | 515 r[0] = rprime |
522 } else { | 516 } else { |
523 » » » r = z2.make(0, false) | 517 » » » r = z2.make(0) |
524 } | 518 } |
525 return | 519 return |
526 } | 520 } |
527 | 521 |
528 q, r = z.divLarge(z2, u, v) | 522 q, r = z.divLarge(z2, u, v) |
529 return | 523 return |
530 } | 524 } |
531 | 525 |
532 | 526 |
533 // q = (uIn-r)/v, with 0 <= r < y | 527 // q = (uIn-r)/v, with 0 <= r < y |
534 // See Knuth, Volume 2, section 4.3.1, Algorithm D. | 528 // See Knuth, Volume 2, section 4.3.1, Algorithm D. |
535 // Preconditions: | 529 // Preconditions: |
536 // len(v) >= 2 | 530 // len(v) >= 2 |
537 // len(uIn) >= len(v) | 531 // len(uIn) >= len(v) |
538 func (z nat) divLarge(z2, uIn, v nat) (q, r nat) { | 532 func (z nat) divLarge(z2, uIn, v nat) (q, r nat) { |
539 n := len(v) | 533 n := len(v) |
540 m := len(uIn) - len(v) | 534 m := len(uIn) - len(v) |
541 | 535 |
542 var u nat | 536 var u nat |
543 if z2 == nil || &z2[0] == &uIn[0] { | 537 if z2 == nil || &z2[0] == &uIn[0] { |
544 » » u = u.make(len(uIn)+1, true) // uIn is an alias for z2 | 538 » » u = u.make(len(uIn) + 1).clear() // uIn is an alias for z2 |
545 } else { | 539 } else { |
546 » » u = z2.make(len(uIn)+1, true) | 540 » » u = z2.make(len(uIn) + 1).clear() |
547 } | 541 } |
548 qhatv := make(nat, len(v)+1) | 542 qhatv := make(nat, len(v)+1) |
549 » q = z.make(m+1, false) | 543 » q = z.make(m + 1) |
550 | 544 |
551 // D1. | 545 // D1. |
552 shift := uint(leadingZeroBits(v[n-1])) | 546 shift := uint(leadingZeroBits(v[n-1])) |
553 v.shiftLeft(v, shift) | 547 v.shiftLeft(v, shift) |
554 u.shiftLeft(uIn, shift) | 548 u.shiftLeft(uIn, shift) |
555 u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift)) | 549 u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift)) |
556 | 550 |
557 // D2. | 551 // D2. |
558 for j := m; j >= 0; j-- { | 552 for j := m; j >= 0; j-- { |
559 // D3. | 553 // D3. |
(...skipping 209 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
769 case 64: | 763 case 64: |
770 return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) | 764 return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) |
771 default: | 765 default: |
772 panic("Unknown word size") | 766 panic("Unknown word size") |
773 } | 767 } |
774 | 768 |
775 return 0 | 769 return 0 |
776 } | 770 } |
777 | 771 |
778 | 772 |
773 // TODO(gri) Make the shift routines faster. | |
774 // Use pidigits.go benchmark as a test case. | |
775 | |
779 // To avoid losing the top n bits, z should be sized so that | 776 // To avoid losing the top n bits, z should be sized so that |
780 // len(z) == len(x) + 1. | 777 // len(z) == len(x) + 1. |
781 func (z nat) shiftLeft(x nat, n uint) nat { | 778 func (z nat) shiftLeft(x nat, n uint) nat { |
782 if len(x) == 0 { | 779 if len(x) == 0 { |
783 return x | 780 return x |
784 } | 781 } |
785 | 782 |
786 ñ := _W - n | 783 ñ := _W - n |
787 m := x[len(x)-1] | 784 m := x[len(x)-1] |
788 if len(z) > len(x) { | 785 if len(z) > len(x) { |
(...skipping 27 matching lines...) Expand all Loading... | |
816 | 813 |
817 | 814 |
818 // greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2) | 815 // greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2) |
819 func greaterThan(x1, x2, y1, y2 Word) bool { return x1 > y1 || x1 == y1 && x2 > y2 } | 816 func greaterThan(x1, x2, y1, y2 Word) bool { return x1 > y1 || x1 == y1 && x2 > y2 } |
820 | 817 |
821 | 818 |
822 // modW returns x % d. | 819 // modW returns x % d. |
823 func (x nat) modW(d Word) (r Word) { | 820 func (x nat) modW(d Word) (r Word) { |
824 // TODO(agl): we don't actually need to store the q value. | 821 // TODO(agl): we don't actually need to store the q value. |
825 var q nat | 822 var q nat |
826 » q = q.make(len(x), false) | 823 » q = q.make(len(x)) |
827 return divWVW(&q[0], 0, &x[0], d, len(x)) | 824 return divWVW(&q[0], 0, &x[0], d, len(x)) |
828 } | 825 } |
829 | 826 |
830 | 827 |
831 // powersOfTwoDecompose finds q and k such that q * 1<<k = n and q is odd. | 828 // powersOfTwoDecompose finds q and k such that q * 1<<k = n and q is odd. |
832 func (n nat) powersOfTwoDecompose() (q nat, k Word) { | 829 func (n nat) powersOfTwoDecompose() (q nat, k Word) { |
833 if len(n) == 0 { | 830 if len(n) == 0 { |
834 return n, 0 | 831 return n, 0 |
835 } | 832 } |
836 | 833 |
837 zeroWords := 0 | 834 zeroWords := 0 |
838 for n[zeroWords] == 0 { | 835 for n[zeroWords] == 0 { |
839 zeroWords++ | 836 zeroWords++ |
840 } | 837 } |
841 // One of the words must be non-zero by invariant, therefore | 838 // One of the words must be non-zero by invariant, therefore |
842 // zeroWords < len(n). | 839 // zeroWords < len(n). |
843 x := trailingZeroBits(n[zeroWords]) | 840 x := trailingZeroBits(n[zeroWords]) |
844 | 841 |
845 » q = q.make(len(n)-zeroWords, false) | 842 » q = q.make(len(n) - zeroWords) |
846 q.shiftRight(n[zeroWords:], uint(x)) | 843 q.shiftRight(n[zeroWords:], uint(x)) |
847 q = q.norm() | 844 q = q.norm() |
848 | 845 |
849 k = Word(_W*zeroWords + x) | 846 k = Word(_W*zeroWords + x) |
850 return | 847 return |
851 } | 848 } |
852 | 849 |
853 | 850 |
854 // random creates a random integer in [0..limit), using the space in z if | 851 // random creates a random integer in [0..limit), using the space in z if |
855 // possible. n is the bit length of limit. | 852 // possible. n is the bit length of limit. |
856 func (z nat) random(rand *rand.Rand, limit nat, n int) nat { | 853 func (z nat) random(rand *rand.Rand, limit nat, n int) nat { |
857 bitLengthOfMSW := uint(n % _W) | 854 bitLengthOfMSW := uint(n % _W) |
858 if bitLengthOfMSW == 0 { | 855 if bitLengthOfMSW == 0 { |
859 bitLengthOfMSW = _W | 856 bitLengthOfMSW = _W |
860 } | 857 } |
861 mask := Word((1 << bitLengthOfMSW) - 1) | 858 mask := Word((1 << bitLengthOfMSW) - 1) |
862 » z = z.make(len(limit), false) | 859 » z = z.make(len(limit)) |
863 | 860 |
864 for { | 861 for { |
865 for i := range z { | 862 for i := range z { |
866 switch _W { | 863 switch _W { |
867 case 32: | 864 case 32: |
868 z[i] = Word(rand.Uint32()) | 865 z[i] = Word(rand.Uint32()) |
869 case 64: | 866 case 64: |
870 z[i] = Word(rand.Uint32()) | Word(rand.Uint32()) <<32 | 867 z[i] = Word(rand.Uint32()) | Word(rand.Uint32()) <<32 |
871 } | 868 } |
872 } | 869 } |
873 | 870 |
874 z[len(limit)-1] &= mask | 871 z[len(limit)-1] &= mask |
875 | 872 |
876 if z.cmp(limit) < 0 { | 873 if z.cmp(limit) < 0 { |
877 break | 874 break |
878 } | 875 } |
879 } | 876 } |
880 | 877 |
881 return z.norm() | 878 return z.norm() |
882 } | 879 } |
883 | 880 |
884 | 881 |
885 // If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It | 882 // If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It |
886 // reuses the storage of z if possible. | 883 // reuses the storage of z if possible. |
887 func (z nat) expNN(x, y, m nat) nat { | 884 func (z nat) expNN(x, y, m nat) nat { |
888 if len(y) == 0 { | 885 if len(y) == 0 { |
889 » » z = z.make(1, false) | 886 » » z = z.make(1) |
890 z[0] = 1 | 887 z[0] = 1 |
891 return z | 888 return z |
892 } | 889 } |
893 | 890 |
894 if m != nil { | 891 if m != nil { |
895 // We likely end up being as long as the modulus. | 892 // We likely end up being as long as the modulus. |
896 » » z = z.make(len(m), false) | 893 » » z = z.make(len(m)) |
897 } | 894 } |
898 z = z.set(x) | 895 z = z.set(x) |
899 v := y[len(y)-1] | 896 v := y[len(y)-1] |
900 // It's invalid for the most significant word to be zero, therefore we | 897 // It's invalid for the most significant word to be zero, therefore we |
901 // will find a one bit. | 898 // will find a one bit. |
902 shift := leadingZeros(v) + 1 | 899 shift := leadingZeros(v) + 1 |
903 v <<= shift | 900 v <<= shift |
904 var q nat | 901 var q nat |
905 | 902 |
906 const mask = 1 << (_W - 1) | 903 const mask = 1 << (_W - 1) |
(...skipping 125 matching lines...) Expand 10 before | Expand all | Expand 10 after Loading... | |
1032 } | 1029 } |
1033 if y.cmp(natOne) == 0 { | 1030 if y.cmp(natOne) == 0 { |
1034 return false | 1031 return false |
1035 } | 1032 } |
1036 } | 1033 } |
1037 return false | 1034 return false |
1038 } | 1035 } |
1039 | 1036 |
1040 return true | 1037 return true |
1041 } | 1038 } |
LEFT | RIGHT |