Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions pairing/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,38 @@
import math


def improved_i_sqrt(n):
assert n >= 0
if n == 0:
return 0
i = n.bit_length() >> 1 # i = floor( (1 + floor(log_2(n))) / 2 )
m = 1 << i # m = 2^i
#
# Fact: (2^(i + 1))^2 > n, so m has at least as many bits
# as the floor of the square root of n.
#
# Proof: (2^(i+1))^2 = 2^(2i + 2) >= 2^(floor(log_2(n)) + 2)
# >= 2^(ceil(log_2(n) + 1) >= 2^(log_2(n) + 1) > 2^(log_2(n)) = n. QED.
#
while (m << i) > n: # (m<<i) = m*(2^i) = m*m
m >>= 1
i -= 1
d = n - (m << i) # d = n-m^2
for k in range(i-1, -1, -1):
j = 1 << k
new_diff = d - (((m<<1) | j) << k) # n-(m+2^k)^2 = n-m^2-2*m*2^k-2^(2k)
if new_diff >= 0:
d = new_diff
m |= j
return m


def pair(k1, k2, safe=True):
"""
Cantor pairing function
http://en.wikipedia.org/wiki/Pairing_function#Cantor_pairing_function
"""
z = int(0.5 * (k1 + k2) * (k1 + k2 + 1) + k2)
z = (k1 + k2) * (k1 + k2 + 1) // 2 + k2
if safe and (k1, k2) != depair(z):
raise ValueError("{} and {} cannot be paired".format(k1, k2))
return z
Expand All @@ -17,9 +43,9 @@ def depair(z):
Inverse of Cantor pairing function
http://en.wikipedia.org/wiki/Pairing_function#Inverting_the_Cantor_pairing_function
"""
w = math.floor((math.sqrt(8 * z + 1) - 1)/2)
t = (w**2 + w) / 2
y = int(z - t)
x = int(w - y)
w = (improved_i_sqrt(8 * z + 1) - 1)//2
t = (w**2 + w) // 2
y = z - t
x = w - y
# assert z != pair(x, y, safe=False):
return x, y
22 changes: 14 additions & 8 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,29 @@
from __future__ import print_function
from pairing import pair, depair
import timeit
import random


def test_pair(a, b):
assert depair(pair(a, b)) == (a, b)
return pair(a, b)


def run_tests():
def run_tests(num_random_tests = 0):
test_pair(22, 33)
test_pair(2**8, 2**8)
test_pair(2**16, 2**16)

try:
test_pair(2**52, 2**52)
except ValueError:
pass # long integers suffer from some imprecision at this size

test_pair(2**52, 2**52)
test_pair(2**52, 1)
test_pair(2**52, 2)
test_pair(2**52, 55555555)
for i in range(0, num_random_tests):
test_pair(random.getrandbits(512), random.getrandbits(512))
test_pair(random.getrandbits(512), 0)
test_pair(0, random.getrandbits(512))
test_pair(random.getrandbits(512), 1)
test_pair(1, random.getrandbits(512))

try:
test_pair(-1, -1)
except ValueError:
Expand All @@ -28,7 +34,7 @@ def run_tests():


if __name__ == '__main__':
print(run_tests())
print(run_tests(2000))

print("Benchmarking...")
i = 20000
Expand Down