URI:
       tmpmul.c - plan9port - [fork] Plan 9 from user space
  HTML git clone git://src.adamsgaard.dk/plan9port
   DIR Log
   DIR Files
   DIR Refs
   DIR README
   DIR LICENSE
       ---
       tmpmul.c (3182B)
       ---
            1 #include "os.h"
            2 #include <mp.h>
            3 #include "dat.h"
            4 
            5 /* */
            6 /*  from knuth's 1969 seminumberical algorithms, pp 233-235 and pp 258-260 */
            7 /* */
            8 /*  mpvecmul is an assembly language routine that performs the inner */
            9 /*  loop. */
           10 /* */
           11 /*  the karatsuba trade off is set empiricly by measuring the algs on */
           12 /*  a 400 MHz Pentium II. */
           13 /* */
           14 
           15 /* karatsuba like (see knuth pg 258) */
           16 /* prereq: p is already zeroed */
           17 static void
           18 mpkaratsuba(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
           19 {
           20         mpdigit *t, *u0, *u1, *v0, *v1, *u0v0, *u1v1, *res, *diffprod;
           21         int u0len, u1len, v0len, v1len, reslen;
           22         int sign, n;
           23 
           24         /* divide each piece in half */
           25         n = alen/2;
           26         if(alen&1)
           27                 n++;
           28         u0len = n;
           29         u1len = alen-n;
           30         if(blen > n){
           31                 v0len = n;
           32                 v1len = blen-n;
           33         } else {
           34                 v0len = blen;
           35                 v1len = 0;
           36         }
           37         u0 = a;
           38         u1 = a + u0len;
           39         v0 = b;
           40         v1 = b + v0len;
           41 
           42         /* room for the partial products */
           43         t = mallocz(Dbytes*5*(2*n+1), 1);
           44         if(t == nil)
           45                 sysfatal("mpkaratsuba: %r");
           46         u0v0 = t;
           47         u1v1 = t + (2*n+1);
           48         diffprod = t + 2*(2*n+1);
           49         res = t + 3*(2*n+1);
           50         reslen = 4*n+1;
           51 
           52         /* t[0] = (u1-u0) */
           53         sign = 1;
           54         if(mpveccmp(u1, u1len, u0, u0len) < 0){
           55                 sign = -1;
           56                 mpvecsub(u0, u0len, u1, u1len, u0v0);
           57         } else
           58                 mpvecsub(u1, u1len, u0, u1len, u0v0);
           59 
           60         /* t[1] = (v0-v1) */
           61         if(mpveccmp(v0, v0len, v1, v1len) < 0){
           62                 sign *= -1;
           63                 mpvecsub(v1, v1len, v0, v1len, u1v1);
           64         } else
           65                 mpvecsub(v0, v0len, v1, v1len, u1v1);
           66 
           67         /* t[4:5] = (u1-u0)*(v0-v1) */
           68         mpvecmul(u0v0, u0len, u1v1, v0len, diffprod);
           69 
           70         /* t[0:1] = u1*v1 */
           71         memset(t, 0, 2*(2*n+1)*Dbytes);
           72         if(v1len > 0)
           73                 mpvecmul(u1, u1len, v1, v1len, u1v1);
           74 
           75         /* t[2:3] = u0v0 */
           76         mpvecmul(u0, u0len, v0, v0len, u0v0);
           77 
           78         /* res = u0*v0<<n + u0*v0 */
           79         mpvecadd(res, reslen, u0v0, u0len+v0len, res);
           80         mpvecadd(res+n, reslen-n, u0v0, u0len+v0len, res+n);
           81 
           82         /* res += u1*v1<<n + u1*v1<<2*n */
           83         if(v1len > 0){
           84                 mpvecadd(res+n, reslen-n, u1v1, u1len+v1len, res+n);
           85                 mpvecadd(res+2*n, reslen-2*n, u1v1, u1len+v1len, res+2*n);
           86         }
           87 
           88         /* res += (u1-u0)*(v0-v1)<<n */
           89         if(sign < 0)
           90                 mpvecsub(res+n, reslen-n, diffprod, u0len+v0len, res+n);
           91         else
           92                 mpvecadd(res+n, reslen-n, diffprod, u0len+v0len, res+n);
           93         memmove(p, res, (alen+blen)*Dbytes);
           94 
           95         free(t);
           96 }
           97 
           98 #define KARATSUBAMIN 32
           99 
          100 void
          101 mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p)
          102 {
          103         int i;
          104         mpdigit d;
          105         mpdigit *t;
          106 
          107         /* both mpvecdigmuladd and karatsuba are fastest when a is the longer vector */
          108         if(alen < blen){
          109                 i = alen;
          110                 alen = blen;
          111                 blen = i;
          112                 t = a;
          113                 a = b;
          114                 b = t;
          115         }
          116         if(blen == 0){
          117                 memset(p, 0, Dbytes*(alen+blen));
          118                 return;
          119         }
          120 
          121         if(alen >= KARATSUBAMIN && blen > 1){
          122                 /* O(n^1.585) */
          123                 mpkaratsuba(a, alen, b, blen, p);
          124         } else {
          125                 /* O(n^2) */
          126                 for(i = 0; i < blen; i++){
          127                         d = b[i];
          128                         if(d != 0)
          129                                 mpvecdigmuladd(a, alen, d, &p[i]);
          130                 }
          131         }
          132 }
          133 
          134 void
          135 mpmul(mpint *b1, mpint *b2, mpint *prod)
          136 {
          137         mpint *oprod;
          138 
          139         oprod = nil;
          140         if(prod == b1 || prod == b2){
          141                 oprod = prod;
          142                 prod = mpnew(0);
          143         }
          144 
          145         prod->top = 0;
          146         mpbits(prod, (b1->top+b2->top+1)*Dbits);
          147         mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p);
          148         prod->top = b1->top+b2->top+1;
          149         prod->sign = b1->sign*b2->sign;
          150         mpnorm(prod);
          151 
          152         if(oprod != nil){
          153                 mpassign(prod, oprod);
          154                 mpfree(prod);
          155         }
          156 }