Rietveld Code Review Tool
Help | Bug tracker | Discussion group | Source code | Sign in
(295)

Delta Between Two Patch Sets: src/pkg/big/nat.go

Issue 1004042: code review 1004042: big: implemented Karatsuba multiplication (Closed)
Left Patch Set: code review 1004042: big: implemented Karatsuba multiplication Created 13 years, 11 months ago
Right Patch Set: code review 1004042: big: implemented Karatsuba multiplication Created 13 years, 11 months ago
Left:
Right:
Use n/p to move between diff chunks; N/P to move between comments. Please Sign in to add in-line comments.
Jump to:
Left: Side by side diff | Download
Right: Side by side diff | Download
« no previous file with change/comment | « src/pkg/big/int_test.go ('k') | src/pkg/big/nat_test.go » ('j') | no next file with change/comment »
Toggle Intra-line Diffs ('i') | Expand Comments ('e') | Collapse Comments ('c') | Show Comments Hide Comments ('s')
LEFTRIGHT
1 // Copyright 2009 The Go Authors. All rights reserved. 1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style 2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file. 3 // license that can be found in the LICENSE file.
4 4
5 // This file contains operations on unsigned multi-precision integers. 5 // This file contains operations on unsigned multi-precision integers.
6 // These are the building blocks for the operations on signed integers 6 // These are the building blocks for the operations on signed integers
7 // and rationals. 7 // and rationals.
8 8
9 // This package implements multi-precision arithmetic (big numbers). 9 // This package implements multi-precision arithmetic (big numbers).
10 // The following numeric types are supported: 10 // The following numeric types are supported:
(...skipping 24 matching lines...) Expand all
35 // representation of 0 is the empty or nil slice (length = 0). 35 // representation of 0 is the empty or nil slice (length = 0).
36 36
37 type nat []Word 37 type nat []Word
38 38
39 var ( 39 var (
40 natOne = nat{1} 40 natOne = nat{1}
41 natTwo = nat{2} 41 natTwo = nat{2}
42 ) 42 )
43 43
44 44
45 func (z nat) clear() nat {
46 for i := range z {
47 z[i] = 0
48 }
49 return z
50 }
51
52
45 func (z nat) norm() nat { 53 func (z nat) norm() nat {
46 i := len(z) 54 i := len(z)
47 for i > 0 && z[i-1] == 0 { 55 for i > 0 && z[i-1] == 0 {
48 i-- 56 i--
49 } 57 }
50 z = z[0:i] 58 z = z[0:i]
51 return z 59 return z
52 } 60 }
53 61
54 62
55 func (z nat) make(m int, clear bool) nat { 63 func (z nat) make(m int) nat {
56 if cap(z) > m { 64 if cap(z) > m {
57 » » z = z[0:m] // reuse z - has at least one extra word for a carry, if any 65 » » return z[0:m] // reuse z - has at least one extra word for a car ry, if any
58 » » if clear {
59 » » » for i := range z {
60 » » » » z[i] = 0
61 » » » }
62 » » }
63 » » return z
64 } 66 }
65 67
66 c := 4 // minimum capacity 68 c := 4 // minimum capacity
67 if m > c { 69 if m > c {
68 c = m 70 c = m
69 } 71 }
70 return make(nat, m, c+1) // +1: extra word for a carry, if any 72 return make(nat, m, c+1) // +1: extra word for a carry, if any
71 } 73 }
72 74
73 75
74 func (z nat) new(x uint64) nat { 76 func (z nat) new(x uint64) nat {
75 if x == 0 { 77 if x == 0 {
76 » » return z.make(0, false) 78 » » return z.make(0)
77 } 79 }
78 80
79 // single-digit values 81 // single-digit values
80 if x == uint64(Word(x)) { 82 if x == uint64(Word(x)) {
81 » » z = z.make(1, false) 83 » » z = z.make(1)
82 z[0] = Word(x) 84 z[0] = Word(x)
83 return z 85 return z
84 } 86 }
85 87
86 // compute number of words n required to represent x 88 // compute number of words n required to represent x
87 n := 0 89 n := 0
88 for t := x; t > 0; t >>= _W { 90 for t := x; t > 0; t >>= _W {
89 n++ 91 n++
90 } 92 }
91 93
92 // split x into n words 94 // split x into n words
93 » z = z.make(n, false) 95 » z = z.make(n)
94 for i := 0; i < n; i++ { 96 for i := 0; i < n; i++ {
95 z[i] = Word(x & _M) 97 z[i] = Word(x & _M)
96 x >>= _W 98 x >>= _W
97 } 99 }
98 100
99 return z 101 return z
100 } 102 }
101 103
102 104
103 func (z nat) set(x nat) nat { 105 func (z nat) set(x nat) nat {
104 » z = z.make(len(x), false) 106 » z = z.make(len(x))
105 for i, d := range x { 107 for i, d := range x {
106 z[i] = d 108 z[i] = d
107 } 109 }
108 return z 110 return z
109 } 111 }
110 112
111 113
112 func (z nat) add(x, y nat) nat { 114 func (z nat) add(x, y nat) nat {
113 m := len(x) 115 m := len(x)
114 n := len(y) 116 n := len(y)
115 117
116 switch { 118 switch {
117 case m < n: 119 case m < n:
118 return z.add(y, x) 120 return z.add(y, x)
119 case m == 0: 121 case m == 0:
120 // n == 0 because m >= n; result is 0 122 // n == 0 because m >= n; result is 0
121 » » return z.make(0, false) 123 » » return z.make(0)
122 case n == 0: 124 case n == 0:
123 // result is x 125 // result is x
124 return z.set(x) 126 return z.set(x)
125 } 127 }
126 // m > 0 128 // m > 0
127 129
128 » z = z.make(m, false) 130 » z = z.make(m)
129 c := addVV(&z[0], &x[0], &y[0], n) 131 c := addVV(&z[0], &x[0], &y[0], n)
130 if m > n { 132 if m > n {
131 c = addVW(&z[n], &x[n], c, m-n) 133 c = addVW(&z[n], &x[n], c, m-n)
132 } 134 }
133 if c > 0 { 135 if c > 0 {
134 z = z[0 : m+1] 136 z = z[0 : m+1]
135 z[m] = c 137 z[m] = c
136 } 138 }
137 139
138 return z 140 return z
139 } 141 }
140 142
141 143
142 func (z nat) sub(x, y nat) nat { 144 func (z nat) sub(x, y nat) nat {
143 m := len(x) 145 m := len(x)
144 n := len(y) 146 n := len(y)
145 147
146 switch { 148 switch {
147 case m < n: 149 case m < n:
148 panic("underflow") 150 panic("underflow")
149 case m == 0: 151 case m == 0:
150 // n == 0 because m >= n; result is 0 152 // n == 0 because m >= n; result is 0
151 » » return z.make(0, false) 153 » » return z.make(0)
152 case n == 0: 154 case n == 0:
153 // result is x 155 // result is x
154 return z.set(x) 156 return z.set(x)
155 } 157 }
156 // m > 0 158 // m > 0
157 159
158 » z = z.make(m, false) 160 » z = z.make(m)
159 c := subVV(&z[0], &x[0], &y[0], n) 161 c := subVV(&z[0], &x[0], &y[0], n)
160 if m > n { 162 if m > n {
161 c = subVW(&z[n], &x[n], c, m-n) 163 c = subVW(&z[n], &x[n], c, m-n)
162 } 164 }
163 if c != 0 { 165 if c != 0 {
164 panic("underflow") 166 panic("underflow")
165 } 167 }
166 z = z.norm() 168 z = z.norm()
167 169
168 return z 170 return z
(...skipping 28 matching lines...) Expand all
197 } 199 }
198 200
199 201
200 func (z nat) mulAddWW(x nat, y, r Word) nat { 202 func (z nat) mulAddWW(x nat, y, r Word) nat {
201 m := len(x) 203 m := len(x)
202 if m == 0 || y == 0 { 204 if m == 0 || y == 0 {
203 return z.new(uint64(r)) // result is r 205 return z.new(uint64(r)) // result is r
204 } 206 }
205 // m > 0 207 // m > 0
206 208
207 » z = z.make(m, false) 209 » z = z.make(m)
208 c := mulAddVWW(&z[0], &x[0], y, r, m) 210 c := mulAddVWW(&z[0], &x[0], y, r, m)
209 if c > 0 { 211 if c > 0 {
210 z = z[0 : m+1] 212 z = z[0 : m+1]
211 z[m] = c 213 z[m] = c
212 } 214 }
213 215
214 return z 216 return z
215 } 217 }
216 218
217 219
218 // Operands that are shorter than this threshold are multiplied using 220 // basicMul multiplies x and y and leaves the result in z.
219 // "grade school" multiplication; for larger operands the Karatsuba 221 // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
220 // algorithm is used. 222 func basicMul(z, x, y nat) {
221 // 223 » // initialize z
222 // The value has been found empirically for gotest -benchmarks=Fact 224 » for i := range z[0 : len(x)+len(y)] {
223 // on a machine running OS X on a 3.06GHz Intel Core 2 Duo. 225 » » z[i] = 0
224 // 226 » }
225 // (To disable Karatsuba multiplication, set the threshold to a very 227 » // multiply
226 // large value). 228 » for i, d := range y {
227 const karatsubaThreshold = 245 229 » » if d != 0 {
228 230 » » » z[len(x)+i] = addMulVVW(&z[i], &x[0], d, len(x))
229 // karatsubaThreshold must be >= 2. 231 » » }
230 // Trigger compile error if that's not true. 232 » }
231 const _ uint = karatsubaThreshold - 2 233 }
232 234
233 235
234 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks. 236 // Fast version of z[0:n+n>>1].add(z[0:n+n>>1], x[0:n]) w/o bounds checks.
235 // Factored out for readability - do not use outside karatsuba. 237 // Factored out for readability - do not use outside karatsuba.
236 func karatsubaAdd(z, x nat, n int) { 238 func karatsubaAdd(z, x nat, n int) {
237 if c := addVV(&z[0], &z[0], &x[0], n); c != 0 { 239 if c := addVV(&z[0], &z[0], &x[0], n); c != 0 {
238 addVW(&z[n], &z[n], c, n>>1) 240 addVW(&z[n], &z[n], c, n>>1)
239 } 241 }
240 } 242 }
241 243
242 244
243 // Like karatsubaAdd, but does subtract. 245 // Like karatsubaAdd, but does subtract.
244 func karatsubaSub(z, x nat, n int) { 246 func karatsubaSub(z, x nat, n int) {
245 if c := subVV(&z[0], &z[0], &x[0], n); c != 0 { 247 if c := subVV(&z[0], &z[0], &x[0], n); c != 0 {
246 subVW(&z[n], &z[n], c, n>>1) 248 subVW(&z[n], &z[n], c, n>>1)
247 } 249 }
248 } 250 }
249 251
252
253 // Operands that are shorter than karatsubaThreshold are multiplied using
254 // "grade school" multiplication; for longer operands the Karatsuba algorithm
255 // is used.
256 var karatsubaThreshold int = 30 // modified by calibrate.go
250 257
251 // karatsuba multiplies x and y and leaves the result in z. 258 // karatsuba multiplies x and y and leaves the result in z.
252 // Both x and y must have the same length n and n must be a 259 // Both x and y must have the same length n and n must be a
253 // power of 2. The result vector z must have len(z) >= 6*n. 260 // power of 2. The result vector z must have len(z) >= 6*n.
254 // The (non-normalized) result is placed in z[0 : 2*n]. 261 // The (non-normalized) result is placed in z[0 : 2*n].
255 func karatsuba(z, x, y nat) { 262 func karatsuba(z, x, y nat) {
256 n := len(y) 263 n := len(y)
257 264
258 » // Switch to basic multiplication if the numbers are small. 265 » // Switch to basic multiplication if numbers are odd or small.
259 » if n < karatsubaThreshold { 266 » // (n is always even if karatsubaThreshold is even, but be
260 » » // initialize z 267 » // conservative)
261 » » for i := 2*n - 1; i >= 0; i-- { 268 » if n&1 != 0 || n < karatsubaThreshold || n < 2 {
262 » » » z[i] = 0 269 » » basicMul(z, x, y)
263 » » }
264 » » // "grade school" multiplication
265 » » for i, d := range y {
266 » » » if d != 0 {
267 » » » » z[n+i] = addMulVVW(&z[i], &x[0], d, n)
268 » » » }
269 » » }
270 return 270 return
271 } 271 }
272 » // n >= karatsubaThreshold > 1 272 » // n&1 == 0 && n >= karatsubaThreshold && n >= 2
273 273
274 // Karatsuba multiplication is based on the observation that 274 // Karatsuba multiplication is based on the observation that
275 // for two numbers x and y with: 275 // for two numbers x and y with:
276 // 276 //
277 // x = x1*b + x0 277 // x = x1*b + x0
278 // y = y1*b + y0 278 // y = y1*b + y0
279 // 279 //
280 // the product x*y can be obtained with 3 products z2, z1, z0 280 // the product x*y can be obtained with 3 products z2, z1, z0
281 // instead of 4: 281 // instead of 4:
282 // 282 //
(...skipping 12 matching lines...) Expand all
295 // = x1*y0 + x0*y1 295 // = x1*y0 + x0*y1
296 296
297 // split x, y into "digits" 297 // split x, y into "digits"
298 n2 := n >> 1 // n2 >= 1 298 n2 := n >> 1 // n2 >= 1
299 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0 299 x1, x0 := x[n2:], x[0:n2] // x = x1*b + y0
300 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0 300 y1, y0 := y[n2:], y[0:n2] // y = y1*b + y0
301 301
302 // z is used for the result and temporary storage: 302 // z is used for the result and temporary storage:
303 // 303 //
304 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n 304 // 6*n 5*n 4*n 3*n 2*n 1*n 0*n
305 » // z = [z2 copy|z0 copy| xd*yd | xd:yd | x1*y1 | x0*y0 ] 305 » // z = [z2 copy|z0 copy| xd*yd | yd:xd | x1*y1 | x0*y0 ]
306 // 306 //
307 // For each recursive call of karatsuba, an unused slice of 307 // For each recursive call of karatsuba, an unused slice of
308 // z is passed in that has (at least) half the length of the 308 // z is passed in that has (at least) half the length of the
309 // caller's z. 309 // caller's z.
310 310
311 // compute z0 and z2 with the result "in place" in z 311 // compute z0 and z2 with the result "in place" in z
312 karatsuba(z, x0, y0) // z0 = x0*y0 312 karatsuba(z, x0, y0) // z0 = x0*y0
313 karatsuba(z[n:], x1, y1) // z2 = x1*y1 313 karatsuba(z[n:], x1, y1) // z2 = x1*y1
314
315 // TODO(gri): In the following we carefully avoid underflow
316 // by recomputing differences and keeping track
317 // of sign changes. Can probably optimize this by
318 // simply ignoring the overflow but track sign changes
319 // and use this to sign extend the product xd*yd before
320 // adding it to z. This should remove quite a bit of code.
321 314
322 // compute xd (or the negative value if underflow occurs) 315 // compute xd (or the negative value if underflow occurs)
323 s := 1 // sign of product xd*yd 316 s := 1 // sign of product xd*yd
324 xd := z[2*n : 2*n+n2] 317 xd := z[2*n : 2*n+n2]
325 if subVV(&xd[0], &x1[0], &x0[0], n2) != 0 { // x1-x0 318 if subVV(&xd[0], &x1[0], &x0[0], n2) != 0 { // x1-x0
326 s = -s 319 s = -s
327 subVV(&xd[0], &x0[0], &x1[0], n2) // x0-x1 320 subVV(&xd[0], &x0[0], &x1[0], n2) // x0-x1
328 } 321 }
329 322
330 // compute yd (or the negative value if underflow occurs) 323 // compute yd (or the negative value if underflow occurs)
331 yd := z[2*n+n2 : 3*n] 324 yd := z[2*n+n2 : 3*n]
332 if subVV(&yd[0], &y0[0], &y1[0], n2) != 0 { // y0-y1 325 if subVV(&yd[0], &y0[0], &y1[0], n2) != 0 { // y0-y1
333 s = -s 326 s = -s
334 subVV(&yd[0], &y1[0], &y0[0], n2) // y1-y0 327 subVV(&yd[0], &y1[0], &y0[0], n2) // y1-y0
335 } 328 }
336 329
337 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0 330 // p = (x1-x0)*(y0-y1) == x1*y0 - x1*y1 - x0*y0 + x0*y1 for s > 0
338 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0 331 // p = (x0-x1)*(y0-y1) == x0*y0 - x0*y1 - x1*y0 + x1*y1 for s < 0
339 p := z[n*3:] 332 p := z[n*3:]
340 karatsuba(p, xd, yd) 333 karatsuba(p, xd, yd)
341 334
342 // save original z2:z0 335 // save original z2:z0
343 // (ok to use upper half of z since we're done recursing) 336 // (ok to use upper half of z since we're done recursing)
344 r := z[n*4:] 337 r := z[n*4:]
345 copy(r, z) 338 copy(r, z)
346 339
347 // add up all partial products 340 // add up all partial products
348 // 341 //
342 // 2*n n 0
349 // z = [ z2 | z0 ] 343 // z = [ z2 | z0 ]
350 // + [ z0 ] 344 // + [ z0 ]
351 // + [ z2 ] 345 // + [ z2 ]
352 // + [ p ] 346 // + [ p ]
353 // 347 //
354 karatsubaAdd(z[n2:], r, n) 348 karatsubaAdd(z[n2:], r, n)
355 karatsubaAdd(z[n2:], r[n:], n) 349 karatsubaAdd(z[n2:], r[n:], n)
356 if s > 0 { 350 if s > 0 {
357 karatsubaAdd(z[n2:], p, n) 351 karatsubaAdd(z[n2:], p, n)
358 } else { 352 } else {
359 karatsubaSub(z[n2:], p, n) 353 karatsubaSub(z[n2:], p, n)
360 } 354 }
361 } 355 }
362 356
363 357
364 // alias returns true if x and y share the same base array. 358 // alias returns true if x and y share the same base array.
365 func alias(x, y nat) bool { 359 func alias(x, y nat) bool {
366 return &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] 360 return &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
367 } 361 }
368 362
369 363
370 // addAt implements z += x*(1<<(_W*i)); z must be long enough. 364 // addAt implements z += x*(1<<(_W*i)); z must be long enough.
365 // (we don't use nat.add because we need z to stay the same
366 // slice, and we don't need to normalize z after each addition)
371 func addAt(z, x nat, i int) { 367 func addAt(z, x nat, i int) {
372 if n := len(x); n > 0 { 368 if n := len(x); n > 0 {
373 if c := addVV(&z[i], &z[i], &x[0], n); c != 0 { 369 if c := addVV(&z[i], &z[i], &x[0], n); c != 0 {
374 j := i + n 370 j := i + n
375 if j < len(z) { 371 if j < len(z) {
376 addVW(&z[j], &z[j], c, len(z)-j) 372 addVW(&z[j], &z[j], c, len(z)-j)
377 } 373 }
378 } 374 }
379 } 375 }
380 } 376 }
381 377
382 378
379 func max(x, y int) int {
380 if x > y {
381 return x
382 }
383 return y
384 }
385
386
383 func (z nat) mul(x, y nat) nat { 387 func (z nat) mul(x, y nat) nat {
384 m := len(x) 388 m := len(x)
385 n := len(y) 389 n := len(y)
386 390
387 switch { 391 switch {
388 case m < n: 392 case m < n:
389 return z.mul(y, x) 393 return z.mul(y, x)
390 case m == 0 || n == 0: 394 case m == 0 || n == 0:
391 » » return z.make(0, false) 395 » » return z.make(0)
392 case n == 1: 396 case n == 1:
393 return z.mulAddWW(x, y[0], 0) 397 return z.mulAddWW(x, y[0], 0)
394 } 398 }
395 // m >= n > 1 399 // m >= n > 1
396 400
397 // determine if z can be reused 401 // determine if z can be reused
398 if len(z) > 0 && (alias(z, x) || alias(z, y)) { 402 if len(z) > 0 && (alias(z, x) || alias(z, y)) {
399 z = nil // z is an alias for x or y - cannot reuse 403 z = nil // z is an alias for x or y - cannot reuse
400 } 404 }
401 405
402 » if n < karatsubaThreshold { 406 » // use basic multiplication if the numbers are small
403 » » // "grade school" multiplication 407 » if n < karatsubaThreshold || n < 2 {
404 » » z = z.make(m+n, true) 408 » » z = z.make(m + n)
405 » » for i, d := range y { 409 » » basicMul(z, x, y)
406 » » » if d != 0 {
407 » » » » z[m+i] = addMulVVW(&z[i], &x[0], d, m)
408 » » » }
409 » » }
410 return z.norm() 410 return z.norm()
411 } 411 }
412 » // m >= n && n >= karatsubaThreshold 412 » // m >= n && n >= karatsubaThreshold && n >= 2
413 413
414 » // Note that even though we passed the Karatsuba threshold, 414 » // determine largest k such that
415 » // because we tested against n and not k (see below) we may
416 » // still end up using grade-school multiplication, albeit with
417 » // an intermediate step if the Karatsuba theshold is not a
418 » // power of 2. It appears that this intermediate step makes
419 » // things faster (e.g., the threshold is < 256 at the moment).
420 » // Theoretically, there are more operations involved but the numbers
421 » // are larger and thus "internal fragmentation" (i.e., total number
422 » // unused bits in leading words) may be smaller, possibly resulting
423 » // in fewer actual machine multiplications.
424
425 » // Determine k such that:
426 // 415 //
427 // x = x1*b + x0 416 // x = x1*b + x0
428 // y = y1*b + y0 (and k <= len(y), which implies k <= len(x)) 417 // y = y1*b + y0 (and k <= len(y), which implies k <= len(x))
429 //
430 // and
431 //
432 // b = 1<<(_W*k) ("base" of digits xi, yi) 418 // b = 1<<(_W*k) ("base" of digits xi, yi)
433 // 419 //
434 » k := 1 << uint(log2(Word(n))) 420 » // and k is karatsubaThreshold multiplied by a power of 2
435 421 » k := max(karatsubaThreshold, 2)
436 » // If x1 and/or y1 are not 0, compute product explicitly: 422 » for k*2 <= n {
437 » // 423 » » k *= 2
438 » // x*y = x1*y1*b*b + x1*y0*b + x0*y1*b + x0*y0 424 » }
425 » // k <= n
426
427 » // multiply x0 and y0 via Karatsuba
428 » x0 := x[0:k] // x0 is not normalized
429 » y0 := y[0:k] // y0 is not normalized
430 » z = z.make(max(6*k, m+n)) // enough space for karatsuba of x0*y0 and ful l result of x*y
431 » karatsuba(z, x0, y0)
432 » z = z[0 : m+n] // z has final length but may be incomplete, upper portio n is garbage
433
434 » // If x1 and/or y1 are not 0, add missing terms to z explicitly:
435 » //
436 » // m+n 2*k 0
437 » // z = [ ... | x0*y0 ]
438 » // + [ x1*y1 ]
439 » // + [ x1*y0 ]
440 » // + [ x0*y1 ]
439 // 441 //
440 if k < n || m != n { 442 if k < n || m != n {
441 » » x1, x0 := x[k:], x[0:k].norm() // x1 is normalized because x is 443 » » x1 := x[k:] // x1 is normalized because x is
442 » » y1, y0 := y[k:], y[0:k].norm() // y1 is normalized because y is 444 » » y1 := y[k:] // y1 is normalized because y is
443 var t nat 445 var t nat
444 z = z.make(m+n, true)
445 t = t.mul(x1, y1) 446 t = t.mul(x1, y1)
446 » » copy(z[2*k:], t) // z may not be normalized! 447 » » copy(z[2*k:], t)
447 » » t = t.mul(x1, y0) 448 » » z[2*k+len(t):].clear() // upper portion of z is garbage
449 » » t = t.mul(x1, y0.norm())
448 addAt(z, t, k) 450 addAt(z, t, k)
449 » » t = t.mul(x0, y1) 451 » » t = t.mul(x0.norm(), y1)
450 addAt(z, t, k) 452 addAt(z, t, k)
451 » » t = t.mul(x0, y0) 453 » }
452 » » addAt(z, t, 0) // (could invoke karatsuba for x0, y0 directly) 454
453 » » return z.norm() 455 » return z.norm()
454 » }
455 » // k == n && m == n
456
457 » // Both x and y have the same length k which is a power of 2
458 » // and thus are directly suitable for Karatsuba multiplication.
459 » z = z.make(6*k, false)
460 » karatsuba(z, x, y)
461 » return z[0 : 2*n].norm()
462 } 456 }
463 457
464 458
465 // mulRange computes the product of all the unsigned integers in the 459 // mulRange computes the product of all the unsigned integers in the
466 // range [a, b] inclusively. If a > b (empty range), the result is 1. 460 // range [a, b] inclusively. If a > b (empty range), the result is 1.
467 func (z nat) mulRange(a, b uint64) nat { 461 func (z nat) mulRange(a, b uint64) nat {
adonovan 2020/02/12 14:46:30 What's the purpose of treating the sequence as lea
gri 2020/02/19 00:51:24 The recursive approach appears faster in practice
468 switch { 462 switch {
469 case a == 0: 463 case a == 0:
470 // cut long ranges short (optimization) 464 // cut long ranges short (optimization)
471 return z.new(0) 465 return z.new(0)
472 case a > b: 466 case a > b:
473 return z.new(1) 467 return z.new(1)
474 case a == b: 468 case a == b:
475 return z.new(a) 469 return z.new(a)
476 case a+1 == b: 470 case a+1 == b:
477 return z.mul(nat(nil).new(a), nat(nil).new(b)) 471 return z.mul(nat(nil).new(a), nat(nil).new(b))
(...skipping 10 matching lines...) Expand all
488 case y == 0: 482 case y == 0:
489 panic("division by zero") 483 panic("division by zero")
490 case y == 1: 484 case y == 1:
491 q = z.set(x) // result is x 485 q = z.set(x) // result is x
492 return 486 return
493 case m == 0: 487 case m == 0:
494 q = z.set(nil) // result is 0 488 q = z.set(nil) // result is 0
495 return 489 return
496 } 490 }
497 // m > 0 491 // m > 0
498 » z = z.make(m, false) 492 » z = z.make(m)
499 r = divWVW(&z[0], 0, &x[0], y, m) 493 r = divWVW(&z[0], 0, &x[0], y, m)
500 q = z.norm() 494 q = z.norm()
501 return 495 return
502 } 496 }
503 497
504 498
505 func (z nat) div(z2, u, v nat) (q, r nat) { 499 func (z nat) div(z2, u, v nat) (q, r nat) {
506 if len(v) == 0 { 500 if len(v) == 0 {
507 panic("division by zero") 501 panic("division by zero")
508 } 502 }
509 503
510 if u.cmp(v) < 0 { 504 if u.cmp(v) < 0 {
511 » » q = z.make(0, false) 505 » » q = z.make(0)
512 r = z2.set(u) 506 r = z2.set(u)
513 return 507 return
514 } 508 }
515 509
516 if len(v) == 1 { 510 if len(v) == 1 {
517 var rprime Word 511 var rprime Word
518 q, rprime = z.divW(u, v[0]) 512 q, rprime = z.divW(u, v[0])
519 if rprime > 0 { 513 if rprime > 0 {
520 » » » r = z2.make(1, false) 514 » » » r = z2.make(1)
521 r[0] = rprime 515 r[0] = rprime
522 } else { 516 } else {
523 » » » r = z2.make(0, false) 517 » » » r = z2.make(0)
524 } 518 }
525 return 519 return
526 } 520 }
527 521
528 q, r = z.divLarge(z2, u, v) 522 q, r = z.divLarge(z2, u, v)
529 return 523 return
530 } 524 }
531 525
532 526
533 // q = (uIn-r)/v, with 0 <= r < y 527 // q = (uIn-r)/v, with 0 <= r < y
534 // See Knuth, Volume 2, section 4.3.1, Algorithm D. 528 // See Knuth, Volume 2, section 4.3.1, Algorithm D.
535 // Preconditions: 529 // Preconditions:
536 // len(v) >= 2 530 // len(v) >= 2
537 // len(uIn) >= len(v) 531 // len(uIn) >= len(v)
538 func (z nat) divLarge(z2, uIn, v nat) (q, r nat) { 532 func (z nat) divLarge(z2, uIn, v nat) (q, r nat) {
539 n := len(v) 533 n := len(v)
540 m := len(uIn) - len(v) 534 m := len(uIn) - len(v)
541 535
542 var u nat 536 var u nat
543 if z2 == nil || &z2[0] == &uIn[0] { 537 if z2 == nil || &z2[0] == &uIn[0] {
544 » » u = u.make(len(uIn)+1, true) // uIn is an alias for z2 538 » » u = u.make(len(uIn) + 1).clear() // uIn is an alias for z2
545 } else { 539 } else {
546 » » u = z2.make(len(uIn)+1, true) 540 » » u = z2.make(len(uIn) + 1).clear()
547 } 541 }
548 qhatv := make(nat, len(v)+1) 542 qhatv := make(nat, len(v)+1)
549 » q = z.make(m+1, false) 543 » q = z.make(m + 1)
550 544
551 // D1. 545 // D1.
552 shift := uint(leadingZeroBits(v[n-1])) 546 shift := uint(leadingZeroBits(v[n-1]))
553 v.shiftLeft(v, shift) 547 v.shiftLeft(v, shift)
554 u.shiftLeft(uIn, shift) 548 u.shiftLeft(uIn, shift)
555 u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift)) 549 u[len(uIn)] = uIn[len(uIn)-1] >> (_W - uint(shift))
556 550
557 // D2. 551 // D2.
558 for j := m; j >= 0; j-- { 552 for j := m; j >= 0; j-- {
559 // D3. 553 // D3.
(...skipping 209 matching lines...) Expand 10 before | Expand all | Expand 10 after
769 case 64: 763 case 64:
770 return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58]) 764 return int(deBruijn64Lookup[((x&-x)*(deBruijn64&_M))>>58])
771 default: 765 default:
772 panic("Unknown word size") 766 panic("Unknown word size")
773 } 767 }
774 768
775 return 0 769 return 0
776 } 770 }
777 771
778 772
773 // TODO(gri) Make the shift routines faster.
774 // Use pidigits.go benchmark as a test case.
775
779 // To avoid losing the top n bits, z should be sized so that 776 // To avoid losing the top n bits, z should be sized so that
780 // len(z) == len(x) + 1. 777 // len(z) == len(x) + 1.
781 func (z nat) shiftLeft(x nat, n uint) nat { 778 func (z nat) shiftLeft(x nat, n uint) nat {
782 if len(x) == 0 { 779 if len(x) == 0 {
783 return x 780 return x
784 } 781 }
785 782
786 ñ := _W - n 783 ñ := _W - n
787 m := x[len(x)-1] 784 m := x[len(x)-1]
788 if len(z) > len(x) { 785 if len(z) > len(x) {
(...skipping 27 matching lines...) Expand all
816 813
817 814
818 // greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2) 815 // greaterThan returns true iff (x1<<_W + x2) > (y1<<_W + y2)
819 func greaterThan(x1, x2, y1, y2 Word) bool { return x1 > y1 || x1 == y1 && x2 > y2 } 816 func greaterThan(x1, x2, y1, y2 Word) bool { return x1 > y1 || x1 == y1 && x2 > y2 }
820 817
821 818
822 // modW returns x % d. 819 // modW returns x % d.
823 func (x nat) modW(d Word) (r Word) { 820 func (x nat) modW(d Word) (r Word) {
824 // TODO(agl): we don't actually need to store the q value. 821 // TODO(agl): we don't actually need to store the q value.
825 var q nat 822 var q nat
826 » q = q.make(len(x), false) 823 » q = q.make(len(x))
827 return divWVW(&q[0], 0, &x[0], d, len(x)) 824 return divWVW(&q[0], 0, &x[0], d, len(x))
828 } 825 }
829 826
830 827
831 // powersOfTwoDecompose finds q and k such that q * 1<<k = n and q is odd. 828 // powersOfTwoDecompose finds q and k such that q * 1<<k = n and q is odd.
832 func (n nat) powersOfTwoDecompose() (q nat, k Word) { 829 func (n nat) powersOfTwoDecompose() (q nat, k Word) {
833 if len(n) == 0 { 830 if len(n) == 0 {
834 return n, 0 831 return n, 0
835 } 832 }
836 833
837 zeroWords := 0 834 zeroWords := 0
838 for n[zeroWords] == 0 { 835 for n[zeroWords] == 0 {
839 zeroWords++ 836 zeroWords++
840 } 837 }
841 // One of the words must be non-zero by invariant, therefore 838 // One of the words must be non-zero by invariant, therefore
842 // zeroWords < len(n). 839 // zeroWords < len(n).
843 x := trailingZeroBits(n[zeroWords]) 840 x := trailingZeroBits(n[zeroWords])
844 841
845 » q = q.make(len(n)-zeroWords, false) 842 » q = q.make(len(n) - zeroWords)
846 q.shiftRight(n[zeroWords:], uint(x)) 843 q.shiftRight(n[zeroWords:], uint(x))
847 q = q.norm() 844 q = q.norm()
848 845
849 k = Word(_W*zeroWords + x) 846 k = Word(_W*zeroWords + x)
850 return 847 return
851 } 848 }
852 849
853 850
854 // random creates a random integer in [0..limit), using the space in z if 851 // random creates a random integer in [0..limit), using the space in z if
855 // possible. n is the bit length of limit. 852 // possible. n is the bit length of limit.
856 func (z nat) random(rand *rand.Rand, limit nat, n int) nat { 853 func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
857 bitLengthOfMSW := uint(n % _W) 854 bitLengthOfMSW := uint(n % _W)
858 if bitLengthOfMSW == 0 { 855 if bitLengthOfMSW == 0 {
859 bitLengthOfMSW = _W 856 bitLengthOfMSW = _W
860 } 857 }
861 mask := Word((1 << bitLengthOfMSW) - 1) 858 mask := Word((1 << bitLengthOfMSW) - 1)
862 » z = z.make(len(limit), false) 859 » z = z.make(len(limit))
863 860
864 for { 861 for {
865 for i := range z { 862 for i := range z {
866 switch _W { 863 switch _W {
867 case 32: 864 case 32:
868 z[i] = Word(rand.Uint32()) 865 z[i] = Word(rand.Uint32())
869 case 64: 866 case 64:
870 z[i] = Word(rand.Uint32()) | Word(rand.Uint32()) <<32 867 z[i] = Word(rand.Uint32()) | Word(rand.Uint32()) <<32
871 } 868 }
872 } 869 }
873 870
874 z[len(limit)-1] &= mask 871 z[len(limit)-1] &= mask
875 872
876 if z.cmp(limit) < 0 { 873 if z.cmp(limit) < 0 {
877 break 874 break
878 } 875 }
879 } 876 }
880 877
881 return z.norm() 878 return z.norm()
882 } 879 }
883 880
884 881
885 // If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It 882 // If m != nil, expNN calculates x**y mod m. Otherwise it calculates x**y. It
886 // reuses the storage of z if possible. 883 // reuses the storage of z if possible.
887 func (z nat) expNN(x, y, m nat) nat { 884 func (z nat) expNN(x, y, m nat) nat {
888 if len(y) == 0 { 885 if len(y) == 0 {
889 » » z = z.make(1, false) 886 » » z = z.make(1)
890 z[0] = 1 887 z[0] = 1
891 return z 888 return z
892 } 889 }
893 890
894 if m != nil { 891 if m != nil {
895 // We likely end up being as long as the modulus. 892 // We likely end up being as long as the modulus.
896 » » z = z.make(len(m), false) 893 » » z = z.make(len(m))
897 } 894 }
898 z = z.set(x) 895 z = z.set(x)
899 v := y[len(y)-1] 896 v := y[len(y)-1]
900 // It's invalid for the most significant word to be zero, therefore we 897 // It's invalid for the most significant word to be zero, therefore we
901 // will find a one bit. 898 // will find a one bit.
902 shift := leadingZeros(v) + 1 899 shift := leadingZeros(v) + 1
903 v <<= shift 900 v <<= shift
904 var q nat 901 var q nat
905 902
906 const mask = 1 << (_W - 1) 903 const mask = 1 << (_W - 1)
(...skipping 125 matching lines...) Expand 10 before | Expand all | Expand 10 after
1032 } 1029 }
1033 if y.cmp(natOne) == 0 { 1030 if y.cmp(natOne) == 0 {
1034 return false 1031 return false
1035 } 1032 }
1036 } 1033 }
1037 return false 1034 return false
1038 } 1035 }
1039 1036
1040 return true 1037 return true
1041 } 1038 }
LEFTRIGHT

Powered by Google App Engine
RSS Feeds Recent Issues | This issue
This is Rietveld f62528b