import random
import binascii

BLOCKSIZE = 16


class KeyLengthError(Exception):
    pass


def bytes2int(b):
    return int.from_bytes(b, byteorder='big')


def int2bytes(i, bs=BLOCKSIZE):
    return i.to_bytes(bs, byteorder='big')


def pad(msg):
    padlength = BLOCKSIZE - (len(msg) % BLOCKSIZE)
    return msg + bytes([padlength] * padlength)


def strip_padding(msg):
    len_padding = msg[-1]
    if msg[-len_padding:] == bytes([len_padding]*len_padding):
        return msg[:-len_padding]
    return msg


def xor(a, b):
    return bytes([a_i ^ b_i for a_i, b_i in zip(a, b)])


# order preserving encryption
class OPE:
    def __init__(self, coefficients):
        if any(c <= 0 for c in coefficients):
            raise KeyLengthError("Needs exactly positive coefficients")
        self.coefficients = coefficients

    @staticmethod
    def new(key):
        if len(key) not in (16, 24):
            raise KeyLengthError("Key must be of length 128bits or 192bits")
        coefficients = [bytes2int(key[:8]), bytes2int(key[8:16])]
        if len(key) > 16:
            coefficients.append(bytes2int(key[16:]))
        return OPE(coefficients)

    def encrypt(self, msg):
        IV = random.SystemRandom().getrandbits(BLOCKSIZE * 8)
        c_i = IV.to_bytes(BLOCKSIZE, byteorder='big')
        c = b''
        padded_msg = pad(msg)
        for i in range(0, len(padded_msg), BLOCKSIZE // 2):
            p_i = padded_msg[i:i+(BLOCKSIZE // 2)]
            intermediate = xor(p_i, c_i)
            c_i = int2bytes(self._peval(bytes2int(intermediate)))
            c += c_i
        return IV, c

    def decrypt(self, IV, cipher):
        c_old = IV.to_bytes(BLOCKSIZE, byteorder='big')
        m = b''
        for i in range(0, len(cipher), BLOCKSIZE):
            c_i = cipher[i:i+BLOCKSIZE]
            intermediate = int2bytes(self._pinv(bytes2int(c_i)), bs=BLOCKSIZE // 2)
            p_i = xor(intermediate, c_old)
            m += p_i
            c_old = c_i
        return strip_padding(m)

    def _peval(self, block):
        p = 0
        for i in self.coefficients[1:][-1::-1]:
            p = p * block + i
            p = p * block
        return p + self.coefficients[0]

    def _pinv(self, block):
        # only allowed coefficients can be solved like this
        x = self.coefficients[0]
        y = -(BLOCKSIZE * 8)
        for k, i in enumerate(self.coefficients[1:][-1::-1]):
            y += (block - x) // i << k + 1
            x = 3 * (block - x) // i + 1337 * k - (BLOCKSIZE * (k+2)**3)
        x -= y
        return x
