/*--------------------------------------------------------------------
This source distribution is placed in the public domain by its author,
Jason Papadopoulos. You may use it for any purpose, free of charge,
without having to notify anyone. I disclaim any responsibility for any
errors.
Optionally, please be nice and tell me if you find this source to be
useful. Again optionally, if you add to the functionality present here
please consider making those additions public too, so that others may
benefit from your work.
--jasonp@boo.net 6/3/07
--------------------------------------------------------------------*/
#ifndef _MP_H_
#define _MP_H_
#include <util.h>
#ifdef __cplusplus
extern "C" {
#endif
/* Basic multiple-precision arithmetic implementation. Precision
is hardwired not to exceed ~164 digits. Numbers are stored in
two's-complement binary form, in little-endian word order.
All inputs and results are assumed positive, and the high-order
words that are not in use must be zero for all input operands.
The array of bits for the number is always composed of 32-bit
words. This is because I want things to be portable and there's
no support in C for 128-bit data types, so that 64x64 multiplies
and 128/64 divides would need assembly language support */
#define MAX_MP_WORDS 17
#define MP_RADIX 4294967296.0
typedef struct {
uint32 nwords; /* number of nonzero words in val[] */
uint32 val[MAX_MP_WORDS];
} mp_t;
/* signed multiple-precision integers */
#ifndef POSITIVE
#define POSITIVE 0
#endif
#ifndef NEGATIVE
#define NEGATIVE 1
#endif
typedef struct {
uint32 sign; /* POSITIVE or NEGATIVE */
mp_t num;
} signed_mp_t;
/* initialize an mp_t */
static INLINE void mp_clear(mp_t *a) {
memset(a, 0, sizeof(mp_t));
}
static INLINE void mp_copy(mp_t *a, mp_t *b) {
*b = *a;
}
/* return the number of bits needed to hold an mp_t.
This is equivalent to floor(log2(a)) + 1. */
uint32 mp_bits(mp_t *a);
/* approximate the logarithm of an mp_t */
double mp_log(mp_t *x);
/* Addition and subtraction; a + b = sum
or a - b = diff. 'b' may be an integer or
another mp_t. sum or diff may overwrite
the input operands */
void mp_add(mp_t *a, mp_t *b, mp_t *sum);
void mp_add_1(mp_t *a, uint32 b, mp_t *sum);
void mp_sub(mp_t *a, mp_t *b, mp_t *diff);
void mp_sub_1(mp_t *a, uint32 b, mp_t *diff);
/* return -1, 0, or 1 if a is less than, equal to,
or greater than b, respectively */
static INLINE int32 mp_cmp(const mp_t *a, const mp_t *b) {
int32 i;
if (a->nwords > b->nwords)
return 1;
if (a->nwords < b->nwords)
return -1;
for (i = a->nwords - 1; i >= 0; i--) {
if (a->val[i] > b->val[i])
return 1;
if (a->val[i] < b->val[i])
return -1;
}
return 0;
}
/* quick test for zero or one mp_t */
#define mp_is_zero(a) ((a)->nwords == 0)
#define mp_is_one(a) ((a)->nwords == 1 && (a)->val[0] == 1)
/* Shift 'a' right by 'shift' bit positions.
The result may overwrite 'a'. shift amount
must not exceed 32*MAX_MP_WORDS */
void mp_rshift(mp_t *a, uint32 shift, mp_t *res);
/* Right-shift 'a' by an amount equal to the
number of trailing zeroes. Return the shift count */
uint32 mp_rjustify(mp_t *a, mp_t *res);
/* multiply a by b. 'b' is either a 1-word
operand or an mp_t. In the latter case,
the product must fit in MAX_MP_WORDS words
and may not overwrite the input operands. */
void mp_mul(mp_t *a, mp_t *b, mp_t *prod);
void mp_mul_1(mp_t *a, uint32 b, mp_t *x);
/* divide a 64-bit input by a 32-bit input,
and return the remainder. The quotient must
not exceed 2^32 */
static INLINE uint32 mp_mod64(uint64 a, uint32 n) {
uint32 ans;
#if defined(__GNUC__) && (defined(__i386__) || defined(__x86_64__))
asm("divl %3 \n\t"
: "=d"(ans)
: "a"((uint32)(a)), "0"((uint32)(a >> 32)), "g"(n) : "cc");
#elif defined(_MSC_VER) && !defined(_WIN64)
__asm
{
lea ecx,a
mov eax,[ecx]
mov edx,[ecx+4]
div n
mov ans,edx
}
#else
ans = (uint32)(a % n);
#endif
return ans;
}
/* modular multiplication: compute 'a' * 'b' mod 'n'.
Multiple precision operands can have up to
MAX_MP_WORDS words each; the out 'res' can alias
a or b but not n */
void mp_modmul(mp_t *a, mp_t *b, mp_t *n, mp_t *res);
static INLINE uint32 mp_modmul_1(uint32 a, uint32 b, uint32 n) {
uint64 acc = (uint64)a * (uint64)b;
return mp_mod64(acc, n);
}
/* General-purpose division routines. mp_divrem
divides num by denom, putting the quotient in
quot (if not NULL) and the remainder in rem
(if not NULL). No aliasing is allowed */
void mp_divrem(mp_t *num, mp_t *denom, mp_t *quot, mp_t *rem);
#define mp_div(n, d, q) mp_divrem(n, d, q, NULL)
#define mp_mod(n, d, rem) mp_divrem(n, d, NULL, rem)
/* Division routine where the denominator is a
single word. The quotient is written to quot
(if not NULL) and the remainder is returned.
quot may overwrite the input */
uint32 mp_divrem_1(mp_t *num, uint32 denom, mp_t *quot);
/* Divide an mp_t by a single word and return the
remainder */
static INLINE uint32 mp_mod_1(mp_t *num, uint32 denom) {
int32 i = num->nwords - 1;
uint32 rem = 0;
if (num->val[i] < denom) {
rem = num->val[i--];
}
#if defined(__GNUC__) && (defined(__i386__) || defined(__x86_64__))
while (i >= 0) {
asm("divl %3"
: "=d"(rem)
: "0"(rem), "a"(num->val[i]), "r"(denom) : "cc" );
i--;
}
#elif defined(_MSC_VER) && !defined(_WIN64)
__asm
{
push ebx
push esi
mov ecx,i
or ecx,ecx
jl L1
mov ebx,denom
mov esi,num
lea esi,[esi]num.val
mov edx,rem
L0: mov eax,[esi+4*ecx]
div ebx
dec ecx
jge L0
mov rem,edx
L1: pop esi
pop ebx
}
#else
while (i >= 0) {
uint64 acc = (uint64)rem << 32 | (uint64)num->val[i];
rem = (uint32)(acc % denom);
i--;
}
#endif
return rem;
}
/* Calculate floor(i_th root of 'a'). The return value
is zero if res is an exact i_th root of 'a'. 'estimate'
is an estimate of the answer, used to reduce the
latency of Newton iteration. If not NULL, it's
recommended that the estimate be an accurate one :) */
uint32 mp_iroot(mp_t *a, mp_t *estimate, uint32 i, mp_t *res);
#define mp_isqrt(a, res) mp_iroot(a, NULL, 2, res)
/* Calculate greatest common divisor of x and y.
Any quantities may alias */
void mp_gcd(mp_t *x, mp_t *y, mp_t *out);
static INLINE uint32 mp_gcd_1(uint32 x, uint32 y) {
uint32 tmp;
if (y < x) {
tmp = x; x = y; y = tmp;
}
while (y > 0) {
x = x % y;
tmp = x; x = y; y = tmp;
}
return x;
}
/* Print routines: print the input mp_t to a file
(if f is not NULL) and also return a pointer to
a string representation of the input (requires
sufficient scratch space to be passed in). The
input is printed in radix 'base' (2 to 36). The
maximum required size for scratch space is
32*MAX_MP_WORDS+1 bytes (i.e. enough to print
out 'a' in base 2) */
char * mp_print(mp_t *a, uint32 base, FILE *f, char *scratch);
#define mp_printf(a, base, scratch) mp_print(a, base, stdout, scratch)
#define mp_fprintf(a, base, f, scratch) mp_print(a, base, f, scratch)
#define mp_sprintf(a, base, scratch) mp_print(a, base, NULL, scratch)
/* A multiple-precision version of strtoul(). The
string 'str' is converted from an ascii representation
of radix 'base' (2 to 36) into an mp_t. If base is 0,
the base is assumed to be 16 if the number is preceded
by "0x", 8 if preceded by "0", and 10 otherwise. Con-
version stops at the first character that cannot belong
to radix 'base', or otherwise at the terminating NULL.
The input is case insensitive. */
void mp_str2mp(char *str, mp_t *a, uint32 base);
/* modular exponentiation: raise 'a' to the power 'b'
mod 'n' and return the result. a and b may exceed n.
In the multiple precision case, the result may not
alias any of the inputs */
void mp_expo(mp_t *a, mp_t *b, mp_t *n, mp_t *res);
/* ordinary exponentiation: raise 'a' to the power 'b'
and return the result. The result may not alias
any of the inputs, and must fit in an mp_t */
void mp_pow(mp_t *a, mp_t *b, mp_t *res);
static INLINE uint32 mp_expo_1(uint32 a, uint32 b, uint32 n) {
uint32 res = 1;
while (b) {
if (b & 1)
res = mp_modmul_1(res, a, n);
a = mp_modmul_1(a, a, n);
b = b >> 1;
}
return res;
}
/* Return the Legendre symbol for 'a'. This is 1 if
x * x = a (mod p) is solvable for some x, -1 if not,
and 0 if a and p have factors in common. p must be
an odd prime, and a may exceed p */
int32 mp_legendre_1(uint32 a, uint32 p);
int32 mp_legendre(mp_t *a, mp_t *p);
/* Find an inverse of 'a' modulo prime 'p', i.e. the number
x such that a * x mod p is 1. The routine assumes that
a will never exceed p */
uint32 mp_modinv_1(uint32 a, uint32 p);
/* For odd prime p, solve 'x * x = a (mod p)' for x and
return the result. Assumes legendre(a,p) = 1 (this is
not verified).
This and the next few routines use random numbers, but
since they are intended to be 'stateless' the state of the
random number generator is passed in as 'seed1' and 'seed2'.
This state is updated as random numbers are produced */
uint32 mp_modsqrt_1(uint32 a, uint32 p);
void mp_modsqrt(mp_t *a, mp_t *p, mp_t *res, uint32 *seed1, uint32 *seed2);
void mp_modsqrt2(mp_t *a, mp_t *p, mp_t *res, uint32 *seed1, uint32 *seed2);
/* Generate a random number between 0 and 2^bits - 1 */
void mp_rand(uint32 bits, mp_t *res, uint32 *seed1, uint32 *seed2);
/* mp_is_prime returns 1 if the input is prime and 0
otherwise. mp_random_prime generates a random number
between 0 and 2^bits - 1 which is probably prime.
mp_next_prime computes the next number greater than p
which is prime, and returns (res - p). The probability
of these routines accidentally declaring a composite
to be prime is < 4 ^ -NUM_WITNESSES, and probably is
drastically smaller than that */
#define NUM_WITNESSES 20
int32 mp_is_prime(mp_t *p, uint32 *seed1, uint32 *seed2);
void mp_random_prime(uint32 bits, mp_t *res, uint32 *seed1, uint32 *seed2);
uint32 mp_next_prime(mp_t *p, mp_t *res, uint32 *seed1, uint32 *seed2);
/* Modular addition/subtraction: compute a +- b mod p */
static INLINE uint32 mp_modsub_1(uint32 a, uint32 b, uint32 p) {
#if defined(__GNUC__) && defined(__i386__) && defined(HAVE_CMOV)
uint32 ans;
asm("xorl %%edx, %%edx \n\t"
"subl %2, %0 \n\t"
"cmovbl %3, %%edx \n\t"
"addl %%edx, %0 \n\t"
: "=r"(ans)
: "0"(a), "g"(b), "g"(p) : "%edx", "cc");
return ans;
#elif defined(_MSC_VER) && !defined(_WIN64) && defined(HAVE_CMOV)
uint32 ans;
__asm
{
mov eax,a
mov ecx,b
xor edx,edx
sub eax,ecx
cmovb edx,p
add eax,edx
mov ans,eax
}
return ans;
#else
if (a >= b)
return a - b;
else
return a - b + p;
#endif
}
static INLINE uint32 mp_modadd_1(uint32 a, uint32 b, uint32 p) {
return mp_modsub_1(a, p - b, p);
}
/* conversion to/from doubles. Note that the maximum
exponent in a double cannot accurately represent
an mp_t that is sufficiently large. In order to get
any precision in the mantissa at all, the input
should have under 100 digits */
static INLINE double mp_mp2d(mp_t *x) {
/* convert a multiple-precision number to a double */
uint32 i = x->nwords;
switch(i) {
case 0:
return 0;
case 1:
return (double)(x->val[0]);
case 2:
return (double)(x->val[0]) +
MP_RADIX * x->val[1];
case 3:
return (double)(x->val[0]) + MP_RADIX *
((double)x->val[1] + MP_RADIX * x->val[2]);
default:
return ((double)(x->val[i-3]) + MP_RADIX *
((double)x->val[i-2] + MP_RADIX * x->val[i-1])) *
pow(2.0, 32.0 * (i - 3));
}
}
static INLINE double mp_signed_mp2d(signed_mp_t *x) {
if (x->sign == POSITIVE)
return mp_mp2d(&x->num);
else
return -mp_mp2d(&x->num);
}
static INLINE void mp_d2mp(double *d, mp_t *x) {
int32 i;
int32 exponent;
uint64 int_mant;
/* convert a double to a multiple-precision integer.
It's assumed the double represents an integer
Note that a pointer to d, and not d itself, is
*required* for PowerPC builds (at least) to work */
mp_clear(x);
/* cut up the double into mantissa and exponent.
Reading it in as a uint64 makes this process
endian-independent */
int_mant = *(uint64 *)(d);
exponent = ((int32)(int_mant >> 52) & 0x7ff) - 1023;
int_mant &= ~((uint64)(0xfff) << 52);
int_mant |= (uint64)(1) << 52;
/* insert the bits of the mantissa into the multi-precision
array. First shift away any fractional bits, then place
in the array of zero bits */
if (exponent < 0) {
return;
}
else if (exponent <= 52) {
int_mant = int_mant >> (52 - exponent);
exponent = 0;
}
else {
exponent -= 52;
}
if (int_mant == 0)
return;
i = exponent / 32;
if (exponent % 32 == 0) {
x->val[i] = (uint32)int_mant;
x->val[i+1] = (uint32)(int_mant >> 32);
}
else {
x->val[i] = (uint32)(int_mant << (exponent % 32));
x->val[i+1] = (uint32)(int_mant >> (32 - (exponent % 32)));
x->val[i+2] = (uint32)(int_mant >> (64 - (exponent % 32)));
}
if (x->val[i+2] != 0)
x->nwords = i+3;
else if (x->val[i+1] != 0)
x->nwords = i+2;
else
x->nwords = i+1;
}
#ifdef __cplusplus
}
#endif
#endif /* !_MP_H_ */
syntax highlighted by Code2HTML, v. 0.9.1