Source code for hensel.hensel

"""
Pure-Python implementation of `Hensel lifting \
<https://en.wikipedia.org/wiki/Hensel%27s_lemma>`__ for square roots modulo
a prime power.
"""
from __future__ import annotations
import doctest
from egcd import egcd

def _inv(a, m):
    """
    Return the modular multiplicative inverse of the supplied integer.

    >>> (_inv(7, 11) * 7) % 11
    1
    """
    return egcd(a, m)[1] % m

def _exp(r, p):
    """
    Determine the largest ``k`` such that ``p ** k`` divides ``r``.

    >>> _exp(11 * (7 ** 3), 7)
    3
    """
    e = p
    k = 0
    while r % e == 0:
        e *= p
        k += 1

    return k

[docs]def hensel(root: int, prime: int, exponent: int = 1) -> int: """ Lift a square root of a value modulo ``prime ** exponent`` to the square root of that same value modulo ``prime ** (exponent + 1)``. More specifically, let ``square`` be a nonnegative integer that is the least nonnegative residue of the congruence class ``root ** 2`` modulo ``prime ** exponent``. Use `Hensel lifting <https://en.wikipedia.org/wiki/Hensel%27s_lemma>`__ to return an integer that represents the square root modulo ``prime ** (exponent + 1)`` of the congruence class represented by the integer ``square`` modulo ``prime ** (exponent + 1)``. >>> hensel(4, 7) 39 >>> hensel(2, 7, 2) 2 This function implements a lifting operation even for those cases in which the root has the supplied prime as a factor (or is zero). >>> hensel(28, 7, 3) 273 >>> pow(28, 2, 7 ** 3) == pow(273, 2, 7 ** 4) True >>> hensel(256, 2, 12) 512 >>> pow(256, 2, 2 ** 12) == pow(512, 2, 2 ** 13) True This function lifts distinct roots to distinct roots when possible. >>> def roots(s, m): ... return [r for r in range(0, m) if pow(r, 2, m) == s] >>> [hensel(r, 3, 5) for r in roots(81, 3 ** 5)] == roots(81, 3 ** 6) True However, when the root has the supplied prime as a factor, it may be the case that not all roots modulo ``prime ** (exponent + 1)`` can be obtained via lifting. In that case, the number of distinct roots that can be obtained is equivalent to the number of distinct roots that are available to lift. >>> [hensel(r, 2, 5) for r in roots(16, 2 ** 5)] [12, 28, 44, 60] >>> roots(16, 2 ** 6) [4, 12, 20, 28, 36, 44, 52, 60] Any attempt to invoke this function with arguments that do not have the expected types (or do not fall within the supported ranges) raises an exception. **If** ``prime`` **is not a prime number, the behavior of this function is not specified.** >>> hensel('abc', 7) Traceback (most recent call last): ... TypeError: 'str' object cannot be interpreted as an integer >>> hensel(2, {}) Traceback (most recent call last): ... TypeError: 'dict' object cannot be interpreted as an integer >>> hensel(2, 7, []) Traceback (most recent call last): ... TypeError: 'list' object cannot be interpreted as an integer >>> hensel(2, -1) Traceback (most recent call last): ... ValueError: prime must be a positive integer >>> hensel(2, 7, -1) Traceback (most recent call last): ... ValueError: exponent must be a nonnegative integer The examples below verify the correct behavior of the function on a range of different inputs. >>> all( ... pow(r, 2, p ** k) == pow(hensel(r, p, k), 2, p ** (k + 1)) ... for k in range(0, 5) ... for p in [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] ... for r in range(0, p ** k) ... ) True >>> all( ... lifted.issubset(actual) and ( ... len(actual) == len(lifted) ... or ... len(actual) == len(lifted) * p ... ) ... for k in range(0, 4) ... for p in [2, 3, 5, 7, 11, 13] ... for s in [pow(x, 2, p ** k) for x in range(0, p ** k)] ... for lifted in [set(hensel(r, p, k) for r in roots(s, p ** k))] ... for actual in [roots(s, p ** (k + 1))] ... ) True """ # pylint: disable=too-many-branches if not isinstance(root, int): raise TypeError( "'" + type(root).__name__ + "'" + ' object cannot be interpreted as an integer' ) if not isinstance(prime, int): raise TypeError( "'" + type(prime).__name__ + "'" + ' object cannot be interpreted as an integer' ) if not isinstance(exponent, int): raise TypeError( "'" + type(exponent).__name__ + "'" + ' object cannot be interpreted as an integer' ) if prime < 0: raise ValueError('prime must be a positive integer') if exponent < 0: raise ValueError('exponent must be a nonnegative integer') square = pow(root, 2, prime ** exponent) if square == 0: return ( root * (prime ** (1 - (exponent % 2))) if prime == 2 else root * prime if exponent % 2 == 0 else root ) prime_to_exponent = prime ** exponent prime_to_exponent_plus_one = prime_to_exponent * prime prime_to_exponent_adjusted = prime ** (exponent - _exp(root, prime)) offset = root % prime_to_exponent_adjusted bottom_half = offset < prime_to_exponent_adjusted // 2 base = min(offset, prime_to_exponent_adjusted - offset) # Specialized calculation for the case in which the supplied prime is 2. if prime == 2: prime_to_exponent_adjusted //= prime bottom_half = root % prime_to_exponent_adjusted < prime_to_exponent_adjusted // 2 lifted = ( base if pow(base, 2, prime_to_exponent_plus_one) == square else prime_to_exponent_adjusted - base ) else: # Basic Hensel lifting (sufficient for roots that are coprime with the modulus). def _lift(value: int) -> int: multiple = ( ((_inv(value, prime) * _inv(2, prime)) % prime) * (((square - pow(root, 2)) // prime_to_exponent) % prime) ) % prime return (value + (multiple * prime_to_exponent)) % prime_to_exponent_plus_one if root % prime != 0: return _lift(root) lifted = _lift(base) # Perform additional work if the root is not coprime with the modulus. # Determine the multiple of the additional prime power needed to adjust the # lifted root (in order to account for the fact that it is not coprime with # the modulus). multiple = ( ( ((square - (lifted * lifted)) % prime_to_exponent_plus_one) * _inv( 2 * lifted * prime_to_exponent_adjusted, prime_to_exponent_plus_one ) ) % prime_to_exponent_plus_one ) // prime_to_exponent lifted = ( lifted + ((prime_to_exponent_adjusted * multiple) % prime_to_exponent_plus_one) ) % prime_to_exponent_plus_one segment = root // prime_to_exponent_adjusted return ( (segment * prime_to_exponent_adjusted * prime) + lifted if bottom_half else ((segment + 1) * prime_to_exponent_adjusted * prime) - lifted )
if __name__ == '__main__': doctest.testmod() # pragma: no cover