"""
Low-level functions for arbitrary-precision floating-point arithmetic.
"""
__docformat__ = 'plaintext'
import math
from bisect import bisect
from random import randrange
#----------------------------------------------------------------------------#
# Some commonly needed float values #
#----------------------------------------------------------------------------#
# Regular number format:
# (-1)**sign * mantissa * 2**exponent, plus bitcount of mantissa
fzero = (0, 0, 0, 0)
fnzero = (1, 0, 0, 0)
fone = (0, 1, 0, 1)
fnone = (1, 1, 0, 1)
ftwo = (0, 1, 1, 1)
ften = (0, 5, 1, 3)
fhalf = (0, 1, -1, 1)
# Arbitrary encoding for special numbers: zero mantissa, nonzero exponent
fnan = (0, 0, -123, -1)
finf = (0, 0, -456, -2)
fninf = (1, 0, -789, -3)
#----------------------------------------------------------------------------#
# Various utilities related to precision and bit-fiddling #
#----------------------------------------------------------------------------#
def prec_to_dps(n):
"""Return number of accurate decimals that can be represented
with a precision of n bits."""
return max(1, int(round(int(n)/3.3219280948873626)-1))
def dps_to_prec(n):
"""Return the number of bits required to represent n decimals
accurately."""
return max(1, int(round((int(n)+1)*3.3219280948873626)))
def repr_dps(n):
"""Return the number of decimal digits required to represent
a number with n-bit precision so that it can be uniquely
reconstructed from the representation."""
dps = prec_to_dps(n)
if dps == 15:
return 17
return dps + 3
def giant_steps(start, target):
"""Return a list of integers ~= [start, 2*start, ..., target/2,
target] describing suitable precision steps for Newton's method."""
L = [target]
while L[-1] > start*2:
L = L + [L[-1]//2 + 1]
return L[::-1]
def rshift(x, n):
"""For an integer x, calculate x >> n with the fastest (floor)
rounding. Unlike the plain Python expression (x >> n), n is
allowed to be negative, in which case a left shift is performed."""
if n >= 0: return x >> n
else: return x << (-n)
def lshift(x, n):
"""For an integer x, calculate x << n. Unlike the plain Python
expression (x << n), n is allowed to be negative, in which case a
right shift with default (floor) rounding is performed."""
if n >= 0: return x << n
else: return x >> (-n)
def trailing(n):
"""Count the number of trailing zero bits in abs(n)."""
if not n:
return 0
t = 0
while not n & 1:
n >>= 1
t += 1
return t
# Small powers of 2
powers = [1<<_ for _ in range(300)]
def bitcount(n):
"""Calculate bit size of the nonnegative integer n."""
bc = bisect(powers, n)
if bc != 300:
return bc
bc = int(math.log(n, 2)) - 4
return bc + bctable[n>>bc]
# Used to avoid slow function calls as far as possible
trailtable = map(trailing, range(256))
bctable = map(bitcount, range(1024))
#----------------------------------------------------------------------------#
# Rounding #
#----------------------------------------------------------------------------#
# All supported rounding modes
round_nearest = intern('n')
round_floor = intern('f')
round_ceiling = intern('c')
round_up = intern('u')
round_down = intern('d')
# This function can be used to round a mantissa generally. However,
# we will try to do most rounding inline for efficiency.
def round_int(x, n, rounding):
if rounding is round_nearest:
if x >= 0:
t = x >> (n-1)
if t & 1 and ((t & 2) or (x & h_mask[n<300][n])):
return (t>>1)+1
else:
return t>>1
else:
return -round_int(-x, n, rounding)
if rounding is round_floor:
return x >> n
if rounding is round_ceiling:
return -((-x) >> n)
if rounding is round_down:
if x >= 0:
return x >> n
return -((-x) >> n)
if rounding is round_up:
if x >= 0:
return -((-x) >> n)
return x >> n
# These masks are used to pick out segments of numbers to determine
# which direction to round when rounding to nearest.
class h_mask_big:
def __getitem__(self, n):
return (1<<(n-1))-1
h_mask_small = [0]+[((1<<(_-1))-1) for _ in range(1, 300)]
h_mask = [h_mask_big(), h_mask_small]
# The >> operator rounds to floor. shifts_down[rounding][sign]
# tells whether this is the right direction to use, or if the
# number should be negated before shifting
shifts_down = {'f':(1,0), 'c':(0,1), 'd':(1,1), 'u':(0,0)}
#----------------------------------------------------------------------------#
# Normalization of raw mpfs #
#----------------------------------------------------------------------------#
# This function is called almost every time an mpf is created.
# It has been optimized accordingly.
def normalize(sign, man, exp, bc, prec, rounding):
"""
Create a raw mpf tuple with value (-1)**sign * man * 2**exp and
normalized mantissa. The mantissa is rounded in the specified
direction if its size exceeds the precision. Trailing zero bits
are also stripped from the mantissa to ensure that the
representation is canonical.
Conditions on the input:
* The input must represent a regular (finite) number
* The sign bit must be 0 or 1
* The mantissa must be positive
* The mxponent must be an integer
* The bitcount must be exact
If these conditions are not met, use from_man_exp, fpos, or any
of the conversion functions to create normalized raw mpf tuples.
"""
if not man:
return fzero
# Cut mantissa down to size if larger than target precision
n = bc - prec
if n > 0:
if rounding is round_nearest:
t = man >> (n-1)
if t & 1 and ((t & 2) or (man & h_mask[n<300][n])):
man = (t>>1)+1
else:
man = t>>1
elif shifts_down[rounding][sign]:
man >>= n
else:
man = -((-man)>>n)
exp += n
bc = prec
# Strip trailing bits
if not man & 1:
t = trailtable[man & 255]
if not t:
while not man & 255:
man >>= 8
exp += 8
bc -= 8
t = trailtable[man & 255]
man >>= t
exp += t
bc -= t
# Bit count can be wrong if the input mantissa was 1 less than
# a power of 2 and got rounded up, thereby adding an extra bit.
# With trailing bits removed, all powers of two have mantissa 1,
# so this is easy to check for.
if man == 1:
bc = 1
return sign, man, exp, bc
#----------------------------------------------------------------------------#
# Conversion functions #
#----------------------------------------------------------------------------#
def from_man_exp(man, exp, prec=None, rounding=None):
"""Create raw mpf from (man, exp) pair. The mantissa may be signed.
If no precision is specified, the mantissa is stored exactly."""
sign = 0
if man < 0:
sign = 1
man = -man
if man < 1024:
bc = bctable[man]
else:
bc = bitcount(man)
if not prec:
if not man:
return fzero
while not man & 1:
man >>= 1
exp += 1
bc -= 1
return (sign, man, exp, bc)
return normalize(sign, man, exp, bc, prec, rounding)
def from_int(n, prec=None, rounding=No