Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(2194)

Delta Between Two Patch Sets: src/pkg/big/nat.go

Issue 1004042: code review 1004042: big: implemented Karatsuba multiplication (Closed)
Left Patch Set: code review 1004042: big: implemented Karatsuba multiplication Created 13 years, 11 months ago
Right Patch Set: code review 1004042: big: implemented Karatsuba multiplication Created 13 years, 11 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
Left: Side by side diff | Download
Right: Side by side diff | Download
« no previous file with change/comment | « src/pkg/big/int_test.go ('k') | src/pkg/big/nat_test.go » ('j') | no next file with change/comment »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
LEFTRIGHT
1 // Copyright 2009 The Go Authors. All rights reserved. 1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style 2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file. 3 // license that can be found in the LICENSE file.
4 4
5 // This file contains operations on unsigned multi-precision integers. 5 // This file contains operations on unsigned multi-precision integers.
6 // These are the building blocks for the operations on signed integers 6 // These are the building blocks for the operations on signed integers
7 // and rationals. 7 // and rationals.
8 8
9 // This package implements multi-precision arithmetic (big numbers). 9 // This package implements multi-precision arithmetic (big numbers).
10 // The following numeric types are supported: 10 // The following numeric types are supported:
(...skipping 19 matching lines...) Expand all
30 // with the digits x[i] as the slice elements. 30 // with the digits x[i] as the slice elements.
31 // 31 //
32 // A number is normalized if the slice contains no leading 0 digits. 32 // A number is normalized if the slice contains no leading 0 digits.
33 // During arithmetic operations, denormalized values may occur but are 33 // During arithmetic operations, denormalized values may occur but are
34 // always normalized before returning the final result. The normalized 34 // always normalized before returning the final result. The normalized
35 // representation of 0 is the empty or nil slice (length = 0). 35 // representation of 0 is the empty or nil slice (length = 0).
36 36
37 type nat []Word 37 type nat []Word
38 38
39 var ( 39 var (
40 » natZero = nat(nil) 40 » natOne = nat{1}
41 » natOne = nat{1} 41 » natTwo = nat{2}
42 » natTwo = nat{2}
43 ) 42 )
43
44
45 func (z nat) clear() nat {
46 for i := range z {
47 z[i] = 0
48 }
49 return z
50 }
44 51
45 52
46 func (z nat) norm() nat { 53 func (z nat) norm() nat {
47 i := len(z) 54 i := len(z)
48 for i > 0 && z[i-1] == 0 { 55 for i > 0 && z[i-1] == 0 {
49 i-- 56 i--
50 } 57 }
51 z = z[0:i] 58 z = z[0:i]
52 return z 59 return z
53 } 60 }
54 61
55 62
56 func (z nat) make(m int, clear bool) nat { 63 func (z nat) make(m int) nat {
57 if cap(z) > m { 64 if cap(z) > m {
58 » » z = z[0:m] // reuse z - has at least one extra word for a carry, if any 65 » » return z[0:m] // reuse z - has at least one extra word for a car ry, if any
59 » » if clear {
60 » » » for i := range z {
61 » » » » z[i] = 0
62 » » » }
63 » » }
64 » » return z
65 } 66 }
66 67
67 c := 4 // minimum capacity 68 c := 4 // minimum capacity
68 if m > c { 69 if m > c {
69 c = m 70 c = m
70 } 71 }
71 return make(nat, m, c+1) // +1: extra word for a carry, if any 72 return make(nat, m, c+1) // +1: extra word for a carry, if any
72 } 73 }
73 74
74 75
75 func (z nat) new(x uint64) nat { 76 func (z nat) new(x uint64) nat {
76 if x == 0 { 77 if x == 0 {
77 » » return z.make(0, false) 78 » » return z.make(0)
78 } 79 }
79 80
80 // single-digit values 81 // single-digit values
81 if x == uint64(Word(x)) { 82 if x == uint64(Word(x)) {
82 » » z = z.make(1, false) 83 » » z = z.make(1)
83 z[0] = Word(x) 84 z[0] = Word(x)
84 return z 85 return z
85 } 86 }
86 87
87 // compute number of words n required to represent x 88 // compute number of words n required to represent x
88 n := 0 89 n := 0
89 for t := x; t > 0; t >>= _W { 90 for t := x; t > 0; t >>= _W {
90 n++ 91 n++
91 } 92 }
92 93
93 // split x into n words 94 // split x into n words
94 » z = z.make(n, false) 95 » z = z.make(n)
95 for i := 0; i < n; i++ { 96 for i := 0; i < n; i++ {
96 z[i] = Word(x & _M) 97 z[i] = Word(x & _M)
97 x >>= _W 98 x >>= _W
98 } 99 }
99 100
100 return z 101 return z
101 } 102 }
102 103
103 104
104 func (z nat) set(x nat) nat { 105 func (z nat) set(x nat) nat {
105 » z = z.make(len(x), false) 106 » z = z.make(len(x))
106 for i, d := range x { 107 for i, d := range x {
107 z[i] = d 108 z[i] = d
108 } 109 }
109 return z 110 return z
110 } 111 }
111 112
112 113
113 func (z nat) add(x, y nat) nat { 114 func (z nat) add(x, y nat) nat {
114 m := len(x) 115 m := len(x)
115 n := len(y) 116 n := len(y)
116 117
117 switch { 118 switch {
118 case m < n: 119 case m < n:
119 return z.add(y, x) 120 return z.add(y, x)
120 case m == 0: 121 case m == 0:
121 // n == 0 because m >= n; result is 0 122 // n == 0 because m >= n; result is 0
122 » » return z.make(0, false) 123 » » return z.make(0)
123 case n == 0: 124 case n == 0:
124 // result is x 125 // result is x
125 return z.set(x) 126 return z.set(x)
126 } 127 }
127 // m > 0 128 // m > 0
128 129
129 » z = z.make(m, false) 130 » z = z.make(m)
130 c := addVV(&z[0], &x[0], &y[0], n) 131 c := addVV(&z[0], &x[0], &y[0], n)
131 if m > n { 132 if m > n {
132 c = addVW(&z[n], &x[n], c, m-n) 133 c = addVW(&z[n], &x[n], c, m-n)
133 } 134 }
134 if c > 0 { 135 if c > 0 {
135 z = z[0 : m+1] 136 z = z[0 : m+1]
136 z[m] = c 137 z[m] = c
137 } 138 }
138 139
139 return z 140 return z
140 } 141 }
141 142
142 143
143 func (z nat) sub(x, y nat) nat { 144 func (z nat) sub(x, y nat) nat {
144 m := len(x) 145 m := len(x)
145 n := len(y) 146 n := len(y)
146 147
147 switch { 148 switch {
148 case m < n: 149 case m < n:
149 panic("underflow") 150 panic("underflow")
150 case m == 0: 151 case m == 0:
151 // n == 0 because m >= n; result is 0 152 // n == 0 because m >= n; result is 0
152 » » return z.make(0, false) 153 » » return z.make(0)
153 case n == 0: 154 case n == 0:
154 // result is x 155 // result is x
155 return z.set(x) 156 return z.set(x)
156 } 157 }
157 // m > 0 158 // m > 0
158 159
159 » z = z.make(m, false) 160 » z = z.make(m)
160 c := subVV(&z[0], &x[0], &y[0], n) 161 c := subVV(&z[0], &x[0], &y[0], n)
161 if m > n { 162 if m > n {
162 c = subVW(&z[n], &x[n], c, m-n) 163 c = subVW(&z[n], &x[n], c, m-n)
163 } 164 }
164 if c != 0 { 165 if c != 0 {
165 panic("underflow") 166 panic("underflow")
166 } 167 }
167 z = z.norm() 168 z = z.norm()
168 169
169 return z 170 return z
(...skipping 28 matching lines...) Expand all
198 } 199 }
199 200
200 201
201 func (z nat) mulAddWW(x nat, y, r Word) nat { 202 func (z nat) mulAddWW(x nat, y, r Word) nat {
202 m := len(x) 203 m := len(x)
203 if m == 0 || y == 0 { 204 if m == 0 || y == 0 {
204 return z.new(uint64(r)) // result is r 205 return z.new(uint64(r)) // result is r
205 } 206 }
206 // m > 0 207 // m > 0
207 208
208 » z = z.make(m, false) 209 » z = z.make(m)
209 c := mulAddVWW(&z[0], &x[0], y, r, m) 210 c := mulAddVWW(&z[0], &x[0], y, r, m)
210 if c > 0 { 211 if c > 0 {
211 z = z[0 : m+1] 212 z = z[0 : m+1]
212 z[m] = c 213 z[m] = c
213 } 214 }
214 215
215 return z 216 return z
216 } 217 }
217 218
218 219
219 // Operands that are shorter than this threshold are multiplied using 220 // basicMul multiplies x and y and leaves the result in z.
220 // "grade school" multiplication; for larger operands the Karatsuba 221 // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
221 // algorithm is used. 222 func basicMul(z, x, y nat) {
222 // 223 » // initialize z
223 // The value has been found empirically for gotest -benchmarks=Fact 224 » for i := range z[0 : len(x)+len(y)] {
224 // on a machine running OS X on a 3.06GHz Intel Core 2 Duo. 225 » » z[i] = 0
225 // 226 » }
226 // (To disable Karatsuba multiplication, set the threshold to a very 227 » // multiply
227 // large value). 228 » for i, d := range y {
228 const karatsubaThreshold = 245 229 » » if d != 0 {
229 230 » » » z[len(x)+i] = addMulVVW(&z[i], &x[0], d, len(x))
230 func init() { 231 » » }
231 » if karatsubaThreshold <= 1 { 232 » }
232 » » panic("karatsubaThreshold must be > 1") 233 }
233 » } 234
234 } 235
235 236 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
236 237 // Factored out for readability - do not use outside karatsuba.
237 func karatsubaAdd(z, x nat, n int) { 238 func karatsubaAdd(z, x nat, n int) {
238 if c := addVV(&z[0], &z[0], &x[0], n); c != 0 { 239 if c := addVV(&z[0], &z[0], &x[0], n); c != 0 {
239 addVW(&z[n], &z[n], c, n>>1) 240 addVW(&z[n], &z[n], c, n>>1)
240 } 241 }
241 } 242 }
242 243
243 244
245 // Like karatsubaAdd, but does subtract.
244 func karatsubaSub(z, x nat, n int) { 246 func karatsubaSub(z, x nat, n int) {
245 if c := subVV(&z[0], &z[0], &x[0], n); c != 0 { 247 if c := subVV(&z[0], &z[0], &x[0], n); c != 0 {
246 subVW(&z[n], &z[n], c, n>>1) 248 subVW(&z[n], &z[n], c, n>>1)
247 } 249 }
248 } 250 }
249 251
250 252
253 // Operands that are shorter than karatsubaThreshold are multiplied using
254 // "grade school" multiplication; for longer operands the Karatsuba algorithm
255 // is used.
256 var karatsubaThreshold int = 30 // modified by calibrate.go
257
251 // karatsuba multiplies x and y and leaves the result in z. 258 // karatsuba multiplies x and y and leaves the result in z.
252 // Both x and y must have the same length and n must be a 259 // Both x and y must have the same length n and n must be a
253 // power of 2. The result vector z must have len(z) >= 6*n. 260 // power of 2. The result vector z must have len(z) >= 6*n.
254 // The (non-normalized) result is placed in z[0 : 2*n]. 261 // The (non-normalized) result is placed in z[0 : 2*n].
255 func karatsuba(z, x, y nat) { 262 func karatsuba(z, x, y nat) {
256 n := len(y) 263 n := len(y)
257 264
258 » // Switch to basic multiplication if the numbers are small. 265 » // Switch to basic multiplication if numbers are odd or small.
259 » if n < karatsubaThreshold { 266 » // (n is always even if karatsubaThreshold is even, but be
260 » » // initialize z 267 » // conservative)
261 » » for i := 2*n - 1; i >= 0; i-- { 268 » if n&1 != 0 || n < karatsubaThreshold || n < 2 {
262 » » » z[i] = 0 269 » » basicMul(z, x, y)
263 » » }
264 » » // "grade school" multiplication
265 » » for i, d := range y {
266 » » » if d != 0 {
267 » » » » z[n+i] = addMulVVW(&z[i], &x[0], d, n)
268 » » » }
269 » » }
270 return 270 return
271 } 271 }
272 » // n >= karatsubaThreshold > 1 272 » // n&1 == 0 && n >= karatsubaThreshold && n >= 2
273 273
274 // Karatsuba multiplication is based on the observation that 274 // Karatsuba multiplication is based on the observation that
275 // for two numbers x and y with: 275 // for two numbers x and y with:
276 // 276 //
277 // x = x1*b + x0 277 // x = x1*b + x0
278 // y = y1*b + y0 278 // y = y1*b + y0
279 // 279 //
280 // the product x*y can be obtained with 3 products z2, z1, z0 280 // the product x*y can be obtained with 3 products z2, z1, z0
281 // instead of 4: 281 // instead of 4:
282 // 282 //
(...skipping 12 matching lines...) Expand all
295 // = x1*y0 + x0*y1 295 // = x1*y0 + x0*y1
296 296
297 // split x, y into "digits" 297 // split x, y into "digits"
298 n2 := n >> 1 // n2 >= 1 298 n2 := n >> 1 // n2 >= 1
299 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 299 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
300 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 300 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
301 301
302 // z is used for the result and temporary storage: 302 // z is used for the result and temporary storage:
303 // 303 //
304 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n 304 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n
305 » // z = [z2 copy|z0 copy| xd*yd | xd:yd | x1*y1 | x0*y0 ] 305 » // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
306 // 306 //
307 // For each recursive call of karatsuba, an unused slice of 307 // For each recursive call of karatsuba, an unused slice of
308 // z is passed in that has (at least) half the length of the 308 // z is passed in that has (at least) half the length of the
309 // caller's z. 309 // caller's z.
310 310
311 // compute z0 and z2 with the result "in place" in z 311 // compute z0 and z2 with the result "in place" in z
312 karatsuba(z, x0, y0) // z0 = x0*y0 312 karatsuba(z, x0, y0) // z0 = x0*y0
313 karatsuba(z[n:], x1, y1) // z2 = x1*y1 313 karatsuba(z[n:], x1, y1) // z2 = x1*y1
314
315 // TODO(gri): In the following we carefully avoid underflow
316 // by recomputing differences and keeping track
317 // of sign changes. Can probably optimize this by
318 // simply ignoring the overflow but track sign changes
319 // and use this to sign extend the product xd*yd before
320 // adding it to z. This should remove quite a bit of code.
321 314
322 // compute xd (or the negative value if underflow occurs) 315 // compute xd (or the negative value if underflow occurs)
323 s := 1 // sign of product xd*yd 316 s := 1 // sign of product xd*yd
324 xd := z[2*n : 2*n+n2] 317 xd := z[2*n : 2*n+n2]
325 if subVV(&xd[0], &x1[0], &x0[0], n2) != 0 { // x1-x0 318 if subVV(&xd[0], &x1[0], &x0[0], n2) != 0 { // x1-x0
326 s = -s 319 s = -s
327 subVV(&xd[0], &x0[0], &x1[0], n2) // x0-x1 320 subVV(&xd[0], &x0[0], &x1[0], n2) // x0-x1
328 } 321 }
329 322
330 // compute yd (or the negative value if underflow occurs) 323 // compute yd (or the negative value if underflow occurs)
331 yd := z[2*n+n2 : 3*n] 324 yd := z[2*n+n2 : 3*n]
332 if subVV(&yd[0], &y0[0], &y1[0], n2) != 0 { // y0-y1 325 if subVV(&yd[0], &y0[0], &y1[0], n2) != 0 { // y0-y1
333 s = -s 326 s = -s
334 subVV(&yd[0], &y1[0], &y0[0], n2) // y1-y0 327 subVV(&yd[0], &y1[0], &y0[0], n2) // y1-y0
335 } 328 }
336 329
337 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 330 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
338 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 331 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
339 p := z[n*3:] 332 p := z[n*3:]
340 karatsuba(p, xd, yd) 333 karatsuba(p, xd, yd)
341 334
342 // save original z2:z0 335 // save original z2:z0
343 // (ok to use upper half of z since we're done recursing) 336 // (ok to use upper half of z since we're done recursing)
344 r := z[n*4:] 337 r := z[n*4:]
345 copy(r, z) 338 copy(r, z)
346 339
347 // add up all partial products 340 // add up all partial products
348 // 341 //
342 // 2*n n 0
349 // z = [ z2 | z0 ] 343 // z = [ z2 | z0 ]
350 // + [ z0 ] 344 // + [ z0 ]
351 // + [ z2 ] 345 // + [ z2 ]
352 // + [ p ] 346 // + [ p ]
353 // 347 //
354 karatsubaAdd(z[n2:], r, n) 348 karatsubaAdd(z[n2:], r, n)
355 karatsubaAdd(z[n2:], r[n:], n) 349 karatsubaAdd(z[n2:], r[n:], n)
356 if s > 0 { 350 if s > 0 {
357 karatsubaAdd(z[n2:], p, n) 351 karatsubaAdd(z[n2:], p, n)
358 } else { 352 } else {
359 karatsubaSub(z[n2:], p, n) 353 karatsubaSub(z[n2:], p, n)
360 } 354 }
361 } 355 }
362 356
363 357
364 // alias returns true if x and y share the same base array. 358 // alias returns true if x and y share the same base array.
365 func alias(x, y nat) bool { 359 func alias(x, y nat) bool {
366 return &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] 360 return &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
367 } 361 }
368 362
369 363
370 // addAt implements z += x*(1<<(_W*i)); z must be long enough. 364 // addAt implements z += x*(1<<(_W*i)); z must be long enough.
365 // (we don't use nat.add because we need z to stay the same
366 // slice, and we don't need to normalize z after each addition)
371 func addAt(z, x nat, i int) { 367 func addAt(z, x nat, i int) {
372 if n := len(x); n > 0 { 368 if n := len(x); n > 0 {
373 if c := addVV(&z[i], &z[i], &x[0], n); c != 0 { 369 if c := addVV(&z[i], &z[i], &x[0], n); c != 0 {
374 j := i + n 370 j := i + n
375 if j < len(z) { 371 if j < len(z) {
376 addVW(&z[j], &z[j], c, len(z)-j) 372 addVW(&z[j], &z[j], c, len(z)-j)
377 } 373 }
378 } 374 }
379 } 375 }
380 } 376 }
381 377
382 378
379 func max(x, y int) int {
380 if x > y {
381 return x
382 }
383 return y
384 }
385
386
383 func (z nat) mul(x, y nat) nat { 387 func (z nat) mul(x, y nat) nat {
384 m := len(x) 388 m := len(x)
385 n := len(y) 389 n := len(y)
386 390
387 switch { 391 switch {
388 case m < n: 392 case m < n:
389 return z.mul(y, x) 393 return z.mul(y, x)
390 case m == 0 || n == 0: 394 case m == 0 || n == 0:
391 » » return z.make(0, false) 395 » » return z.make(0)
392 case n == 1: 396 case n == 1:
393 return z.mulAddWW(x, y[0], 0) 397 return z.mulAddWW(x, y[0], 0)
394 } 398 }
395 » // m >= n && m > 1 && n > 1 399 » // m >= n > 1
396 400
397 // determine if z can be reused 401 // determine if z can be reused
398 if len(z) > 0 && (alias(z, x) || alias(z, y)) { 402 if len(z) > 0 && (alias(z, x) || alias(z, y)) {
399 z = nil // z is an alias for x or y - cannot reuse 403 z = nil // z is an alias for x or y - cannot reuse
400 } 404 }
401 405
402 » if n < karatsubaThreshold { 406 » // use basic multiplication if the numbers are small
403 » » // "grade school" multiplication 407 » if n < karatsubaThreshold || n < 2 {
404 » » z = z.make(m+n, true) 408 » » z = z.make(m + n)
405 » » for i, d := range y { 409 » » basicMul(z, x, y)
406 » » » if d != 0 {
407 » » » » z[m+i] = addMulVVW(&z[i], &x[0], d, m)
408 » » » }
409 » » }
410 return z.norm() 410 return z.norm()
411 } 411 }
412 » // m >= n && n >= karatsubaThreshold 412 » // m >= n && n >= karatsubaThreshold && n >= 2
413 413
414 » // Note that even though we passed the Karatsuba threshold, 414 » // determine largest k such that
415 » // because we tested against n and not k (see below) we may
416 » // still end up using grade-school multiplication, albeit with
417 » // an intermediate step if the Karatsuba theshold is not a
418 » // power of 2. It appears that this intermediate step makes
419 » // things faster (e.g., the threshold is < 256 at the moment).
420 » // Theoretically, there are more operations involved but the numbers
421 » // are larger and thus "internal fragmentation" (i.e., total number
422 » // unused bits in leading words) may be smaller, possibly resulting
423 » // in fewer actual machine multiplications.
424
425 » // Determine k such that:
426 // 415 //
427 // x = x1*b + x0 416 // x = x1*b + x0
428 // y = y1*b + y0 (and k <= len(y), which implies k <= len(x)) 417 // y = y1*b + y0 (and k <= len(y), which implies k <= len(x))
429 //
430 // and
431 //
432 // b = 1<<(_W*k) ("base" of digits xi, yi) 418 // b = 1<<(_W*k) ("base" of digits xi, yi)
433 // 419 //
434 » k := 1 << uint(log2(Word(n))) 420 » // and k is karatsubaThreshold multiplied by a power of 2
435 421 » k := max(karatsubaThreshold, 2)
436 » // If x1 and/or y1 are not 0, compute product explicitly: 422 » for k*2 <= n {
437 » // 423 » » k *= 2
438 » // x*y = x1*y1*b*b + x1*y0*b + x0*y1*b + x0*y0 424 » }
425 » // k <= n
426
427 » // multiply x0 and y0 via Karatsuba
428 » x0 := x[0:k] // x0 is not normalized
429 » y0 := y[0:k] // y0 is not normalized
430 » z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and ful l result of x*y
431 » karatsuba(z, x0, y0)
432 » z = z[0 : m+n] // z has final length but may be incomplete, upper portio n is garbage
433
434 » // If x1 and/or y1 are not 0, add missing terms to z explicitly:
435 » //
436 » // m+n 2*k 0
437 » // z = [ ... | x0*y0 ]
438 » // + [ x1*y1 ]
439 » // + [ x1*y0 ]
440 » // + [ x0*y1 ]
439 // 441 //
440 if k < n || m != n { 442 if k < n || m != n {
441 » » x1, x0 := x[k:], x[0:k].norm() // x1 is normalized because x is 443 » » x1 := x[k:] // x1 is normalized because x is
442 » » y1, y0 := y[k:], y[0:k].norm() // y1 is normalized because y is 444 » » y1 := y[k:] // y1 is normalized because y is
443 » » z = z.make(m+n, true) 445 » » var t nat
444 » » copy(z[2*k:], natZero.mul(x1, y1)) // z may not be normalized! 446 » » t = t.mul(x1, y1)
445 » » addAt(z, natZero.mul(x1, y0), k) 447 » » copy(z[2*k:], t)
446 » » addAt(z, natZero.mul(x0, y1), k) 448 » » z[2*k+len(t):].clear() // upper portion of z is garbage
447 » » addAt(z, natZero.mul(x0, y0), 0) // (could invoke karatsuba for x0, y0 directly) 449 » » t = t.mul(x1, y0.norm())
448 » » return z.norm() 450 » » addAt(z, t, k)
449 » } 451 » » t = t.mul(x0.norm(), y1)
450 » // k == n && m == n 452 » » addAt(z, t, k)
451 453 » }
452 » // Both x and y have the same length k which is a power of 2 454
453 » // and thus are directly suitable for Karatsuba multiplication. 455 » return z.norm()
454 » z = z.make(6*k, false)
455 » karatsuba(z, x, y)
456 » return z[0 : 2*n].norm()
457 } 456 }
458 457
459 458
460 // mulRange computes the product of all the unsigned integers in the 459 // mulRange computes the product of all the unsigned integers in the
461 // range [a, b] inclusively. If a > b (empty range), the result is 1. 460 // range [a, b] inclusively. If a > b (empty range), the result is 1.
462 func (z nat) mulRange(a, b uint64) nat { 461 func (z nat) mulRange(a, b uint64) nat {
adonovan 2020/02/12 14:46:30 What's the purpose of treating the sequence as lea
gri 2020/02/19 00:51:24 The recursive approach appears faster in practice
463 switch { 462 switch {
464 case a == 0: 463 case a == 0:
465 // cut long ranges short (optimization) 464 // cut long ranges short (optimization)
466 return z.new(0) 465 return z.new(0)
467 case a > b: 466 case a > b:
468 return z.new(1) 467 return z.new(1)
469 case a == b: 468 case a == b:
470 return z.new(a) 469 return z.new(a)
471 case a+1 == b: 470 case a+1 == b:
472 » » return z.mul(natZero.new(a), natZero.new(b)) 471 » » return z.mul(nat(nil).new(a), nat(nil).new(b))
473 } 472 }
474 m := (a + b) / 2 473 m := (a + b) / 2
475 » return z.mul(natZero.mulRange(a, m), natZero.mulRange(m+1, b)) 474 » return z.mul(nat(nil).mulRange(a, m), nat(nil).mulRange(m+1, b))
476 } 475 }
477 476
478 477
479 // q = (x-r)/y, with 0 <= r < y 478 // q = (x-r)/y, with 0 <= r < y
480 func (z nat) divW(x nat, y Word) (q nat, r Word) { 479 func (z nat) divW(x nat, y Word) (q nat, r Word) {
481 m := len(x) 480 m := len(x)
482 switch { 481 switch {
483 case y == 0: 482 case y == 0:
484 panic("division by zero") 483 panic("division by zero")
485 case y == 1: 484 case y == 1:
486 q = z.set(x) // result is x 485 q = z.set(x) // result is x
487 return 486 return
488 case m == 0: 487 case m == 0:
489 q = z.set(nil) // result is 0 488 q = z.set(nil) // result is 0
490 return 489 return
491 } 490 }
492 // m > 0 491 // m > 0
493 » z = z.make(m, false) 492 » z = z.make(m)
494 r = divWVW(&z[0], 0, &x[0], y, m) 493 r = divWVW(&z[0], 0, &x[0], y, m)
495 q = z.norm() 494 q = z.norm()
496 return 495 return
497 } 496 }
498 497
499 498
500 func (z nat) div(z2, u, v nat) (q, r nat) { 499 func (z nat) div(z2, u, v nat) (q, r nat) {
501 if len(v) == 0 { 500 if len(v) == 0 {
502 panic("division by zero") 501 panic("division by zero")
503 } 502 }
504 503
505 if u.cmp(v) < 0 { 504 if u.cmp(v) < 0 {
506 » » q = z.make(0, false) 505 » » q = z.make(0)
507 r = z2.set(u) 506 r = z2.set(u)
508 return 507 return
509 } 508 }
510 509
511 if len(v) == 1 { 510 if len(v) == 1 {
512 var rprime Word 511 var rprime Word
513 q, rprime = z.divW(u, v[0]) 512 q, rprime = z.divW(u, v[0])
514 if rprime > 0 { 513 if rprime > 0 {
515 » » » r = z2.make(1, false) 514 » » » r = z2.make(1)
516 r[0] = rprime 515 r[0] = rprime
517 } else { 516 } else {
518 » » » r = z2.make(0, false) 517 » » » r = z2.make(0)
519 } 518 }
520 return 519 return
521 } 520 }
522 521
523 q, r = z.divLarge(z2, u, v) 522 q, r = z.divLarge(z2, u, v)
524 return 523 return
525 } 524 }
526 525
527 526
528 // q = (uIn-r)/v, with 0 <= r < y 527 // q = (uIn-r)/v, with 0 <= r < y
529 // See Knuth, Volume 2, section 4.3.1, Algorithm D. 528 // See Knuth, Volume 2, section 4.3.1, Algorithm D.
530 // Preconditions: 529 // Preconditions:
531 // len(v) >= 2 530 // len(v) >= 2
532 // len(uIn) >= len(v) 531 // len(uIn) >= len(v)
533 func (z nat) divLarge(z2, uIn, v nat) (q, r nat) { 532 func (z nat) divLarge(z2, uIn, v nat) (q, r nat) {
534 n := len(v) 533 n := len(v)
535 m := len(uIn) - len(v) 534 m := len(uIn) - len(v)
536 535
537 var u nat 536 var u nat
538 if z2 == nil || &z2[0] == &uIn[0] { 537 if z2 == nil || &z2[0] == &uIn[0] {
539 » » u = u.make(len(uIn)+1, true) // uIn is an alias for z2 538 » » u = u.make(len(uIn) + 1).clear() // uIn is an alias for z2
540 } else { 539 } else {
541 » » u = z2.make(len(uIn)+1, true) 540 » » u = z2.make(len(uIn) + 1).clear()
542 } 541 }
543 qhatv := make(nat, len(v)+1) 542 qhatv := make(nat, len(v)+1)
544 » q = z.make(m+1, false) 543 » q = z.make(m + 1)
545 544
546 // D1. 545 // D1.
547 shift := uint(leadingZeroBits(v[n-1])) 546 shift := uint(leadingZeroBits(v[n-1]))
548 v.shiftLeft(v, shift) 547 v.shiftLeft(v, shift)
549 u.shiftLeft(uIn, shift) 548 u.shiftLeft(uIn, shift)
550 u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift)) 549 u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift))
551 550
552 // D2. 551 // D2.
553 for j := m; j >= 0; j-- { 552 for j := m; j >= 0; j-- {
554 // D3. 553 // D3.
(...skipping 136 matching lines...) Expand 10 before | Expand all | Expand 10 after
691 690
692 if len(x) == 0 { 691 if len(x) == 0 {
693 return "0" 692 return "0"
694 } 693 }
695 694
696 // allocate buffer for conversion 695 // allocate buffer for conversion
697 i := (x.log2()+1)/log2(Word(base)) + 1 // +1: round up 696 i := (x.log2()+1)/log2(Word(base)) + 1 // +1: round up
698 s := make([]byte, i) 697 s := make([]byte, i)
699 698
700 // don't destroy x 699 // don't destroy x
701 » q := natZero.set(x) 700 » q := nat(nil).set(x)
702 701
703 // convert 702 // convert
704 for len(q) > 0 { 703 for len(q) > 0 {
705 i-- 704 i--
706 var r Word 705 var r Word
707 q, r = q.divW(q, Word(base)) 706 q, r = q.divW(q, Word(base))
708 s[i] = "0123456789abcdef"[r] 707 s[i] = "0123456789abcdef"[r]
709 } 708 }
710 709
711 return string(s[i:]) 710 return string(s[i:])
(...skipping 52 matching lines...) Expand 10 before | Expand all | Expand 10 after
764 case 64: 763 case 64:
765 return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) 764 return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
766 default: 765 default:
767 panic("Unknown word size") 766 panic("Unknown word size")
768 } 767 }
769 768
770 return 0 769 return 0
771 } 770 }
772 771
773 772
773 // TODO(gri) Make the shift routines faster.
774 // Use pidigits.go benchmark as a test case.
775
774 // To avoid losing the top n bits, z should be sized so that 776 // To avoid losing the top n bits, z should be sized so that
775 // len(z) == len(x) + 1. 777 // len(z) == len(x) + 1.
776 func (z nat) shiftLeft(x nat, n uint) nat { 778 func (z nat) shiftLeft(x nat, n uint) nat {
777 if len(x) == 0 { 779 if len(x) == 0 {
778 return x 780 return x
779 } 781 }
780 782
781 ñ := _W - n 783 ñ := _W - n
782 m := x[len(x)-1] 784 m := x[len(x)-1]
783 if len(z) > len(x) { 785 if len(z) > len(x) {
(...skipping 27 matching lines...) Expand all
811 813
812 814
813 // greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2) 815 // greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2)
814 func greaterThan(x1, x2, y1, y2 Word) bool { return x1 > y1 || x1 == y1 && x2 > y2 } 816 func greaterThan(x1, x2, y1, y2 Word) bool { return x1 > y1 || x1 == y1 && x2 > y2 }
815 817
816 818
817 // modW returns x % d. 819 // modW returns x % d.
818 func (x nat) modW(d Word) (r Word) { 820 func (x nat) modW(d Word) (r Word) {
819 // TODO(agl): we don't actually need to store the q value. 821 // TODO(agl): we don't actually need to store the q value.
820 var q nat 822 var q nat
821 » q = q.make(len(x), false) 823 » q = q.make(len(x))
822 return divWVW(&q[0], 0, &x[0], d, len(x)) 824 return divWVW(&q[0], 0, &x[0], d, len(x))
823 } 825 }
824 826
825 827
826 // powersOfTwoDecompose finds q and k such that q * 1<<k = n and q is odd. 828 // powersOfTwoDecompose finds q and k such that q * 1<<k = n and q is odd.
827 func (n nat) powersOfTwoDecompose() (q nat, k Word) { 829 func (n nat) powersOfTwoDecompose() (q nat, k Word) {
828 if len(n) == 0 { 830 if len(n) == 0 {
829 return n, 0 831 return n, 0
830 } 832 }
831 833
832 zeroWords := 0 834 zeroWords := 0
833 for n[zeroWords] == 0 { 835 for n[zeroWords] == 0 {
834 zeroWords++ 836 zeroWords++
835 } 837 }
836 // One of the words must be non-zero by invariant, therefore 838 // One of the words must be non-zero by invariant, therefore
837 // zeroWords < len(n). 839 // zeroWords < len(n).
838 x := trailingZeroBits(n[zeroWords]) 840 x := trailingZeroBits(n[zeroWords])
839 841
840 » q = q.make(len(n)-zeroWords, false) 842 » q = q.make(len(n) - zeroWords)
841 q.shiftRight(n[zeroWords:], uint(x)) 843 q.shiftRight(n[zeroWords:], uint(x))
842 q = q.norm() 844 q = q.norm()
843 845
844 k = Word(_W*zeroWords + x) 846 k = Word(_W*zeroWords + x)
845 return 847 return
846 } 848 }
847 849
848 850
849 // random creates a random integer in [0..limit), using the space in z if 851 // random creates a random integer in [0..limit), using the space in z if
850 // possible. n is the bit length of limit. 852 // possible. n is the bit length of limit.
851 func (z nat) random(rand *rand.Rand, limit nat, n int) nat { 853 func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
852 bitLengthOfMSW := uint(n % _W) 854 bitLengthOfMSW := uint(n % _W)
853 if bitLengthOfMSW == 0 { 855 if bitLengthOfMSW == 0 {
854 bitLengthOfMSW = _W 856 bitLengthOfMSW = _W
855 } 857 }
856 mask := Word((1 << bitLengthOfMSW) - 1) 858 mask := Word((1 << bitLengthOfMSW) - 1)
857 » z = z.make(len(limit), false) 859 » z = z.make(len(limit))
858 860
859 for { 861 for {
860 for i := range z { 862 for i := range z {
861 switch _W { 863 switch _W {
862 case 32: 864 case 32:
863 z[i] = Word(rand.Uint32()) 865 z[i] = Word(rand.Uint32())
864 case 64: 866 case 64:
865 z[i] = Word(rand.Uint32()) | Word(rand.Uint32()) <<32 867 z[i] = Word(rand.Uint32()) | Word(rand.Uint32()) <<32
866 } 868 }
867 } 869 }
868 870
869 z[len(limit)-1] &= mask 871 z[len(limit)-1] &= mask
870 872
871 if z.cmp(limit) < 0 { 873 if z.cmp(limit) < 0 {
872 break 874 break
873 } 875 }
874 } 876 }
875 877
876 return z.norm() 878 return z.norm()
877 } 879 }
878 880
879 881
880 // If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It 882 // If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It
881 // reuses the storage of z if possible. 883 // reuses the storage of z if possible.
882 func (z nat) expNN(x, y, m nat) nat { 884 func (z nat) expNN(x, y, m nat) nat {
883 if len(y) == 0 { 885 if len(y) == 0 {
884 » » z = z.make(1, false) 886 » » z = z.make(1)
885 z[0] = 1 887 z[0] = 1
886 return z 888 return z
887 } 889 }
888 890
889 if m != nil { 891 if m != nil {
890 // We likely end up being as long as the modulus. 892 // We likely end up being as long as the modulus.
891 » » z = z.make(len(m), false) 893 » » z = z.make(len(m))
892 } 894 }
893 z = z.set(x) 895 z = z.set(x)
894 v := y[len(y)-1] 896 v := y[len(y)-1]
895 // It's invalid for the most significant word to be zero, therefore we 897 // It's invalid for the most significant word to be zero, therefore we
896 // will find a one bit. 898 // will find a one bit.
897 shift := leadingZeros(v) + 1 899 shift := leadingZeros(v) + 1
898 v <<= shift 900 v <<= shift
899 var q nat 901 var q nat
900 902
901 const mask = 1 << (_W - 1) 903 const mask = 1 << (_W - 1)
(...skipping 92 matching lines...) Expand 10 before | Expand all | Expand 10 after
994 if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 || 996 if r%3 == 0 || r%5 == 0 || r%7 == 0 || r%11 == 0 ||
995 r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 { 997 r%13 == 0 || r%17 == 0 || r%19 == 0 || r%23 == 0 || r%29 == 0 {
996 return false 998 return false
997 } 999 }
998 1000
999 if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 || 1001 if _W == 64 && (r%31 == 0 || r%37 == 0 || r%41 == 0 ||
1000 r%43 == 0 || r%47 == 0 || r%53 == 0) { 1002 r%43 == 0 || r%47 == 0 || r%53 == 0) {
1001 return false 1003 return false
1002 } 1004 }
1003 1005
1004 » nm1 := natZero.sub(n, natOne) 1006 » nm1 := nat(nil).sub(n, natOne)
1005 // 1<<k * q = nm1; 1007 // 1<<k * q = nm1;
1006 q, k := nm1.powersOfTwoDecompose() 1008 q, k := nm1.powersOfTwoDecompose()
1007 1009
1008 » nm3 := natZero.sub(nm1, natTwo) 1010 » nm3 := nat(nil).sub(nm1, natTwo)
1009 rand := rand.New(rand.NewSource(int64(n[0]))) 1011 rand := rand.New(rand.NewSource(int64(n[0])))
1010 1012
1011 var x, y, quotient nat 1013 var x, y, quotient nat
1012 nm3Len := nm3.len() 1014 nm3Len := nm3.len()
1013 1015
1014 NextRandom: 1016 NextRandom:
1015 for i := 0; i < reps; i++ { 1017 for i := 0; i < reps; i++ {
1016 x = x.random(rand, nm3, nm3Len) 1018 x = x.random(rand, nm3, nm3Len)
1017 x = x.add(x, natTwo) 1019 x = x.add(x, natTwo)
1018 y = y.expNN(x, q, n) 1020 y = y.expNN(x, q, n)
1019 if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { 1021 if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 {
1020 continue 1022 continue
1021 } 1023 }
1022 for j := Word(1); j < k; j++ { 1024 for j := Word(1); j < k; j++ {
1023 y = y.mul(y, y) 1025 y = y.mul(y, y)
1024 quotient, y = quotient.div(y, y, n) 1026 quotient, y = quotient.div(y, y, n)
1025 if y.cmp(nm1) == 0 { 1027 if y.cmp(nm1) == 0 {
1026 continue NextRandom 1028 continue NextRandom
1027 } 1029 }
1028 if y.cmp(natOne) == 0 { 1030 if y.cmp(natOne) == 0 {
1029 return false 1031 return false
1030 } 1032 }
1031 } 1033 }
1032 return false 1034 return false
1033 } 1035 }
1034 1036
1035 return true 1037 return true
1036 } 1038 }
LEFTRIGHT

Powered by Google App Engine
RSS Feeds Recent Issues | This issue
This is Rietveld f62528b