#include <NTL/ZZ.h>
#include <NTL/vec_ZZ.h>
#include <NTL/new.h>
NTL_START_IMPL
long ZZ::HexOutput = 0;
const ZZ& ZZ::zero()
{
static ZZ z;
return z;
}
const ZZ& ZZ_expo(long e)
{
static ZZ expo_helper;
conv(expo_helper, e);
return expo_helper;
}
void AddMod(ZZ& x, const ZZ& a, long b, const ZZ& n)
{
static ZZ B;
conv(B, b);
AddMod(x, a, B, n);
}
void SubMod(ZZ& x, const ZZ& a, long b, const ZZ& n)
{
static ZZ B;
conv(B, b);
SubMod(x, a, B, n);
}
void SubMod(ZZ& x, long a, const ZZ& b, const ZZ& n)
{
static ZZ A;
conv(A, a);
SubMod(x, A, b, n);
}
// ****** input and output
static long iodigits = 0;
static long ioradix = 0;
// iodigits is the greatest integer such that 10^{iodigits} < NTL_WSP_BOUND
// ioradix = 10^{iodigits}
static void InitZZIO()
{
long x;
x = (NTL_WSP_BOUND-1)/10;
iodigits = 0;
ioradix = 1;
while (x) {
x = x / 10;
iodigits++;
ioradix = ioradix * 10;
}
if (iodigits <= 0) Error("problem with I/O");
}
static long HexTwoChars(long &byte, istream& s)
{
byte = CharToIntVal(s.peek());
if (byte < 0 || byte > 15)
return 0; // didn't read anything from stream
s.get();
long c2 = CharToIntVal(s.peek());
if (c2 < 0 || c2 > 15)
return 1; // read only one character
s.get();
byte <<= 4;
byte += c2;
return 2; // read two characters
}
static void HexReadFromStream(istream& s, ZZ& a)
{
const long bufLen = 256; // read upto 256 bytes at a time
static ZZ b;
static unsigned char buf[bufLen];
long nRead;
long bufIdx = bufLen-1;
long c = CharToIntVal(s.peek());
if (c < 0 || c > 15) Error("HexReadFromStream: bad ZZ input");
a = 0;
while ((nRead=HexTwoChars(c,s)) == 2) { // read next 1-2 chars from stream
buf[bufIdx] = (unsigned char) c;
if (bufIdx == 0) { // buffer is full, process it
ZZFromBytes(b, buf, bufLen);
a <<= bufLen*8; // shift a to left, then add b
a += b;
bufIdx = bufLen-1;
}
else bufIdx--;
}
if (bufIdx < bufLen-1) { // get leftovers from buffer
long nBytes = bufLen-bufIdx-1;
ZZFromBytes(b, &buf[bufIdx+1], nBytes);
a <<= nBytes*8;
a += b;
}
if (nRead == 1) { // one more character to process
a <<= 4;
a += c;
}
}
istream& operator>>(istream& s, ZZ& x)
{
long c;
long cval;
long sign;
long ndigits;
long acc;
static ZZ a;
if (!s) Error("bad ZZ input (no stream found)");
if (!iodigits) InitZZIO();
SkipWhiteSpace(s);
c = s.peek();
if (c == '-') {
sign = -1;
s.get();
c = s.peek();
}
else
sign = 1;
cval = CharToIntVal(c);
if (cval < 0 || cval > 9) Error("bad ZZ input (first digit not int [0-9])");
if (c == '0') { // check if the next char is 'x', else ignore leading 0
s.get();
c = s.peek();
if (c == 'x' || c == 'X') { // hexadecimal number
s.get();
HexReadFromStream(s, a);
if (sign == -1) negate(a, a);
x = a;
return s;
}
else {
cval = CharToIntVal(c);
if (cval < 0 || cval > 9) { // just one zero
clear(x);
return s;
}
}
// if the next digit is in 0-9, fall through to decimal implementation
}
a = 0;
ndigits = 0;
acc = 0;
while (cval >= 0 && cval <= 9) {
acc = acc*10 + cval;
ndigits++;
if (ndigits == iodigits) {
mul(a, a, ioradix);
add(a, a, acc);
ndigits = 0;
acc = 0;
}
s.get();
c = s.peek();
cval = CharToIntVal(c);
}
if (ndigits != 0) {
long mpy = 1;
while (ndigits > 0) {
mpy = mpy * 10;
ndigits--;
}
mul(a, a, mpy);
add(a, a, acc);
}
if (sign == -1)
negate(a, a);
x = a;
return s;
}
// The class _ZZ_local_stack should be defined in an empty namespace,
// but since I don't want to rely on namespaces, we just give it a funny
// name to avoid accidental name clashes.
struct _ZZ_local_stack {
long top;
long alloc;
long *elts;
_ZZ_local_stack() { top = -1; alloc = 0; elts = 0; }
~_ZZ_local_stack() { }
long pop() { return elts[top--]; }
long empty() { return (top == -1); }
void push(long x);
};
void _ZZ_local_stack::push(long x)
{
if (alloc == 0) {
alloc = 100;
elts = (long *) NTL_MALLOC(alloc, sizeof(long), 0);
}
top++;
if (top + 1 > alloc) {
alloc = 2*alloc;
elts = (long *) NTL_REALLOC(elts, alloc, sizeof(long), 0);
}
if (!elts) {
Error("out of space in ZZ output");
}
elts[top] = x;
}
static
void PrintDigits(ostream& s, long d, long justify)
{
static char *buf = 0;
if (!buf) {
buf = (char *) NTL_MALLOC(iodigits, 1, 0);
if (!buf) Error("out of memory");
}
long i = 0;
while (d) {
buf[i] = IntValToChar(d % 10);
d = d / 10;
i++;
}
if (justify) {
long j = iodigits - i;
while (j > 0) {
s << "0";
j--;
}
}
while (i > 0) {
i--;
s << buf[i];
}
}
static void HexPrintBytes(ostream& s, unsigned char* buf, long numBytes)
{
// it is assumed that numBytes > 0
if (numBytes <= 0) Error("HexPrintBytes: numBytes non-positive");
while (--numBytes >= 0) {
long c = buf[numBytes];
s << IntValToChar((c>>4)& 0xf);
s << IntValToChar( c & 0xf);
}
}
ostream& operator<<(ostream& s, const ZZ& a)
{
static ZZ b;
static _ZZ_local_stack S;
long r;
long k;
if (!iodigits) InitZZIO();
b = a;
k = sign(b);
if (k == 0) {
s << "0";
return s;
}
if (k < 0) {
s << "-";
negate(b, b);
}
if (ZZ::HexOutput) { // output in Hexadecimal format
s << "0x";
long nb=NumBytes(b);
unsigned char* txt = NTL_NEW_OP unsigned char[nb];
BytesFromZZ(txt, b, nb);
HexPrintBytes(s,txt,nb);
delete [] txt;
return s;
}
do {
r = DivRem(b, b, ioradix);
S.push(r);
} while (!IsZero(b));
r = S.pop();
PrintDigits(s, r, 0);
while (!S.empty()) {
r = S.pop();
PrintDigits(s, r, 1);
}
return s;
}
long GCD(long a, long b)
{
long u, v, t, x;
if (a < 0) {
if (a < -NTL_MAX_LONG) Error("GCD: integer overflow");
a = -a;
}
if (b < 0) {
if (b < -NTL_MAX_LONG) Error("GCD: integer overflow");
b = -b;
}
if (b==0)
x = a;
else {
u = a;
v = b;
do {
t = u % v;
u = v;
v = t;
} while (v != 0);
x = u;
}
return x;
}
void XGCD(long& d, long& s, long& t, long a, long b)
{
long u, v, u0, v0, u1, v1, u2, v2, q, r;
long aneg = 0, bneg = 0;
if (a < 0) {
if (a < -NTL_MAX_LONG) Error("XGCD: integer overflow");
a = -a;
aneg = 1;
}
if (b < 0) {
if (b < -NTL_MAX_LONG) Error("XGCD: integer overflow");
b = -b;
bneg = 1;
}
u1=1; v1=0;
u2=0; v2=1;
u = a; v = b;
while (v != 0) {
q = u / v;
r = u % v;
u = v;
v = r;
u0 = u2;
v0 = v2;
u2 = u1 - q*u2;
v2 = v1- q*v2;
u1 = u0;
v1 = v0;
}
if (aneg)
u1 = -u1;
if (bneg)
v1 = -v1;
d = u;
s = u1;
t = v1;
}
long InvMod(long a, long n)
{
long d, s, t;
XGCD(d, s, t, a, n);
if (d != 1) Error("InvMod: inverse undefined");
if (s < 0)
return s + n;
else
return s;
}
long PowerMod(long a, long ee, long n)
{
long x, y;
unsigned long e;
if (ee < 0)
e = - ((unsigned long) ee);
else
e = ee;
x = 1;
y = a;
while (e) {
if (e & 1) x = MulMod(x, y, n);
y = MulMod(y, y, n);
e = e >> 1;
}
if (ee < 0) x = InvMod(x, n);
return x;
}
long ProbPrime(long n, long NumTests)
{
long m, x, y, z;
long i, j, k;
if (n <= 1) return 0;
if (n == 2) return 1;
if (n % 2 == 0) return 0;
if (n