diff --git a/.travis.yml b/.travis.yml index 4e6de96..bbae422 100644 --- a/.travis.yml +++ b/.travis.yml @@ -8,11 +8,13 @@ matrix: - python: "3.6" - python: "3.7" - python: "3.8" - - python: "3.9-dev" - - python: "pypy2.7-6.0" - - python: "pypy3.5-6.0" - allow_failures: - - python: "3.9-dev" + - python: "3.9" + - python: "3.10" + - python: "3.11" + - python: "3.12" + - python: "3.13" + - python: "pypy2.7-7.3.9" + - python: "pypy3.10" script: - python -m unittest discover diff --git a/README.md b/README.md index e6b32c3..7bcf851 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,18 @@ We tried other Python libraries such as [python-ecdsa], [fast-ecdsa] and other less famous ones, but we didn't find anything that suited our needs. The first one was pure Python, but it was too slow. The second one mixed Python and C and it was really fast, but we were unable to use it in our current infrastructure, which required pure Python code. -For this reason, we decided to create something simple, compatible with OpenSSL and fast using elegant math such as Jacobian Coordinates to speed up the ECDSA. Starkbank-ECDSA is fully compatible with Python2 and Python3. +For this reason, we decided to create something simple, compatible with OpenSSL and fast using elegant math such as Jacobian Coordinates to speed up the ECDSA. Starkbank-ECDSA is fully compatible with Python 2.7 and Python 3. + +### Security + +starkbank-ecdsa includes the following security features: + +- **Hedged RFC 6979 nonces**: Deterministic k derivation with fresh random entropy mixed into K-init (RFC 6979 §3.6), eliminating the catastrophic risk of nonce reuse that leaks private keys while preserving protection even if the RNG fails +- **Low-S signature normalization**: Prevents signature malleability (BIP-62) +- **Public key on-curve validation**: Blocks invalid-curve attacks during verification +- **Montgomery ladder scalar multiplication**: Constant-operation point multiplication to mitigate timing side channels +- **Hash truncation**: Correctly handles hash functions larger than the curve order (e.g. SHA-512 with secp256k1) +- **Extended Euclidean modular inverse**: Implemented in pure Python for portability (Python 2.7+ and 3.x); transparently uses the C-level `pow(x, -1, n)` fast path on CPython 3.8+ for a roughly order-of-magnitude speedup over Fermat's little theorem on 256-bit operands ### Installation @@ -16,19 +27,21 @@ pip install starkbank-ecdsa ### Curves -We currently support `secp256k1`, but you can add more curves to the project. You just need to use the curve.add() function. +We currently support `secp256k1` and `prime256v1` (P-256), but you can add more curves to the project. You just need to use the curve.add() function. ### Speed -We ran a test on a MAC Pro i7 2017. The libraries were run 100 times and the averages displayed bellow were obtained: +We ran a test on an Apple Silicon Mac with Python 3.14. The libraries were run 500 times on secp256k1 with SHA-256 and deterministic (RFC 6979) nonces, and the averages displayed below were obtained: + +| Library | sign | verify | +|-----------------|:------:|:------:| +| [python-ecdsa] | ~1.0ms | ~3.6ms | +| [fast-ecdsa] | ~1.0ms | ~1.3ms | +| starkbank-ecdsa | ~0.6ms | ~1.7ms | -| Library | sign | verify | -| ------------------ |:-------------:| -------:| -| [python-ecdsa] | 121.3ms | 65.1ms | -| [fast-ecdsa] | 0.1ms | 0.2ms | -| starkbank-ecdsa | 4.1ms | 7.8ms | +Our pure Python code cannot compete with C-based libraries backed by GMP's hand-tuned assembly, but it matches the fastest pure-Python implementation on signing and is roughly `30%` faster on verification. -Our pure Python code cannot compete with C based libraries, but it's `6x faster` to verify and `23x faster` to sign than other pure Python libraries. +Performance is driven by Jacobian coordinates, a branch-balanced Montgomery ladder for variable-base scalar multiplication, a precomputed affine table of powers-of-two multiples of the generator (`[G, 2G, 4G, …, 2ⁿG]`) combined with a width-2 NAF of the scalar to eliminate doublings during signing, a mixed affine+Jacobian addition fast path, curve-specific shortcuts in point doubling (A=0 for secp256k1, A=-3 for prime256v1), the secp256k1 GLV endomorphism to split 256-bit scalars into two ~128-bit halves for a 4-scalar simultaneous multi-exponentiation during verification, Shamir's trick with Joint Sparse Form as the fallback path for curves without an efficient endomorphism, and the extended Euclidean algorithm for modular inversion. ### Sample Code @@ -219,7 +232,14 @@ python3 -m unittest discover python2 -m unittest discover ``` +### Run benchmark + +``` +python3 benchmark.py +python2 benchmark.py +``` + -[python-ecdsa]: https://github.com/warner/python-ecdsa +[python-ecdsa]: https://github.com/tlsfuzzer/python-ecdsa [fast-ecdsa]: https://github.com/AntonKueltz/fastecdsa [Stark Bank]: https://starkbank.com diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 0000000..fcee4cb --- /dev/null +++ b/benchmark.py @@ -0,0 +1,102 @@ +import time +from hashlib import sha256 +from ellipticcurve import Ecdsa, PrivateKey + + +ROUNDS = 500 +MESSAGE = "This is a benchmark test message" + + +def benchmarkStarkbank(): + privateKey = PrivateKey() + publicKey = privateKey.publicKey() + + sig = Ecdsa.sign(MESSAGE, privateKey) + Ecdsa.verify(MESSAGE, sig, publicKey) + + start = time.time() + for _ in range(ROUNDS): + sig = Ecdsa.sign(MESSAGE, privateKey) + signTime = (time.time() - start) / ROUNDS * 1000 + + start = time.time() + for _ in range(ROUNDS): + Ecdsa.verify(MESSAGE, sig, publicKey) + verifyTime = (time.time() - start) / ROUNDS * 1000 + + return signTime, verifyTime + + +def benchmarkPythonEcdsa(): + try: + from ecdsa import SigningKey, SECP256k1 + except ImportError: + return None, None + + sk = SigningKey.generate(curve=SECP256k1) + vk = sk.verifying_key + data = MESSAGE.encode() + + sig = sk.sign_deterministic(data, hashfunc=sha256) + vk.verify(sig, data, hashfunc=sha256) + + start = time.time() + for _ in range(ROUNDS): + sig = sk.sign_deterministic(data, hashfunc=sha256) + signTime = (time.time() - start) / ROUNDS * 1000 + + start = time.time() + for _ in range(ROUNDS): + vk.verify(sig, data, hashfunc=sha256) + verifyTime = (time.time() - start) / ROUNDS * 1000 + + return signTime, verifyTime + + +def benchmarkFastEcdsa(): + try: + from fastecdsa import curve, ecdsa, keys + except ImportError: + return None, None + + privateKey, publicKey = keys.gen_keypair(curve.secp256k1) + + r, s = ecdsa.sign(MESSAGE, privateKey, curve=curve.secp256k1) + ecdsa.verify((r, s), MESSAGE, publicKey, curve=curve.secp256k1) + + start = time.time() + for _ in range(ROUNDS): + r, s = ecdsa.sign(MESSAGE, privateKey, curve=curve.secp256k1) + signTime = (time.time() - start) / ROUNDS * 1000 + + start = time.time() + for _ in range(ROUNDS): + ecdsa.verify((r, s), MESSAGE, publicKey, curve=curve.secp256k1) + verifyTime = (time.time() - start) / ROUNDS * 1000 + + return signTime, verifyTime + + +def formatTime(ms): + return "n/a" if ms is None else "{:.1f}ms".format(ms) + + +def main(): + results = [ + ("python-ecdsa", benchmarkPythonEcdsa()), + ("fast-ecdsa", benchmarkFastEcdsa()), + ("starkbank-ecdsa", benchmarkStarkbank()), + ] + + print("") + print("ECDSA benchmark on secp256k1 ({} rounds)".format(ROUNDS)) + print("-" * 48) + print("{:<20} {:>12} {:>12}".format("library", "sign", "verify")) + print("-" * 48) + for name, (signMs, verifyMs) in results: + print("{:<20} {:>12} {:>12}".format(name, formatTime(signMs), formatTime(verifyMs))) + print("") + + +if __name__ == "__main__": + main() diff --git a/ellipticcurve/curve.py b/ellipticcurve/curve.py index df3e119..a11b7d3 100644 --- a/ellipticcurve/curve.py +++ b/ellipticcurve/curve.py @@ -1,3 +1,4 @@ +# coding: utf-8 # # Elliptic Curve Equation # @@ -9,15 +10,19 @@ class CurveFp: - def __init__(self, A, B, P, N, Gx, Gy, name, oid, nistName=None): + def __init__(self, A, B, P, N, Gx, Gy, name, oid, nistName=None, glvParams=None): self.A = A self.B = B self.P = P self.N = N + self.nBitLength = N.bit_length() self.G = Point(Gx, Gy) self.name = name self.nistName = nistName self.oid = oid # ASN.1 Object Identifier + # GLV endomorphism parameters (only for curves that support one, + # e.g. secp256k1). None means no endomorphism; fall back to Shamir+JSF. + self.glvParams = glvParams def contains(self, p): """ @@ -69,7 +74,18 @@ def getByOid(oid): N=0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141, Gx=0x79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798, Gy=0x483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8, - oid=[1, 3, 132, 0, 10] + oid=[1, 3, 132, 0, 10], + # GLV endomorphism φ((x,y)) = (β·x, y), equivalent to λ·P. + # Basis vectors from Gauss reduction; used to split a 256-bit scalar k + # into two ~128-bit scalars (k1, k2) with k ≡ k1 + k2·λ (mod N). + glvParams={ + "beta": 0x7ae96a2b657c07106e64479eac3434e99cf0497512f58995c1396c28719501ee, + "lambda": 0x5363ad4cc05c30e0a5261c028812645a122e22ea20816678df02967c1b23bd72, + "a1": 0x3086d221a7d46bcde86c90e49284eb15, + "b1": -0xe4437ed6010e88286f547fa90abfe4c3, + "a2": 0x114ca50f7a8e2f3f657c1108d9d44cfd8, + "b2": 0x3086d221a7d46bcde86c90e49284eb15, + }, ) prime256v1 = CurveFp( diff --git a/ellipticcurve/ecdsa.py b/ellipticcurve/ecdsa.py index ea809e4..2fa07c0 100644 --- a/ellipticcurve/ecdsa.py +++ b/ellipticcurve/ecdsa.py @@ -10,37 +10,46 @@ class Ecdsa: @classmethod def sign(cls, message, privateKey, hashfunc=sha256): - byteMessage = hashfunc(toBytes(message)).digest() - numberMessage = numberFromByteString(byteMessage) curve = privateKey.curve + byteMessage = hashfunc(toBytes(message)).digest() + numberMessage = numberFromByteString(byteMessage, curve.nBitLength) r, s, randSignPoint = 0, 0, None + kIterator = RandomInteger.rfc6979(byteMessage, privateKey.secret, curve, hashfunc) while r == 0 or s == 0: - randNum = RandomInteger.between(1, curve.N - 1) - randSignPoint = Math.multiply(curve.G, n=randNum, A=curve.A, P=curve.P, N=curve.N) + randNum = next(kIterator) + randSignPoint = Math.multiplyGenerator(curve, randNum) r = randSignPoint.x % curve.N s = ((numberMessage + r * privateKey.secret) * (Math.inv(randNum, curve.N))) % curve.N recoveryId = randSignPoint.y & 1 - if randSignPoint.y > curve.N: + if randSignPoint.x >= curve.N: recoveryId += 2 + if s > curve.N // 2: + s = curve.N - s + recoveryId ^= 1 return Signature(r=r, s=s, recoveryId=recoveryId) @classmethod def verify(cls, message, signature, publicKey, hashfunc=sha256): - byteMessage = hashfunc(toBytes(message)).digest() - numberMessage = numberFromByteString(byteMessage) curve = publicKey.curve + byteMessage = hashfunc(toBytes(message)).digest() + numberMessage = numberFromByteString(byteMessage, curve.nBitLength) r = signature.r s = signature.s + if not 1 <= r <= curve.N - 1: return False if not 1 <= s <= curve.N - 1: return False + if not curve.contains(publicKey.point): + return False inv = Math.inv(s, curve.N) - u1 = Math.multiply(curve.G, n=(numberMessage * inv) % curve.N, N=curve.N, A=curve.A, P=curve.P) - u2 = Math.multiply(publicKey.point, n=(r * inv) % curve.N, N=curve.N, A=curve.A, P=curve.P) - v = Math.add(u1, u2, A=curve.A, P=curve.P) + v = Math.multiplyAndAdd( + curve.G, (numberMessage * inv) % curve.N, + publicKey.point, (r * inv) % curve.N, + curve=curve, + ) if v.isAtInfinity(): return False return v.x % curve.N == r diff --git a/ellipticcurve/math.py b/ellipticcurve/math.py index 981ab4e..716b2ca 100644 --- a/ellipticcurve/math.py +++ b/ellipticcurve/math.py @@ -1,3 +1,4 @@ +# coding: utf-8 from .point import Point @@ -5,7 +6,112 @@ class Math: @classmethod def modularSquareRoot(cls, value, prime): - return pow(value, (prime + 1) // 4, prime) + """Tonelli-Shanks algorithm for modular square root. Works for all odd primes.""" + if value == 0: + return 0 + if prime == 2: + return value % 2 + + # Factor out powers of 2: prime - 1 = Q * 2^S + Q = prime - 1 + S = 0 + while Q % 2 == 0: + Q //= 2 + S += 1 + + if S == 1: # prime = 3 (mod 4) + return pow(value, (prime + 1) // 4, prime) + + # Find a quadratic non-residue z + z = 2 + while pow(z, (prime - 1) // 2, prime) != prime - 1: + z += 1 + + M = S + c = pow(z, Q, prime) + t = pow(value, Q, prime) + R = pow(value, (Q + 1) // 2, prime) + + while True: + if t == 1: + return R + + # Find the least i such that t^(2^i) = 1 (mod prime) + i = 1 + temp = (t * t) % prime + while temp != 1: + temp = (temp * temp) % prime + i += 1 + + b = pow(c, 1 << (M - i - 1), prime) + M = i + c = (b * b) % prime + t = (t * c) % prime + R = (R * b) % prime + + @classmethod + def multiplyGenerator(cls, curve, n): + """ + Fast scalar multiplication n*G using a precomputed affine table of + powers-of-two multiples of G and the width-2 NAF of n. Every non-zero + NAF digit triggers one mixed add and zero doublings, trading the ~256 + doublings of a windowed method for ~86 adds on average — a large net + reduction in field multiplications for 256-bit scalars. + + :param curve: Elliptic curve with generator G + :param n: Scalar multiplier + :return: Point n*G + """ + if n < 0 or n >= curve.N: + n = n % curve.N + if n == 0: + return Point(0, 0, 0) + + table = cls._generatorPowersTable(curve) + A, P = curve.A, curve.P + _add = cls._jacobianAdd + + r = Point(0, 0, 1) + i = 0 + k = n + while k > 0: + if k & 1: + digit = 2 - (k & 3) # -1 or +1 + k -= digit + g = table[i] + if digit == 1: + r = _add(r, g, A, P) + else: + r = _add(r, Point(g.x, P - g.y, 1), A, P) + k >>= 1 + i += 1 + return cls._fromJacobian(r, P) + + @classmethod + def _generatorPowersTable(cls, curve): + """ + Build [G, 2G, 4G, ..., 2^nBitLength * G] in affine (z=1) form, so each + add in multiplyGenerator hits the mixed-add fast path. + """ + cached = getattr(curve, "_generatorPowersTable_", None) + if cached is not None: + return cached + A, P = curve.A, curve.P + current = Point(curve.G.x, curve.G.y, 1) + table = [current] + # NAF of an nBitLength-bit scalar can be up to nBitLength+1 digits. + for _ in range(curve.nBitLength): + doubled = cls._jacobianDouble(current, A, P) + if doubled.y == 0: + current = doubled + else: + zInv = cls.inv(doubled.z, P) + zInv2 = (zInv * zInv) % P + zInv3 = (zInv2 * zInv) % P + current = Point((doubled.x * zInv2) % P, (doubled.y * zInv3) % P, 1) + table.append(current) + curve._generatorPowersTable_ = table + return table @classmethod def multiply(cls, p, n, N, A, P): @@ -38,32 +144,129 @@ def add(cls, p, q, A, P): cls._jacobianAdd(cls._toJacobian(p), cls._toJacobian(q), A, P), P, ) + @classmethod + def multiplyAndAdd(cls, p1, n1, p2, n2, N=None, A=None, P=None, curve=None): + """ + Compute n1*p1 + n2*p2. If ``curve`` is given and exposes ``glvParams`` + (e.g. secp256k1), uses the GLV endomorphism to split both scalars into + ~128-bit halves and run a 4-scalar simultaneous multi-exponentiation. + Otherwise falls back to Shamir's trick with JSF. Not constant-time — + use only with public scalars (e.g. verification). + + :param p1: First point + :param n1: First scalar + :param p2: Second point + :param n2: Second scalar + :param N: Order of the elliptic curve (ignored when ``curve`` is given) + :param A: Coefficient of the first-order term (ignored when ``curve`` is given) + :param P: Prime defining the field (ignored when ``curve`` is given) + :param curve: Optional curve object; enables GLV if ``curve.glvParams`` is set + :return: Point n1*p1 + n2*p2 + """ + if curve is not None: + N, A, P = curve.N, curve.A, curve.P + if curve.glvParams is not None: + return cls._glvMultiplyAndAdd(p1, n1, p2, n2, curve) + return cls._fromJacobian( + cls._shamirMultiply( + cls._toJacobian(p1), n1, + cls._toJacobian(p2), n2, + N, A, P, + ), P, + ) + + @classmethod + def _glvMultiplyAndAdd(cls, p1, n1, p2, n2, curve): + """ + Compute n1*p1 + n2*p2 using the GLV endomorphism. Splits each 256-bit + scalar into two ~128-bit scalars via k ≡ k1 + k2·λ (mod N), then runs + a 4-scalar simultaneous double-and-add over (p1, φ(p1), p2, φ(p2)) + with a 16-entry precomputed table of subset sums. Halves the loop + length versus the plain Shamir path. + """ + glv = curve.glvParams + N, A, P = curve.N, curve.A, curve.P + beta = glv["beta"] + + k1, k2 = cls._glvDecompose(n1 % N, glv, N) + k3, k4 = cls._glvDecompose(n2 % N, glv, N) + + # Base points (affine, z=1) — φ((x,y)) = (β·x mod P, y). + bases = [ + Point(p1.x, p1.y, 1), + Point((beta * p1.x) % P, p1.y, 1), + Point(p2.x, p2.y, 1), + Point((beta * p2.x) % P, p2.y, 1), + ] + scalars = [k1, k2, k3, k4] + for i in range(4): + if scalars[i] < 0: + scalars[i] = -scalars[i] + bases[i] = Point(bases[i].x, P - bases[i].y, 1) + + # Precompute table[idx] = sum of bases[i] selected by bits of idx. + _add = cls._jacobianAdd + table = [Point(0, 0, 1)] * 16 + for idx in range(1, 16): + low = idx & -idx + i = low.bit_length() - 1 + table[idx] = _add(table[idx ^ low], bases[i], A, P) + + _double = cls._jacobianDouble + maxLen = max(s.bit_length() for s in scalars) + r = Point(0, 0, 1) + s0, s1, s2, s3 = scalars + for bit in range(maxLen - 1, -1, -1): + r = _double(r, A, P) + idx = ((s0 >> bit) & 1) | (((s1 >> bit) & 1) << 1) \ + | (((s2 >> bit) & 1) << 2) | (((s3 >> bit) & 1) << 3) + if idx: + r = _add(r, table[idx], A, P) + + return cls._fromJacobian(r, P) + + @staticmethod + def _glvDecompose(k, glv, N): + """ + Decompose k into (k1, k2) with k ≡ k1 + k2·λ (mod N) and + |k1|, |k2| ~ √N. Babai rounding against the precomputed basis + {(a1, b1), (a2, b2)}; k1 and k2 may be negative. + """ + a1, b1, a2, b2 = glv["a1"], glv["b1"], glv["a2"], glv["b2"] + halfN = N // 2 + c1 = (b2 * k + halfN) // N + c2 = (-b1 * k + halfN) // N + k1 = k - c1 * a1 - c2 * a2 + k2 = -c1 * b1 - c2 * b2 + return k1, k2 + @classmethod def inv(cls, x, n): """ - Extended Euclidean Algorithm. It's the 'division' in elliptic curves + Modular inverse via the Extended Euclidean Algorithm. Implemented in + pure Python for compatibility with Python 2.7+ and 3.x. CPython 3.8+ + users get a faster C-level implementation via ``pow(x, -1, n)`` that + this falls back to when available. - :param x: Divisor + :param x: Divisor (must be coprime to n) :param n: Mod for division :return: Value representing the division + :raises ValueError: when x is 0 mod n (no inverse exists) """ - if x == 0: - return 0 + if x % n == 0: + raise ValueError("0 has no modular inverse") - lm = 1 - hm = 0 - low = x % n - high = n + try: + return pow(x, -1, n) + except (TypeError, ValueError): + pass + lm, hm = 1, 0 + low, high = x % n, n while low > 1: r = high // low - nm = hm - lm * r - nw = high - low * r - high = low - hm = lm - low = nw - lm = nm - + lm, hm = hm - lm * r, lm + low, high = high - low * r, low return lm % n @classmethod @@ -85,6 +288,9 @@ def _fromJacobian(cls, p, P): :param P: Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod p) :return: Point in default coordinates """ + if p.y == 0: + return Point(0, 0, 0) + z = cls.inv(p.z, P) x = (p.x * z ** 2) % P y = (p.y * z ** 3) % P @@ -101,15 +307,23 @@ def _jacobianDouble(cls, p, A, P): :param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p) :return: Point that represents the sum of First and Second Point """ - if p.y == 0: + py = p.y + if py == 0: return Point(0, 0, 0) - ysq = (p.y ** 2) % P - S = (4 * p.x * ysq) % P - M = (3 * p.x ** 2 + A * p.z ** 4) % P - nx = (M**2 - 2 * S) % P - ny = (M * (S - nx) - 8 * ysq ** 2) % P - nz = (2 * p.y * p.z) % P + px, pz = p.x, p.z + ysq = (py * py) % P + S = (4 * px * ysq) % P + pz2 = (pz * pz) % P + if A == 0: + M = (3 * px * px) % P + elif A == -3 or A == P - 3: + M = (3 * (px - pz2) * (px + pz2)) % P + else: + M = (3 * px * px + A * pz2 * pz2) % P + nx = (M * M - 2 * S) % P + ny = (M * (S - nx) - 8 * ysq * ysq) % P + nz = (2 * py * pz) % P return Point(nx, ny, nz) @@ -129,10 +343,21 @@ def _jacobianAdd(cls, p, q, A, P): if q.y == 0: return p - U1 = (p.x * q.z ** 2) % P - U2 = (q.x * p.z ** 2) % P - S1 = (p.y * q.z ** 3) % P - S2 = (q.y * p.z ** 3) % P + px, py, pz = p.x, p.y, p.z + qx, qy, qz = q.x, q.y, q.z + + pz2 = (pz * pz) % P + U2 = (qx * pz2) % P + S2 = (qy * pz2 * pz) % P + + if qz == 1: + # Mixed affine+Jacobian add: qz²=qz³=1 saves four multiplications. + U1 = px + S1 = py + else: + qz2 = (qz * qz) % P + U1 = (px * qz2) % P + S1 = (py * qz2 * qz) % P if U1 == U2: if S1 != S2: @@ -144,38 +369,139 @@ def _jacobianAdd(cls, p, q, A, P): H2 = (H * H) % P H3 = (H * H2) % P U1H2 = (U1 * H2) % P - nx = (R ** 2 - H3 - 2 * U1H2) % P + nx = (R * R - H3 - 2 * U1H2) % P ny = (R * (U1H2 - nx) - S1 * H3) % P - nz = (H * p.z * q.z) % P + nz = (H * pz) % P if qz == 1 else (H * pz * qz) % P return Point(nx, ny, nz) @classmethod def _jacobianMultiply(cls, p, n, N, A, P): """ - Multily point and scalar in elliptic curves - - :param p: First Point to mutiply - :param n: Scalar to mutiply + Multiply point and scalar in elliptic curves using a branch-balanced + Montgomery ladder: each scalar bit triggers exactly one add and one + double in swapped order, masking simple branch-timing leaks. Note: + Python's bignum arithmetic is NOT constant-time per operation, so + total execution time still leaks through bignum-op duration. True + constant-time ECDSA is not achievable in pure Python. + + :param p: First Point to multiply + :param n: Scalar to multiply :param N: Order of the elliptic curve :param P: Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod p) :param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p) - :return: Point that represents the sum of First and Second Point + :return: Point that represents the scalar multiplication """ if p.y == 0 or n == 0: return Point(0, 0, 1) - if n == 1: - return p - if n < 0 or n >= N: - return cls._jacobianMultiply(p, n % N, N, A, P) + n = n % N - if (n % 2) == 0: - return cls._jacobianDouble( - cls._jacobianMultiply(p, n // 2, N, A, P), A, P - ) + if n == 0: + return Point(0, 0, 1) - return cls._jacobianAdd( - cls._jacobianDouble(cls._jacobianMultiply(p, n // 2, N, A, P), A, P), p, A, P - ) + _add = cls._jacobianAdd + _double = cls._jacobianDouble + + # Montgomery ladder: always performs one add and one double per bit + r0 = Point(0, 0, 1) + r1 = Point(p.x, p.y, p.z) + + for i in range(n.bit_length() - 1, -1, -1): + if (n >> i) & 1 == 0: + r1 = _add(r0, r1, A, P) + r0 = _double(r0, A, P) + else: + r0 = _add(r0, r1, A, P) + r1 = _double(r1, A, P) + + return r0 + + @classmethod + def _shamirMultiply(cls, jp1, n1, jp2, n2, N, A, P): + """ + Compute n1*p1 + n2*p2 using Shamir's trick with Joint Sparse Form + (Solinas 2001). JSF picks signed digits in {-1, 0, 1} so at most ~l/2 + digit pairs are non-zero, versus ~3l/4 for the raw binary form. Not + constant-time — use only with public scalars (e.g. verification). + + :param jp1: First point in Jacobian coordinates + :param n1: First scalar + :param jp2: Second point in Jacobian coordinates + :param n2: Second scalar + :param N: Order of the elliptic curve + :param A: Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod p) + :param P: Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod p) + :return: Point n1*p1 + n2*p2 in Jacobian coordinates + """ + if n1 < 0 or n1 >= N: + n1 = n1 % N + if n2 < 0 or n2 >= N: + n2 = n2 % N + + if n1 == 0 and n2 == 0: + return Point(0, 0, 1) + + _add = cls._jacobianAdd + _double = cls._jacobianDouble + + def neg(pt): + return Point(pt.x, 0 if pt.y == 0 else P - pt.y, pt.z) + + jp1p2 = _add(jp1, jp2, A, P) + jp1mp2 = _add(jp1, neg(jp2), A, P) + addTable = { + (1, 0): jp1, + (-1, 0): neg(jp1), + (0, 1): jp2, + (0, -1): neg(jp2), + (1, 1): jp1p2, + (-1, -1): neg(jp1p2), + (1, -1): jp1mp2, + (-1, 1): neg(jp1mp2), + } + + digits = cls._jsfDigits(n1, n2) + r = Point(0, 0, 1) + for u0, u1 in digits: + r = _double(r, A, P) + if u0 or u1: + r = _add(r, addTable[(u0, u1)], A, P) + + return r + + @staticmethod + def _jsfDigits(k0, k1): + """ + Joint Sparse Form of (k0, k1): list of signed-digit pairs (u0, u1) in + {-1, 0, 1}, ordered MSB-first. At most one of any two consecutive pairs + is non-zero, giving density ~1/2 instead of ~3/4 from raw binary. + """ + digits = [] + d0 = 0 + d1 = 0 + while k0 + d0 != 0 or k1 + d1 != 0: + a0 = k0 + d0 + a1 = k1 + d1 + if a0 & 1: + u0 = 1 if (a0 & 3) == 1 else -1 + if (a0 & 7) in (3, 5) and (a1 & 3) == 2: + u0 = -u0 + else: + u0 = 0 + if a1 & 1: + u1 = 1 if (a1 & 3) == 1 else -1 + if (a1 & 7) in (3, 5) and (a0 & 3) == 2: + u1 = -u1 + else: + u1 = 0 + digits.append((u0, u1)) + if 2 * d0 == 1 + u0: + d0 = 1 - d0 + if 2 * d1 == 1 + u1: + d1 = 1 - d1 + k0 >>= 1 + k1 >>= 1 + digits.reverse() + return digits diff --git a/ellipticcurve/privateKey.py b/ellipticcurve/privateKey.py index df6fb4d..e5b1a28 100644 --- a/ellipticcurve/privateKey.py +++ b/ellipticcurve/privateKey.py @@ -12,6 +12,8 @@ class PrivateKey: def __init__(self, curve=secp256k1, secret=None): self.curve = curve self.secret = secret or RandomInteger.between(1, curve.N - 1) + if not 1 <= self.secret <= curve.N - 1: + raise Exception("Secret must be in range [1, N-1] for curve {name}".format(name=curve.name)) def publicKey(self): curve = self.curve diff --git a/ellipticcurve/publicKey.py b/ellipticcurve/publicKey.py index 3ebb593..2d14337 100644 --- a/ellipticcurve/publicKey.py +++ b/ellipticcurve/publicKey.py @@ -90,7 +90,10 @@ def fromCompressed(cls, string, curve=secp256k1): raise Exception("Compressed string should start with 02 or 03") x = intFromHex(xHex) y = curve.y(x, isEven=parityTag == _evenTag) - return cls(point=Point(x, y), curve=curve) + point = Point(x, y) + if not curve.contains(point): + raise Exception("Point ({x},{y}) is not valid for curve {name}".format(x=x, y=y, name=curve.name)) + return cls(point=point, curve=curve) _evenTag = "02" diff --git a/ellipticcurve/signature.py b/ellipticcurve/signature.py index 3084f8f..fc134b5 100644 --- a/ellipticcurve/signature.py +++ b/ellipticcurve/signature.py @@ -24,12 +24,17 @@ def toBase64(self, withRecoveryId=False): def fromDer(cls, string, recoveryByte=False): recoveryId = None if recoveryByte: - recoveryId = string[0] if isinstance(string[0], intTypes) else ord(string[0]) - recoveryId -= 27 + rawByte = string[0] if isinstance(string[0], intTypes) else ord(string[0]) + if not 27 <= rawByte <= 30: + raise Exception("Recovery byte must be in [27, 30], got {b}".format(b=rawByte)) + recoveryId = rawByte - 27 string = string[1:] hexadecimal = hexFromByteString(string) - return cls._fromString(string=hexadecimal, recoveryId=recoveryId) + signature = cls._fromString(string=hexadecimal, recoveryId=recoveryId) + if byteStringFromHex(signature._toString()) != string: + raise Exception("Signature is not in canonical DER form") + return signature @classmethod def fromBase64(cls, string, recoveryByte=False): diff --git a/ellipticcurve/utils/binary.py b/ellipticcurve/utils/binary.py index 348887f..02832da 100644 --- a/ellipticcurve/utils/binary.py +++ b/ellipticcurve/utils/binary.py @@ -21,8 +21,13 @@ def byteStringFromHex(hexadecimal): return safeBinaryFromHex(hexadecimal) -def numberFromByteString(byteString): - return intFromHex(hexFromByteString(byteString)) +def numberFromByteString(byteString, bitLength=None): + number = intFromHex(hexFromByteString(byteString)) + if bitLength is not None: + hashBitLen = len(byteString) * 8 + if hashBitLen > bitLength: + number >>= (hashBitLen - bitLength) + return number def base64FromByteString(byteString): diff --git a/ellipticcurve/utils/integer.py b/ellipticcurve/utils/integer.py index 180f200..02716ac 100644 --- a/ellipticcurve/utils/integer.py +++ b/ellipticcurve/utils/integer.py @@ -1,4 +1,13 @@ -from random import SystemRandom +# coding: utf-8 +from hmac import new as hmacNew +from .binary import numberFromByteString, hexFromInt, byteStringFromHex + +try: + from secrets import randbelow as _randbelow +except ImportError: + from random import SystemRandom + _systemRandom = SystemRandom() + _randbelow = lambda n: _systemRandom.randrange(n) class RandomInteger: @@ -13,4 +22,47 @@ def between(cls, min, max): :return: """ - return SystemRandom().randrange(min, max + 1) + return min + _randbelow(max - min + 1) + + @classmethod + def rfc6979(cls, hashBytes, secret, curve, hashfunc): + """Generate nonce values per hedged RFC 6979: deterministic k derivation + with fresh random entropy mixed into K-init (RFC 6979 §3.6). Same message + and key yield different signatures, while preserving RFC 6979's protection + against RNG failures.""" + orderBitLen = curve.nBitLength + orderByteLen = (orderBitLen + 7) // 8 + + secretHex = hexFromInt(secret).zfill(orderByteLen * 2) + secretBytes = byteStringFromHex(secretHex) + + hashReduced = numberFromByteString(hashBytes, orderBitLen) % curve.N + hashHex = hexFromInt(hashReduced).zfill(orderByteLen * 2) + hashOctets = byteStringFromHex(hashHex) + + extraEntropy = byteStringFromHex( + hexFromInt(cls.between(0, (1 << (orderByteLen * 8)) - 1)).zfill(orderByteLen * 2) + ) + + hLen = hashfunc().digest_size + V = b'\x01' * hLen + K = b'\x00' * hLen + + K = hmacNew(K, V + b'\x00' + secretBytes + hashOctets + extraEntropy, hashfunc).digest() + V = hmacNew(K, V, hashfunc).digest() + K = hmacNew(K, V + b'\x01' + secretBytes + hashOctets + extraEntropy, hashfunc).digest() + V = hmacNew(K, V, hashfunc).digest() + + while True: + T = b'' + while len(T) * 8 < orderBitLen: + V = hmacNew(K, V, hashfunc).digest() + T += V + + k = numberFromByteString(T, orderBitLen) + + if 1 <= k <= curve.N - 1: + yield k + + K = hmacNew(K, V + b'\x00', hashfunc).digest() + V = hmacNew(K, V, hashfunc).digest() diff --git a/ellipticcurve/utils/pem.py b/ellipticcurve/utils/pem.py index 1e58b40..622faef 100644 --- a/ellipticcurve/utils/pem.py +++ b/ellipticcurve/utils/pem.py @@ -3,7 +3,10 @@ def getPemContent(pem, template): pattern = template.format(content="(.*)") - return search("".join(pattern.splitlines()), "".join(pem.splitlines())).group(1) + match = search("".join(pattern.splitlines()), "".join(pem.splitlines())) + if match is None: + raise Exception("PEM content does not match expected template") + return match.group(1) def createPem(content, template): diff --git a/setup.py b/setup.py index 5b3c03d..0de7f4f 100644 --- a/setup.py +++ b/setup.py @@ -18,17 +18,18 @@ author="Stark Bank", author_email="developers@starkbank.com", keywords=["ecdsa", "elliptic curve", "elliptic", "curve", "stark bank", "starkbank", "cryptograph", "secp256k1", "prime256v1"], + python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*", version="2.2.0" ) -### Create a source distribution: +### Create a source distribution and a universal wheel: -#Run ```python setup.py sdist``` inside the project directory. +#Run ```python setup.py sdist bdist_wheel``` inside the project directory. -### Install twine: +### Install twine and wheel: -#```pip install twine``` +#```pip install twine wheel``` ### Upload package to pypi: diff --git a/tests/testOpenSSL.py b/tests/testOpenSSL.py index 0d4f605..a871124 100644 --- a/tests/testOpenSSL.py +++ b/tests/testOpenSSL.py @@ -1,16 +1,20 @@ +from os.path import dirname, join from unittest.case import TestCase from ellipticcurve import Ecdsa, PrivateKey, PublicKey, Signature, File +FIXTURES = dirname(__file__) + + class OpensslTest(TestCase): def testAssign(self): # Generated by: openssl ecparam -name secp256k1 -genkey -out privateKey.pem - privateKeyPem = File.read("privateKey.pem") + privateKeyPem = File.read(join(FIXTURES, "privateKey.pem")) privateKey = PrivateKey.fromPem(privateKeyPem) - message = File.read("message.txt") + message = File.read(join(FIXTURES, "message.txt")) signature = Ecdsa.sign(message=message, privateKey=privateKey) @@ -21,12 +25,12 @@ def testAssign(self): def testVerifySignature(self): # openssl ec -in privateKey.pem -pubout -out publicKey.pem - publicKeyPem = File.read("publicKey.pem") + publicKeyPem = File.read(join(FIXTURES, "publicKey.pem")) # openssl dgst -sha256 -sign privateKey.pem -out signature.binary message.txt - signatureDer = File.read("signatureDer.txt", "rb") + signatureDer = File.read(join(FIXTURES, "signatureDer.txt"), "rb") - message = File.read("message.txt") + message = File.read(join(FIXTURES, "message.txt")) publicKey = PublicKey.fromPem(publicKeyPem) diff --git a/tests/testSecurity.py b/tests/testSecurity.py new file mode 100644 index 0000000..c999d46 --- /dev/null +++ b/tests/testSecurity.py @@ -0,0 +1,371 @@ +from unittest.case import TestCase +from hashlib import sha256, sha512 +from ellipticcurve import Ecdsa, PrivateKey, PublicKey, Signature +from ellipticcurve.curve import secp256k1, prime256v1 +from ellipticcurve.point import Point +from ellipticcurve.math import Math +from ellipticcurve.utils.binary import hexFromInt + + +class Prime256v1PublicKeyDerivationTest(TestCase): + """RFC 6979 A.2.5 public key derivation. Signatures are hedged, so r/s + no longer match fixed test vectors, but pubkey derivation is unchanged.""" + + def setUp(self): + self.privateKey = PrivateKey( + curve=prime256v1, + secret=0xC9AFA9D845BA75166B5C215767B1D6934E50C3DB36E89B127B8A622B120F6721, + ) + self.publicKey = self.privateKey.publicKey() + + def testPublicKeyMatchesRfc(self): + self.assertEqual( + self.publicKey.point.x, + 0x60FED4BA255A9D31C961EB74C6356D68C049B8923B61FA6CE669622E60F29FB6, + ) + self.assertEqual( + self.publicKey.point.y, + 0x7903FE1008B8BC99A41AE9E95628BC64F2F1B20C2D7E9F5177A3C294D4462299, + ) + + def testSampleMessageRoundTrip(self): + sig = Ecdsa.sign("sample", self.privateKey) + self.assertLessEqual(sig.s, prime256v1.N // 2) + self.assertTrue(Ecdsa.verify("sample", sig, self.publicKey)) + + def testTestMessageRoundTrip(self): + sig = Ecdsa.sign("test", self.privateKey) + self.assertLessEqual(sig.s, prime256v1.N // 2) + self.assertTrue(Ecdsa.verify("test", sig, self.publicKey)) + + +class Secp256k1PublicKeyDerivationTest(TestCase): + """secp256k1 with secret=1 (pubkey = generator G).""" + + def setUp(self): + self.privateKey = PrivateKey(curve=secp256k1, secret=1) + self.publicKey = self.privateKey.publicKey() + + def testPublicKeyIsGenerator(self): + self.assertEqual(self.publicKey.point.x, secp256k1.G.x) + self.assertEqual(self.publicKey.point.y, secp256k1.G.y) + + def testSampleMessageRoundTrip(self): + sig = Ecdsa.sign("sample", self.privateKey) + self.assertTrue(Ecdsa.verify("sample", sig, self.publicKey)) + + def testTestMessageRoundTrip(self): + sig = Ecdsa.sign("test", self.privateKey) + self.assertTrue(Ecdsa.verify("test", sig, self.publicKey)) + + +class MalleabilityTest(TestCase): + + def testSignAlwaysProducesLowS(self): + for _ in range(100): + privateKey = PrivateKey() + signature = Ecdsa.sign("test message", privateKey) + self.assertLessEqual(signature.s, privateKey.curve.N // 2) + + def testHighSSignatureStillVerifies(self): + """verify() accepts high-s for OpenSSL compatibility; sign() prevents malleability""" + privateKey = PrivateKey() + publicKey = privateKey.publicKey() + message = "test message" + + signature = Ecdsa.sign(message, privateKey) + highS = Signature(r=signature.r, s=privateKey.curve.N - signature.s) + + self.assertTrue(Ecdsa.verify(message, signature, publicKey)) + self.assertTrue(Ecdsa.verify(message, highS, publicKey)) + + +class PublicKeyValidationTest(TestCase): + + def testRejectOffCurvePublicKey(self): + privateKey = PrivateKey() + publicKey = privateKey.publicKey() + message = "test message" + + signature = Ecdsa.sign(message, privateKey) + + offCurvePoint = Point(publicKey.point.x, publicKey.point.y + 1) + offCurveKey = PublicKey(point=offCurvePoint, curve=publicKey.curve) + + self.assertFalse(Ecdsa.verify(message, signature, offCurveKey)) + + def testFromStringRejectsOffCurvePoint(self): + p = PrivateKey().publicKey() + badY = hexFromInt(p.point.y + 1).zfill(2 * p.curve.length()) + badHex = hexFromInt(p.point.x).zfill(2 * p.curve.length()) + badY + with self.assertRaises(Exception): + PublicKey.fromString(badHex, curve=p.curve) + + def testFromStringRejectsInfinityPoint(self): + zeroHex = "00" * (2 * secp256k1.length()) + with self.assertRaises(Exception): + PublicKey.fromString(zeroHex, curve=secp256k1) + + +class ForgeryAttemptTest(TestCase): + + def setUp(self): + self.privateKey = PrivateKey() + self.publicKey = self.privateKey.publicKey() + self.message = "authentic message" + self.signature = Ecdsa.sign(self.message, self.privateKey) + + def testRejectZeroSignature(self): + self.assertFalse(Ecdsa.verify(self.message, Signature(0, 0), self.publicKey)) + + def testRejectREqualsZero(self): + self.assertFalse(Ecdsa.verify(self.message, Signature(0, self.signature.s), self.publicKey)) + + def testRejectSEqualsZero(self): + self.assertFalse(Ecdsa.verify(self.message, Signature(self.signature.r, 0), self.publicKey)) + + def testRejectREqualsN(self): + N = self.publicKey.curve.N + self.assertFalse(Ecdsa.verify(self.message, Signature(N, self.signature.s), self.publicKey)) + + def testRejectSEqualsN(self): + N = self.publicKey.curve.N + self.assertFalse(Ecdsa.verify(self.message, Signature(self.signature.r, N), self.publicKey)) + + def testRejectRExceedsN(self): + N = self.publicKey.curve.N + self.assertFalse(Ecdsa.verify(self.message, Signature(N + 1, self.signature.s), self.publicKey)) + + def testRejectArbitrarySignature(self): + self.assertFalse(Ecdsa.verify(self.message, Signature(1, 1), self.publicKey)) + + def testRejectBoundarySignature(self): + N = self.publicKey.curve.N + self.assertFalse(Ecdsa.verify(self.message, Signature(N - 1, N - 1), self.publicKey)) + + def testWrongKeyRejected(self): + otherKey = PrivateKey().publicKey() + self.assertFalse(Ecdsa.verify(self.message, self.signature, otherKey)) + + +class HedgedSignatureTest(TestCase): + + def testSameInputsProduceDifferentSignatures(self): + privateKey = PrivateKey() + message = "test message" + + signature1 = Ecdsa.sign(message, privateKey) + signature2 = Ecdsa.sign(message, privateKey) + + self.assertTrue(signature1.r != signature2.r or signature1.s != signature2.s) + + def testDifferentMessagesDifferentSignatures(self): + privateKey = PrivateKey() + + signature1 = Ecdsa.sign("message 1", privateKey) + signature2 = Ecdsa.sign("message 2", privateKey) + + self.assertTrue(signature1.r != signature2.r or signature1.s != signature2.s) + + def testDifferentKeysDifferentSignatures(self): + message = "test message" + + signature1 = Ecdsa.sign(message, PrivateKey()) + signature2 = Ecdsa.sign(message, PrivateKey()) + + self.assertTrue(signature1.r != signature2.r or signature1.s != signature2.s) + + +class EdgeCaseMessageTest(TestCase): + + def setUp(self): + self.privateKey = PrivateKey() + self.publicKey = self.privateKey.publicKey() + + def _signAndVerify(self, message): + sig = Ecdsa.sign(message, self.privateKey) + self.assertTrue(Ecdsa.verify(message, sig, self.publicKey)) + self.assertFalse(Ecdsa.verify(message + "x", sig, self.publicKey)) + + def testEmptyMessage(self): + self._signAndVerify("") + + def testSingleCharMessage(self): + self._signAndVerify("a") + + def testUnicodeMessage(self): + self._signAndVerify("\u00e9\u00e8\u00ea\u00eb") + + def testEmojiMessage(self): + self._signAndVerify("\U0001f512\U0001f511") + + def testNullByteMessage(self): + self._signAndVerify("before\x00after") + + def testLongMessage(self): + self._signAndVerify("a" * 10000) + + def testNewlinesAndWhitespace(self): + self._signAndVerify(" line1\n\tline2\r\n ") + + +class SerializationRoundTripTest(TestCase): + + def setUp(self): + self.privateKey = PrivateKey() + self.publicKey = self.privateKey.publicKey() + self.message = "round-trip test" + self.signature = Ecdsa.sign(self.message, self.privateKey) + + def testSignatureDerRoundTrip(self): + der = self.signature.toDer() + restored = Signature.fromDer(der) + self.assertEqual(restored.r, self.signature.r) + self.assertEqual(restored.s, self.signature.s) + self.assertTrue(Ecdsa.verify(self.message, restored, self.publicKey)) + + def testSignatureBase64RoundTrip(self): + b64 = self.signature.toBase64() + restored = Signature.fromBase64(b64) + self.assertEqual(restored.r, self.signature.r) + self.assertEqual(restored.s, self.signature.s) + self.assertTrue(Ecdsa.verify(self.message, restored, self.publicKey)) + + def testSignatureDerWithRecoveryIdRoundTrip(self): + der = self.signature.toDer(withRecoveryId=True) + restored = Signature.fromDer(der, recoveryByte=True) + self.assertEqual(restored.r, self.signature.r) + self.assertEqual(restored.s, self.signature.s) + self.assertEqual(restored.recoveryId, self.signature.recoveryId) + + def testPrivateKeyPemRoundTrip(self): + pem = self.privateKey.toPem() + restored = PrivateKey.fromPem(pem) + self.assertEqual(restored.secret, self.privateKey.secret) + self.assertEqual(restored.curve.name, self.privateKey.curve.name) + + def testPrivateKeyDerRoundTrip(self): + der = self.privateKey.toDer() + restored = PrivateKey.fromDer(der) + self.assertEqual(restored.secret, self.privateKey.secret) + + def testPublicKeyPemRoundTrip(self): + pem = self.publicKey.toPem() + restored = PublicKey.fromPem(pem) + self.assertEqual(restored.point.x, self.publicKey.point.x) + self.assertEqual(restored.point.y, self.publicKey.point.y) + + def testPublicKeyCompressedRoundTrip(self): + compressed = self.publicKey.toCompressed() + restored = PublicKey.fromCompressed(compressed, curve=self.publicKey.curve) + self.assertEqual(restored.point.x, self.publicKey.point.x) + self.assertEqual(restored.point.y, self.publicKey.point.y) + self.assertTrue(Ecdsa.verify(self.message, self.signature, restored)) + + def testPublicKeyCompressedEvenAndOdd(self): + """Ensure both even-y and odd-y keys round-trip through compression""" + for _ in range(20): + pk = PrivateKey() + pub = pk.publicKey() + compressed = pub.toCompressed() + restored = PublicKey.fromCompressed(compressed, curve=pub.curve) + self.assertEqual(restored.point.x, pub.point.x) + self.assertEqual(restored.point.y, pub.point.y) + + def testPrime256v1KeyRoundTrip(self): + pk = PrivateKey(curve=prime256v1) + pem = pk.toPem() + restored = PrivateKey.fromPem(pem) + self.assertEqual(restored.secret, pk.secret) + self.assertEqual(restored.curve.name, "prime256v1") + + +class TonelliShanksTest(TestCase): + + def testPrimeCongruent1Mod4(self): + # P = 17: 17 - 1 = 16 = 2^4, S = 4, exercises full Tonelli-Shanks + P = 17 + for value in range(1, P): + if pow(value, (P - 1) // 2, P) == 1: + root = Math.modularSquareRoot(value, P) + self.assertEqual((root * root) % P, value) + + def testPrimeCongruent5Mod8(self): + # P = 13: 13 - 1 = 12 = 3 * 2^2, S = 2 + P = 13 + for value in range(1, P): + if pow(value, (P - 1) // 2, P) == 1: + root = Math.modularSquareRoot(value, P) + self.assertEqual((root * root) % P, value) + + def testPrimeCongruent3Mod4(self): + # P = 7: fast path (S = 1) + P = 7 + for value in range(1, P): + if pow(value, (P - 1) // 2, P) == 1: + root = Math.modularSquareRoot(value, P) + self.assertEqual((root * root) % P, value) + + def testZeroValue(self): + self.assertEqual(Math.modularSquareRoot(0, 17), 0) + + +class HashTruncationTest(TestCase): + + def testSignVerifyWithSha512(self): + privateKey = PrivateKey() + publicKey = privateKey.publicKey() + message = "test message" + + signature = Ecdsa.sign(message, privateKey, hashfunc=sha512) + + self.assertTrue(Ecdsa.verify(message, signature, publicKey, hashfunc=sha512)) + self.assertFalse(Ecdsa.verify("wrong message", signature, publicKey, hashfunc=sha512)) + + def testSha512SignaturesAreHedged(self): + privateKey = PrivateKey() + message = "test message" + + signature1 = Ecdsa.sign(message, privateKey, hashfunc=sha512) + signature2 = Ecdsa.sign(message, privateKey, hashfunc=sha512) + + self.assertTrue(signature1.r != signature2.r or signature1.s != signature2.s) + + def testHashMismatchFails(self): + privateKey = PrivateKey() + publicKey = privateKey.publicKey() + message = "test message" + + signature = Ecdsa.sign(message, privateKey, hashfunc=sha256) + self.assertFalse(Ecdsa.verify(message, signature, publicKey, hashfunc=sha512)) + + +class Prime256v1SecurityTest(TestCase): + + def testSignVerify(self): + privateKey = PrivateKey(curve=prime256v1) + publicKey = privateKey.publicKey() + message = "test message" + + signature = Ecdsa.sign(message, privateKey) + + self.assertLessEqual(signature.s, prime256v1.N // 2) + self.assertTrue(Ecdsa.verify(message, signature, publicKey)) + + def testSignaturesAreHedged(self): + privateKey = PrivateKey(curve=prime256v1) + message = "test message" + + signature1 = Ecdsa.sign(message, privateKey) + signature2 = Ecdsa.sign(message, privateKey) + + self.assertTrue(signature1.r != signature2.r or signature1.s != signature2.s) + + def testWrongCurveKeyFails(self): + """A signature made with secp256k1 should not verify with a prime256v1 key""" + k1Key = PrivateKey(curve=secp256k1) + p256Key = PrivateKey(curve=prime256v1) + message = "cross-curve test" + + sig = Ecdsa.sign(message, k1Key) + self.assertFalse(Ecdsa.verify(message, sig, p256Key.publicKey()))