diff --git a/README.md b/README.md index 14a388d..2805dc8 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,19 @@ -## A lightweight and fast ECDSA +## A lightweight and fast pure Java ECDSA ### Overview -This is a pure Java implementation of the Elliptic Curve Digital Signature Algorithm (ECDSA). It is compatible with Java 8+ and OpenSSL. It uses some elegant math such as Jacobian Coordinates to speed up the ECDSA. +This is a pure Java implementation of the Elliptic Curve Digital Signature Algorithm (ECDSA). It is compatible with Java 8+ and OpenSSL. It uses elegant math such as Jacobian Coordinates to speed up the ECDSA on pure Java. + +### Security + +starkbank-ecdsa includes the following security features: + +- **RFC 6979 deterministic nonces**: Eliminates the catastrophic risk of nonce reuse that leaks private keys +- **Low-S signature normalization**: Prevents signature malleability (BIP-62) +- **Public key on-curve validation**: Blocks invalid-curve attacks during verification +- **Montgomery ladder scalar multiplication**: Constant-operation point multiplication to mitigate timing side channels +- **Hash truncation**: Correctly handles hash functions larger than the curve order (e.g. SHA-512 with secp256k1) +- **Fermat's little theorem for modular inverse**: More uniform execution time than the extended Euclidean algorithm ### Installation @@ -13,7 +24,7 @@ In pom.xml: com.starkbank starkbank-ecdsa - 1.0.2 + 2.0.0 ``` @@ -24,20 +35,21 @@ mvn clean install ### Curves -We currently support `secp256k1`, but it's super easy to add more curves to the project. Just add them on `Curve.java` +We currently support `secp256k1` and `prime256v1` (P-256), but you can add more curves to the project. Just use `Curve.add()`. ### Speed -We ran a test on JDK 13.0.1 on a MAC Pro i5 2019. The libraries ran 100 times and showed the average times displayed bellow: +We ran a test on JDK 21.0.10 on a MAC Pro. The libraries were run 100 times and the averages displayed below were obtained: | Library | sign | verify | | ------------------ |:-------------:| -------:| -| [java.security] | 0.9ms | 2.4ms | -| starkbank-ecdsa | 4.3ms | 9.9ms | +| starkbank-ecdsa | 0.8ms | 1.3ms | + +Performance is driven by Jacobian coordinates, a branch-balanced Montgomery ladder for variable-base scalar multiplication, a precomputed affine table of powers-of-two multiples of the generator (`[G, 2G, 4G, ..., 2^n*G]`) combined with a width-2 NAF of the scalar to eliminate doublings during signing, a mixed affine+Jacobian addition fast path, curve-specific shortcuts in point doubling (A=0 for secp256k1, A=-3 for prime256v1), the secp256k1 GLV endomorphism to split 256-bit scalars into two ~128-bit halves for a 4-scalar simultaneous multi-exponentiation during verification, Shamir's trick with Joint Sparse Form as the fallback path for curves without an efficient endomorphism, and the extended Euclidean algorithm for modular inversion. ### Sample Code -How to use it: +How to sign a json message for [Stark Bank]: ```java import com.starkbank.ellipticcurve.PrivateKey; @@ -46,26 +58,93 @@ import com.starkbank.ellipticcurve.Signature; import com.starkbank.ellipticcurve.Ecdsa; -public class GenerateKeys{ +// Generate privateKey from PEM string +PrivateKey privateKey = PrivateKey.fromPem("-----BEGIN EC PRIVATE KEY-----\n...\n-----END EC PRIVATE KEY-----"); + +String message = "{\"transfers\": [{\"amount\": 100000000}]}"; + +Signature signature = Ecdsa.sign(message, privateKey); + +// Generate Signature in base64 +System.out.println(signature.toBase64()); + +// To double check if the message matches the signature: +PublicKey publicKey = privateKey.publicKey(); +System.out.println(Ecdsa.verify(message, signature, publicKey)); +``` + +Simple use: + +```java +import com.starkbank.ellipticcurve.PrivateKey; +import com.starkbank.ellipticcurve.PublicKey; +import com.starkbank.ellipticcurve.Signature; +import com.starkbank.ellipticcurve.Ecdsa; + + +// Generate new Keys +PrivateKey privateKey = new PrivateKey(); +PublicKey publicKey = privateKey.publicKey(); + +String message = "My test message"; + +// Generate Signature +Signature signature = Ecdsa.sign(message, privateKey); + +// To verify if the signature is valid +System.out.println(Ecdsa.verify(message, signature, publicKey)); +``` + +How to add more curves: + +```java +import com.starkbank.ellipticcurve.Curve; +import com.starkbank.ellipticcurve.PrivateKey; +import com.starkbank.ellipticcurve.PublicKey; +import java.math.BigInteger; + +Curve newCurve = new Curve( + new BigInteger("f1fd178c0b3ad58f10126de8ce42435b3961adbcabc8ca6de8fcf353d86e9c00", 16), + new BigInteger("ee353fca5428a9300d4aba754a44c00fdfec0c9ae4b1a1803075ed967b7bb73f", 16), + new BigInteger("f1fd178c0b3ad58f10126de8ce42435b3961adbcabc8ca6de8fcf353d86e9c03", 16), + new BigInteger("f1fd178c0b3ad58f10126de8ce42435b53dc67e140d2bf941ffdd459c6d655e1", 16), + new BigInteger("b6b3d4c356c139eb31183d4749d423958c27d2dcaf98b70164c97a2dd98f5cff", 16), + new BigInteger("6142e0f7c8b204911f9271f0f3ecef8c2701c307e8e4c9e183115a1554062cfb", 16), + "frp256v1", + new long[]{1, 2, 250, 1, 223, 101, 256, 1} +); + +Curve.add(newCurve); + +PrivateKey privateKey = new PrivateKey(newCurve, null); +PublicKey publicKey = privateKey.publicKey(); +System.out.println(publicKey.toPem()); +``` + +How to generate compressed public key: - public static void main(String[] args){ - // Generate Keys - PrivateKey privateKey = new PrivateKey(); - PublicKey publicKey = privateKey.publicKey(); +```java +import com.starkbank.ellipticcurve.PrivateKey; +import com.starkbank.ellipticcurve.PublicKey; - String message = "Testing message"; - // Generate Signature - Signature signature = Ecdsa.sign(message, privateKey); +PrivateKey privateKey = new PrivateKey(); +PublicKey publicKey = privateKey.publicKey(); +String compressedPublicKey = publicKey.toCompressed(); - // Verify if signature is valid - boolean verified = Ecdsa.verify(message, signature, publicKey) ; +System.out.println(compressedPublicKey); +``` + +How to recover a compressed public key: + +```java +import com.starkbank.ellipticcurve.PublicKey; - // Return the signature verification status - System.out.println("Verified: " + verified); +String compressedPublicKey = "0252972572d465d016d4c501887b8df303eee3ed602c056b1eb09260dfa0da0ab2"; +PublicKey publicKey = PublicKey.fromCompressed(compressedPublicKey); - } -} +System.out.println(publicKey.toPem()); ``` + ### OpenSSL This library is compatible with OpenSSL, so you can use it to generate keys: @@ -78,10 +157,10 @@ openssl ec -in privateKey.pem -pubout -out publicKey.pem Create a message.txt file and sign it: ``` -openssl dgst -sha256 -sign privateKey.pem -out signatureBinary.txt message.txt +openssl dgst -sha256 -sign privateKey.pem -out signatureDer.txt message.txt ``` -It's time to verify: +To verify, do this: ```java import com.starkbank.ellipticcurve.Ecdsa; @@ -90,65 +169,52 @@ import com.starkbank.ellipticcurve.Signature; import com.starkbank.ellipticcurve.utils.ByteString; import com.starkbank.ellipticcurve.utils.File; +String publicKeyPem = File.read("publicKey.pem"); +byte[] signatureBin = File.readBytes("signatureDer.txt"); +String message = File.read("message.txt"); -public class VerifyKeys { +PublicKey publicKey = PublicKey.fromPem(publicKeyPem); +Signature signature = Signature.fromDer(new ByteString(signatureBin)); - public static void main(String[] args){ - // Read files - String publicKeyPem = File.read("publicKey.pem"); - byte[] signatureBin = File.readBytes("signatureBinary.txt"); - String message = File.read("message.txt"); - - ByteString byteString = new ByteString(signatureBin); - - PublicKey publicKey = PublicKey.fromPem(publicKeyPem); - Signature signature = Signature.fromDer(byteString); - - // Get verification status: - boolean verified = Ecdsa.verify(message, signature, publicKey); - System.out.println("Verification status: " + verified); - } -} +System.out.println(Ecdsa.verify(message, signature, publicKey)); ``` You can also verify it on terminal: ``` -openssl dgst -sha256 -verify publicKey.pem -signature signatureBinary.txt message.txt +openssl dgst -sha256 -verify publicKey.pem -signature signatureDer.txt message.txt ``` -NOTE: If you want to create a Digital Signature to use in the [Stark Bank], you need to convert the binary signature to base64. +NOTE: If you want to create a Digital Signature to use with [Stark Bank], you need to convert the binary signature to base64. ``` -openssl base64 -in signatureBinary.txt -out signatureBase64.txt +openssl base64 -in signatureDer.txt -out signatureBase64.txt ``` -You can also verify it with this library: +You can do the same with this library: ```java -import com.starkbank.ellipticcurve.utils.ByteString; import com.starkbank.ellipticcurve.Signature; +import com.starkbank.ellipticcurve.utils.ByteString; import com.starkbank.ellipticcurve.utils.File; +byte[] signatureBin = File.readBytes("signatureDer.txt"); +Signature signature = Signature.fromDer(new ByteString(signatureBin)); -public class GenerateSignature { - - public static void main(String[] args) { - // Load signature file - byte[] signatureBin = File.readBytes("signatureBinary.txt"); - Signature signature = Signature.fromDer(new ByteString(signatureBin)); - // Print signature - System.out.println(signature.toBase64()); - } -} +System.out.println(signature.toBase64()); ``` -[Stark Bank]: https://starkbank.com +### Run unit tests -### Run all unit tests ```shell gradle test ``` -[ecdsa-python]: https://github.com/starkbank/ecdsa-python +### Run benchmark + +```shell +gradle run +``` + +[Stark Bank]: https://starkbank.com [java.security]: https://docs.oracle.com/javase/7/docs/api/index.html diff --git a/build.gradle b/build.gradle index e2d8b94..d20b601 100644 --- a/build.gradle +++ b/build.gradle @@ -3,104 +3,41 @@ plugins { } group 'com.starkbank.ellipticcurve' -version '1.0.2' +version '2.0.0' -sourceCompatibility = 1.7 +sourceCompatibility = 1.8 +targetCompatibility = 1.8 repositories { mavenCentral() } dependencies { - testCompile group: 'junit', name: 'junit', version: '4.12' - testCompile group: 'org.mockito', name: 'mockito-core', version: '2.22.0' + testImplementation 'junit:junit:4.13.2' } -apply plugin: 'maven' -apply plugin: 'signing' +test { + testLogging { + events "passed", "failed", "skipped" + showExceptions true + showCauses true + showStackTraces true + exceptionFormat "full" + } +} archivesBaseName = "starkbank-ecdsa" task javadocJar(type: Jar) { - classifier = 'javadoc' + archiveClassifier = 'javadoc' from javadoc } task sourcesJar(type: Jar) { - classifier = 'sources' + archiveClassifier = 'sources' from sourceSets.main.allSource } artifacts { archives javadocJar, sourcesJar } - -signing { - sign configurations.archives -} - -uploadArchives { - repositories { - mavenDeployer { - beforeDeployment { MavenDeployment deployment -> signing.signPom(deployment) } - - repository(url: "https://oss.sonatype.org/service/local/staging/deploy/maven2/") { - authentication( - userName: project.hasProperty('ossrhUsername') ? project.property('ossrhUsername') : 'username', - password: project.hasProperty('ossrhPassword') ? project.property('ossrhPassword') : 'password' - ) - } - - snapshotRepository(url: "https://oss.sonatype.org/content/repositories/snapshots/") { - authentication( - userName: project.hasProperty('ossrhUsername') ? project.property('ossrhUsername') : 'username', - password: project.hasProperty('ossrhPassword') ? project.property('ossrhPassword') : 'password' - ) - } - - pom.project { - name 'StarkBank ECDSA library' - packaging 'jar' - // optionally artifactId can be defined here - description 'Pure Java ECDSA library by Stark Bank' - url 'https://github.com/starkbank/ecdsa-java' - - scm { - connection 'scm:git:git://github.com/starkbank/ecdsa-java.git' - developerConnection 'scm:git:ssh://github.com/starkbank/ecdsa-java.git' - url 'https://github.com/starkbank/ecdsa-java/' - } - - licenses { - license { - name 'MIT License' - url 'https://github.com/starkbank/sdk-java/blob/master/LICENSE' - } - } - - developers { - developer { - id 'rcmstark' - name 'Rafael Stark' - email 'rafael@starkbank.com' - } - developer { - id 'daltonmenezes' - name 'Dalton Menezes' - email 'dalton.menezes@starkbank.com' - } - developer { - id 'cdottori' - name 'Caio Dottori' - email 'caio.dottori@starkbank.com' - } - developer { - id 'thalesmello' - name 'Thales Mello' - email 'thalesmello@gmail.com' - } - } - } - } - } -} diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 6c9a224..5c82cb0 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.6-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.14-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/src/main/java/com/starkbank/ellipticcurve/Benchmark.java b/src/main/java/com/starkbank/ellipticcurve/Benchmark.java new file mode 100644 index 0000000..ce12802 --- /dev/null +++ b/src/main/java/com/starkbank/ellipticcurve/Benchmark.java @@ -0,0 +1,38 @@ +package com.starkbank.ellipticcurve; + + +public class Benchmark { + + private static final int ROUNDS = 100; + + public static void main(String[] args) { + PrivateKey privateKey = new PrivateKey(); + PublicKey publicKey = privateKey.publicKey(); + String message = "This is a benchmark test message"; + + // Warmup + Signature sig = Ecdsa.sign(message, privateKey); + Ecdsa.verify(message, sig, publicKey); + + // Benchmark sign + long start = System.nanoTime(); + for (int i = 0; i < ROUNDS; i++) { + sig = Ecdsa.sign(message, privateKey); + } + double signTime = (System.nanoTime() - start) / 1e6 / ROUNDS; + + // Benchmark verify + start = System.nanoTime(); + for (int i = 0; i < ROUNDS; i++) { + Ecdsa.verify(message, sig, publicKey); + } + double verifyTime = (System.nanoTime() - start) / 1e6 / ROUNDS; + + System.out.println(); + System.out.printf("starkbank-ecdsa benchmark (%d rounds)%n", ROUNDS); + System.out.println("---------------------------------------"); + System.out.printf("sign: %.1fms%n", signTime); + System.out.printf("verify: %.1fms%n", verifyTime); + System.out.println(); + } +} diff --git a/src/main/java/com/starkbank/ellipticcurve/Curve.java b/src/main/java/com/starkbank/ellipticcurve/Curve.java index 5d285c0..c4533e9 100644 --- a/src/main/java/com/starkbank/ellipticcurve/Curve.java +++ b/src/main/java/com/starkbank/ellipticcurve/Curve.java @@ -2,24 +2,57 @@ import java.math.BigInteger; import java.util.*; + /** * Elliptic Curve Equation. * y^2 = x^3 + A*x + B (mod P) - * */ - public class Curve { + /** + * GLV endomorphism parameters for curves that support one (e.g. secp256k1). + * phi((x, y)) = (beta * x mod P, y) corresponds to lambda * P. Basis vectors + * (a1, b1), (a2, b2) from Gauss reduction used to split a 256-bit scalar k + * into two ~128-bit scalars (k1, k2) with k = k1 + k2*lambda (mod N). + */ + public static final class GLVParams { + public final BigInteger beta; + public final BigInteger lambda; + public final BigInteger a1; + public final BigInteger b1; + public final BigInteger a2; + public final BigInteger b2; + + public GLVParams(BigInteger beta, BigInteger lambda, + BigInteger a1, BigInteger b1, + BigInteger a2, BigInteger b2) { + this.beta = beta; + this.lambda = lambda; + this.a1 = a1; + this.b1 = b1; + this.a2 = a2; + this.b2 = b2; + } + } + public BigInteger A; public BigInteger B; public BigInteger P; public BigInteger N; + public int nBitLength; public Point G; public String name; + public String nistName; public long[] oid; + // null means no endomorphism; fall back to Shamir + JSF. + public GLVParams glvParams; + + // Precomputed window table for fixed-base generator multiplication. + // Lazily populated by Math.generatorTable and published via a volatile + // store; safe to read without locks (Points are effectively immutable). + volatile Point[] generatorTable; /** - * * @param A A * @param B B * @param P P @@ -30,39 +63,72 @@ public class Curve { * @param oid oid */ public Curve(BigInteger A, BigInteger B, BigInteger P, BigInteger N, BigInteger Gx, BigInteger Gy, String name, long[] oid) { + this(A, B, P, N, Gx, Gy, name, oid, null); + } + + /** + * @param A A + * @param B B + * @param P P + * @param N N + * @param Gx Gx + * @param Gy Gy + * @param name name + * @param oid oid + * @param nistName nistName + */ + public Curve(BigInteger A, BigInteger B, BigInteger P, BigInteger N, BigInteger Gx, BigInteger Gy, String name, long[] oid, String nistName) { + this(A, B, P, N, Gx, Gy, name, oid, nistName, null); + } + + /** + * @param A A + * @param B B + * @param P P + * @param N N + * @param Gx Gx + * @param Gy Gy + * @param name name + * @param oid oid + * @param nistName nistName + * @param glvParams GLV endomorphism parameters, or null + */ + public Curve(BigInteger A, BigInteger B, BigInteger P, BigInteger N, BigInteger Gx, BigInteger Gy, String name, long[] oid, String nistName, GLVParams glvParams) { this.A = A; this.B = B; this.P = P; this.N = N; + this.nBitLength = N.bitLength(); this.G = new Point(Gx, Gy); this.name = name; + this.nistName = nistName; this.oid = oid; + this.glvParams = glvParams; } /** * Verify if the point `p` is on the curve * * @param p Point p = Point(x, y) - * @return true if point is in the curve otherwise false + * @return true if point is on the curve otherwise false */ public boolean contains(Point p) { - if (p.x.compareTo(BigInteger.ZERO) < 0) { + if (p.x.compareTo(BigInteger.ZERO) < 0 || p.x.compareTo(this.P.subtract(BigInteger.ONE)) > 0) { return false; } - if (p.x.compareTo(this.P) >= 0) { + if (p.y.compareTo(BigInteger.ZERO) < 0 || p.y.compareTo(this.P.subtract(BigInteger.ONE)) > 0) { return false; } - if (p.y.compareTo(BigInteger.ZERO) < 0) { - return false; - } - if (p.y.compareTo(this.P) >= 0) { - return false; - } - return p.y.pow(2).subtract(p.x.pow(3).add(A.multiply(p.x)).add(B)).mod(P).intValue() == 0; + // y^2 - (x^3 + A*x + B) mod P == 0 + BigInteger lhs = p.y.modPow(BigInteger.TWO, this.P); + BigInteger rhs = p.x.modPow(BigInteger.valueOf(3), this.P) + .add(this.A.multiply(p.x)) + .add(this.B) + .mod(this.P); + return lhs.equals(rhs); } /** - * * @return int */ public int length() { @@ -70,8 +136,24 @@ public int length() { } /** + * Compute the y coordinate for a given x on the curve * + * @param x the x coordinate + * @param isEven whether the y coordinate should be even + * @return the y coordinate */ + public BigInteger y(BigInteger x, boolean isEven) { + BigInteger ySquared = x.modPow(BigInteger.valueOf(3), this.P) + .add(this.A.multiply(x)) + .add(this.B) + .mod(this.P); + BigInteger y = Math.modularSquareRoot(ySquared, this.P); + if (isEven != y.mod(BigInteger.TWO).equals(BigInteger.ZERO)) { + y = this.P.subtract(y); + } + return y; + } + public static final Curve secp256k1 = new Curve( BigInteger.ZERO, BigInteger.valueOf(7), @@ -80,25 +162,73 @@ public int length() { new BigInteger("79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798", 16), new BigInteger("483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8", 16), "secp256k1", - new long[]{1, 3, 132, 0, 10} + new long[]{1, 3, 132, 0, 10}, + null, + // GLV endomorphism phi((x, y)) = (beta * x, y), equivalent to lambda * P. + // Basis vectors from Gauss reduction; used to split a 256-bit scalar k + // into two ~128-bit scalars (k1, k2) with k = k1 + k2 * lambda (mod N). + new GLVParams( + new BigInteger("7ae96a2b657c07106e64479eac3434e99cf0497512f58995c1396c28719501ee", 16), + new BigInteger("5363ad4cc05c30e0a5261c028812645a122e22ea20816678df02967c1b23bd72", 16), + new BigInteger("3086d221a7d46bcde86c90e49284eb15", 16), + new BigInteger("-e4437ed6010e88286f547fa90abfe4c3", 16), + new BigInteger("114ca50f7a8e2f3f657c1108d9d44cfd8", 16), + new BigInteger("3086d221a7d46bcde86c90e49284eb15", 16) + ) + ); + + public static final Curve prime256v1 = new Curve( + new BigInteger("ffffffff00000001000000000000000000000000fffffffffffffffffffffffc", 16), + new BigInteger("5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b", 16), + new BigInteger("ffffffff00000001000000000000000000000000ffffffffffffffffffffffff", 16), + new BigInteger("ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551", 16), + new BigInteger("6b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c296", 16), + new BigInteger("4fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f5", 16), + "prime256v1", + new long[]{1, 2, 840, 10045, 3, 1, 7}, + "P-256" ); + public static final Curve p256 = prime256v1; + + public static final List supportedCurves = new ArrayList<>(); + + public static final Map curvesByOid = new HashMap<>(); + + static { + add(secp256k1); + add(prime256v1); + } + /** + * Register a curve so it can be looked up by OID * + * @param curve the curve to register */ - public static final List supportedCurves = new ArrayList(); + public static void add(Curve curve) { + supportedCurves.add(curve); + curvesByOid.put(Arrays.hashCode(curve.oid), curve); + } /** + * Look up a curve by OID * + * @param oid the OID to look up + * @return the curve */ - public static final Map curvesByOid = new HashMap(); - - static { - supportedCurves.add(secp256k1); - - for (Object c : supportedCurves) { - Curve curve = (Curve) c; - curvesByOid.put(Arrays.hashCode(curve.oid), curve); + public static Curve getByOid(long[] oid) { + Curve curve = curvesByOid.get(Arrays.hashCode(oid)); + if (curve == null) { + StringBuilder names = new StringBuilder(); + for (int i = 0; i < supportedCurves.size(); i++) { + if (i > 0) names.append(", "); + names.append(supportedCurves.get(i).name); + } + throw new RuntimeException(String.format( + "Unknown curve with oid %s; The following are registered: %s", + Arrays.toString(oid), names.toString() + )); } + return curve; } } diff --git a/src/main/java/com/starkbank/ellipticcurve/Ecdsa.java b/src/main/java/com/starkbank/ellipticcurve/Ecdsa.java index a3cf289..75279a7 100644 --- a/src/main/java/com/starkbank/ellipticcurve/Ecdsa.java +++ b/src/main/java/com/starkbank/ellipticcurve/Ecdsa.java @@ -1,33 +1,55 @@ package com.starkbank.ellipticcurve; -import com.starkbank.ellipticcurve.utils.BinaryAscii; import com.starkbank.ellipticcurve.utils.RandomInteger; import java.math.BigInteger; +import java.nio.charset.StandardCharsets; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; +import java.util.Iterator; public class Ecdsa { /** + * Sign a message using the private key with a specified hash function. * * @param message message * @param privateKey privateKey * @param hashfunc hashfunc * @return Signature */ - public static Signature sign(String message, PrivateKey privateKey, MessageDigest hashfunc) { - byte[] hashMessage = hashfunc.digest(message.getBytes()); - BigInteger numberMessage = BinaryAscii.numberFromString(hashMessage); Curve curve = privateKey.curve; - BigInteger randNum = RandomInteger.between(BigInteger.ONE, curve.N); - Point randomSignPoint = Math.multiply(curve.G, randNum, curve.N, curve.A, curve.P); - BigInteger r = randomSignPoint.x.mod(curve.N); - BigInteger s = ((numberMessage.add(r.multiply(privateKey.secret))).multiply(Math.inv(randNum, curve.N))).mod(curve.N); - return new Signature(r, s); + byte[] byteMessage = hashfunc.digest(message.getBytes(StandardCharsets.UTF_8)); + BigInteger numberMessage = RandomInteger.numberFromByteString(byteMessage, curve.nBitLength); + + String hmacAlgorithm = getHmacAlgorithm(hashfunc.getAlgorithm()); + Iterator kIterator = RandomInteger.rfc6979(byteMessage, privateKey.secret, curve, hmacAlgorithm); + + BigInteger r = BigInteger.ZERO, s = BigInteger.ZERO; + Point randSignPoint = null; + while (r.equals(BigInteger.ZERO) || s.equals(BigInteger.ZERO)) { + BigInteger randNum = kIterator.next(); + randSignPoint = Math.multiplyGenerator(curve, randNum); + r = randSignPoint.x.mod(curve.N); + s = numberMessage.add(r.multiply(privateKey.secret)).multiply(Math.inv(randNum, curve.N)).mod(curve.N); + } + + int recoveryId = randSignPoint.y.testBit(0) ? 1 : 0; + if (randSignPoint.y.compareTo(curve.N) > 0) { + recoveryId += 2; + } + // Low-S normalization + BigInteger halfN = curve.N.shiftRight(1); + if (s.compareTo(halfN) > 0) { + s = curve.N.subtract(s); + recoveryId ^= 1; + } + + return new Signature(r, s, recoveryId); } /** + * Sign a message using the private key with SHA-256. * * @param message message * @param privateKey privateKey @@ -42,6 +64,7 @@ public static Signature sign(String message, PrivateKey privateKey) { } /** + * Verify a signature against a message and public key with a specified hash function. * * @param message message * @param signature signature @@ -50,29 +73,28 @@ public static Signature sign(String message, PrivateKey privateKey) { * @return boolean */ public static boolean verify(String message, Signature signature, PublicKey publicKey, MessageDigest hashfunc) { - byte[] hashMessage = hashfunc.digest(message.getBytes()); - BigInteger numberMessage = BinaryAscii.numberFromString(hashMessage); Curve curve = publicKey.curve; + byte[] byteMessage = hashfunc.digest(message.getBytes(StandardCharsets.UTF_8)); + BigInteger numberMessage = RandomInteger.numberFromByteString(byteMessage, curve.nBitLength); BigInteger r = signature.r; BigInteger s = signature.s; - if (r.compareTo(new BigInteger(String.valueOf(1))) < 0) { + if (r.compareTo(BigInteger.ONE) < 0 || r.compareTo(curve.N.subtract(BigInteger.ONE)) > 0) { return false; } - if (r.compareTo(curve.N) >= 0) { + if (s.compareTo(BigInteger.ONE) < 0 || s.compareTo(curve.N.subtract(BigInteger.ONE)) > 0) { return false; } - if (s.compareTo(new BigInteger(String.valueOf(1))) < 0) { + if (!curve.contains(publicKey.point)) { return false; } - if (s.compareTo(curve.N) >= 0) { - return false; - } - - BigInteger w = Math.inv(s, curve.N); - Point u1 =Math.multiply(curve.G, numberMessage.multiply(w).mod(curve.N), curve.N, curve.A, curve.P); - Point u2 = Math.multiply(publicKey.point, r.multiply(w).mod(curve.N), curve.N, curve.A, curve.P); - Point v = Math.add(u1, u2, curve.A, curve.P); + + BigInteger inv = Math.inv(s, curve.N); + Point v = Math.multiplyAndAdd( + curve.G, numberMessage.multiply(inv).mod(curve.N), + publicKey.point, r.multiply(inv).mod(curve.N), + curve + ); if (v.isAtInfinity()) { return false; } @@ -80,7 +102,8 @@ public static boolean verify(String message, Signature signature, PublicKey publ } /** - * + * Verify a signature against a message and public key with SHA-256. + * * @param message message * @param signature signature * @param publicKey publicKey @@ -93,4 +116,13 @@ public static boolean verify(String message, Signature signature, PublicKey publ throw new IllegalStateException("Could not find SHA-256 message digest in provided java environment"); } } + + /** + * Convert a MessageDigest algorithm name to the corresponding HMAC algorithm name. + */ + private static String getHmacAlgorithm(String digestAlgorithm) { + // MessageDigest names like "SHA-256" -> "HmacSHA256" + String normalized = digestAlgorithm.replace("-", ""); + return "Hmac" + normalized; + } } diff --git a/src/main/java/com/starkbank/ellipticcurve/Math.java b/src/main/java/com/starkbank/ellipticcurve/Math.java index 8ae022a..814dfc7 100644 --- a/src/main/java/com/starkbank/ellipticcurve/Math.java +++ b/src/main/java/com/starkbank/ellipticcurve/Math.java @@ -4,15 +4,82 @@ public final class Math { + private static final BigInteger TWO = BigInteger.valueOf(2); + private static final BigInteger THREE = BigInteger.valueOf(3); + private static final BigInteger FOUR = BigInteger.valueOf(4); + private static final BigInteger EIGHT = BigInteger.valueOf(8); + + /** + * Tonelli-Shanks algorithm for modular square root. Works for all odd primes. + * + * @param value the value to compute the square root of + * @param prime the prime modulus + * @return the modular square root + */ + public static BigInteger modularSquareRoot(BigInteger value, BigInteger prime) { + if (value.equals(BigInteger.ZERO)) { + return BigInteger.ZERO; + } + if (prime.equals(TWO)) { + return value.mod(TWO); + } + + // Factor out powers of 2: prime - 1 = Q * 2^S + BigInteger Q = prime.subtract(BigInteger.ONE); + int S = 0; + while (Q.mod(TWO).equals(BigInteger.ZERO)) { + Q = Q.divide(TWO); + S++; + } + + if (S == 1) { + // prime = 3 (mod 4) fast path + return value.modPow(prime.add(BigInteger.ONE).divide(FOUR), prime); + } + + // Find a quadratic non-residue z + BigInteger z = TWO; + BigInteger primeMinusOne = prime.subtract(BigInteger.ONE); + BigInteger halfPrimeMinusOne = primeMinusOne.divide(TWO); + while (!z.modPow(halfPrimeMinusOne, prime).equals(primeMinusOne)) { + z = z.add(BigInteger.ONE); + } + + int M = S; + BigInteger c = z.modPow(Q, prime); + BigInteger t = value.modPow(Q, prime); + BigInteger R = value.modPow(Q.add(BigInteger.ONE).divide(TWO), prime); + + while (true) { + if (t.equals(BigInteger.ONE)) { + return R; + } + + // Find the least i such that t^(2^i) = 1 (mod prime) + int i = 1; + BigInteger temp = t.multiply(t).mod(prime); + while (!temp.equals(BigInteger.ONE)) { + temp = temp.multiply(temp).mod(prime); + i++; + } + + BigInteger b = c.modPow(BigInteger.ONE.shiftLeft(M - i - 1), prime); + M = i; + c = b.multiply(b).mod(prime); + t = t.multiply(c).mod(prime); + R = R.multiply(b).mod(prime); + } + } + /** * Fast way to multiply point and scalar in elliptic curves * * @param p First Point to multiply * @param n Scalar to multiply * @param N Order of the elliptic curve - * @param P Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod P) * @param A Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod P) - * @return Point that represents the sum of First and Second Point + * @param P Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod P) + * @return Point that represents the scalar multiplication */ public static Point multiply(Point p, BigInteger n, BigInteger N, BigInteger A, BigInteger P) { return fromJacobian(jacobianMultiply(toJacobian(p), n, N, A, P), P); @@ -27,37 +94,152 @@ public static Point multiply(Point p, BigInteger n, BigInteger N, BigInteger A, * @param P Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod P) * @return Point that represents the sum of First and Second Point */ - public static Point add(Point p, Point q, BigInteger A, BigInteger P) { return fromJacobian(jacobianAdd(toJacobian(p), toJacobian(q), A, P), P); } /** - * Extended Euclidean Algorithm. It's the 'division' in elliptic curves + * Compute n1*p1 + n2*p2 using Shamir's trick with JSF. + * Not constant-time -- use only with public scalars (e.g. verification). * - * @param x Divisor + * @param p1 First point + * @param n1 First scalar + * @param p2 Second point + * @param n2 Second scalar + * @param N Order of the elliptic curve + * @param A Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod P) + * @param P Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod P) + * @return Point n1*p1 + n2*p2 + */ + public static Point multiplyAndAdd(Point p1, BigInteger n1, Point p2, BigInteger n2, BigInteger N, BigInteger A, BigInteger P) { + return fromJacobian( + shamirMultiply(toJacobian(p1), n1, toJacobian(p2), n2, N, A, P), + P + ); + } + + /** + * Compute n1*p1 + n2*p2. If the curve exposes GLV parameters (e.g. + * secp256k1), uses the GLV endomorphism to split both scalars into + * ~128-bit halves and run a 4-scalar simultaneous multi-exponentiation. + * Otherwise falls back to Shamir's trick with JSF. Not constant-time -- + * use only with public scalars (e.g. verification). + * + * @param p1 First point + * @param n1 First scalar + * @param p2 Second point + * @param n2 Second scalar + * @param curve Elliptic curve; enables GLV if curve.glvParams is set + * @return Point n1*p1 + n2*p2 + */ + public static Point multiplyAndAdd(Point p1, BigInteger n1, Point p2, BigInteger n2, Curve curve) { + if (curve.glvParams != null) { + return glvMultiplyAndAdd(p1, n1, p2, n2, curve); + } + return fromJacobian( + shamirMultiply(toJacobian(p1), n1, toJacobian(p2), n2, curve.N, curve.A, curve.P), + curve.P + ); + } + + /** + * Modular inverse via the extended Euclidean algorithm + * (BigInteger.modInverse). Roughly 2-3x faster than Fermat's little + * theorem for 256-bit operands. + * + * @param x Divisor (must be coprime to n) * @param n Mod for division - * @return Value representing the division + * @return Value representing the modular inverse */ public static BigInteger inv(BigInteger x, BigInteger n) { - if (x.compareTo(BigInteger.ZERO) == 0) { - return BigInteger.ZERO; + if (x.mod(n).equals(BigInteger.ZERO)) { + throw new ArithmeticException("0 has no modular inverse"); + } + return x.modInverse(n); + } + + /** + * Fast scalar multiplication n*G using a precomputed affine table of + * powers-of-two multiples of G and the width-2 NAF of n. Every non-zero + * NAF digit triggers one mixed add and zero doublings, trading the ~256 + * doublings of a windowed method for ~86 adds on average -- a large net + * reduction in field multiplications for 256-bit scalars. + * + * @param curve Elliptic curve with generator G + * @param n Scalar multiplier + * @return Point n*G + */ + public static Point multiplyGenerator(Curve curve, BigInteger n) { + if (n.signum() < 0 || n.compareTo(curve.N) >= 0) { + n = n.mod(curve.N); + } + if (n.equals(BigInteger.ZERO)) { + return new Point(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ZERO); } - BigInteger lm = BigInteger.ONE; - BigInteger hm = BigInteger.ZERO; - BigInteger high = n; - BigInteger low = x.mod(n); - BigInteger r, nm, nw; - while (low.compareTo(BigInteger.ONE) > 0) { - r = high.divide(low); - nm = hm.subtract(lm.multiply(r)); - nw = high.subtract(low.multiply(r)); - high = low; - hm = lm; - low = nw; - lm = nm; + + Point[] table = generatorTable(curve); + BigInteger A = curve.A; + BigInteger P = curve.P; + + Point r = new Point(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ONE); + int i = 0; + BigInteger k = n; + while (k.signum() > 0) { + if (k.testBit(0)) { + // Low two bits of k: 1 -> digit +1, 3 -> digit -1. + int low2 = (k.testBit(1) ? 2 : 0) | 1; + int digit = 2 - low2; // +1 or -1 + k = digit == 1 ? k.subtract(BigInteger.ONE) : k.add(BigInteger.ONE); + Point g = table[i]; + if (digit == 1) { + r = jacobianAdd(r, g, A, P); + } else { + r = jacobianAdd(r, new Point(g.x, P.subtract(g.y), BigInteger.ONE), A, P); + } + } + k = k.shiftRight(1); + i++; + } + return fromJacobian(r, P); + } + + /** + * Build [G, 2G, 4G, ..., 2^nBitLength * G] in affine (z=1) form, so each + * add in multiplyGenerator hits the mixed-add fast path. Idempotent: + * repeated calls return the same array. + * + * @param curve Elliptic curve whose generator is tabulated + * @return Powers-of-two table of affine points + */ + static Point[] generatorTable(Curve curve) { + Point[] cached = curve.generatorTable; + if (cached != null) { + return cached; + } + BigInteger A = curve.A; + BigInteger P = curve.P; + Point current = new Point(curve.G.x, curve.G.y, BigInteger.ONE); + // NAF of an nBitLength-bit scalar can be up to nBitLength+1 digits. + Point[] table = new Point[curve.nBitLength + 1]; + table[0] = current; + for (int i = 1; i <= curve.nBitLength; i++) { + Point doubled = jacobianDouble(current, A, P); + if (doubled.y.equals(BigInteger.ZERO)) { + current = doubled; + } else { + BigInteger zInv = inv(doubled.z, P); + BigInteger zInv2 = zInv.multiply(zInv).mod(P); + BigInteger zInv3 = zInv2.multiply(zInv).mod(P); + current = new Point( + doubled.x.multiply(zInv2).mod(P), + doubled.y.multiply(zInv3).mod(P), + BigInteger.ONE + ); + } + table[i] = current; } - return lm.mod(n); + curve.generatorTable = table; + return table; } /** @@ -66,8 +248,7 @@ public static BigInteger inv(BigInteger x, BigInteger n) { * @param p the point you want to transform * @return Point in Jacobian coordinates */ - public static Point toJacobian(Point p) { - + static Point toJacobian(Point p) { return new Point(p.x, p.y, BigInteger.ONE); } @@ -78,7 +259,10 @@ public static Point toJacobian(Point p) { * @param P Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod P) * @return Point in default coordinates */ - public static Point fromJacobian(Point p, BigInteger P) { + static Point fromJacobian(Point p, BigInteger P) { + if (p.y.equals(BigInteger.ZERO)) { + return new Point(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ZERO); + } BigInteger z = inv(p.z, P); BigInteger x = p.x.multiply(z.pow(2)).mod(P); BigInteger y = p.y.multiply(z.pow(3)).mod(P); @@ -88,21 +272,32 @@ public static Point fromJacobian(Point p, BigInteger P) { /** * Double a point in elliptic curves * - * @param p the point you want to transform + * @param p the point you want to double * @param A Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod P) * @param P Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod P) - * @return the result point doubled in elliptic curves + * @return the result point doubled */ - public static Point jacobianDouble(Point p, BigInteger A, BigInteger P) { - if (p.y == null || p.y.equals(BigInteger.ZERO)) { + static Point jacobianDouble(Point p, BigInteger A, BigInteger P) { + if (p.y.equals(BigInteger.ZERO)) { return new Point(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ZERO); } - BigInteger ysq = p.y.pow(2).mod(P); - BigInteger S = BigInteger.valueOf(4).multiply(p.x).multiply(ysq).mod(P); - BigInteger M = BigInteger.valueOf(3).multiply(p.x.pow(2)).add(A.multiply(p.z.pow(4))).mod(P); - BigInteger nx = M.pow(2).subtract(BigInteger.valueOf(2).multiply(S)).mod(P); - BigInteger ny = M.multiply(S.subtract(nx)).subtract(BigInteger.valueOf(8).multiply(ysq.pow(2))).mod(P); - BigInteger nz = BigInteger.valueOf(2).multiply(p.y).multiply(p.z).mod(P); + BigInteger px = p.x, py = p.y, pz = p.z; + BigInteger ysq = py.multiply(py).mod(P); + BigInteger S = FOUR.multiply(px).multiply(ysq).mod(P); + BigInteger pz2 = pz.multiply(pz).mod(P); + BigInteger M; + if (A.signum() == 0) { + // A = 0 (secp256k1): skip A*pz^4 term + M = THREE.multiply(px).multiply(px).mod(P); + } else if (A.equals(P.subtract(THREE))) { + // A = -3 (prime256v1): M = 3*(px - pz^2)*(px + pz^2) + M = THREE.multiply(px.subtract(pz2)).multiply(px.add(pz2)).mod(P); + } else { + M = THREE.multiply(px).multiply(px).add(A.multiply(pz2).multiply(pz2)).mod(P); + } + BigInteger nx = M.multiply(M).subtract(TWO.multiply(S)).mod(P); + BigInteger ny = M.multiply(S.subtract(nx)).subtract(EIGHT.multiply(ysq).multiply(ysq)).mod(P); + BigInteger nz = TWO.multiply(py).multiply(pz).mod(P); return new Point(nx, ny, nz); } @@ -115,60 +310,282 @@ public static Point jacobianDouble(Point p, BigInteger A, BigInteger P) { * @param P Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod P) * @return Point that represents the sum of First and Second Point */ - public static Point jacobianAdd(Point p, Point q, BigInteger A, BigInteger P) { - if (p.y == null || p.y.equals(BigInteger.ZERO)) { + static Point jacobianAdd(Point p, Point q, BigInteger A, BigInteger P) { + if (p.y.equals(BigInteger.ZERO)) { return q; } - if (q.y == null || q.y.equals(BigInteger.ZERO)) { + if (q.y.equals(BigInteger.ZERO)) { return p; } - BigInteger U1 = p.x.multiply(q.z.pow(2)).mod(P); - BigInteger U2 = q.x.multiply(p.z.pow(2)).mod(P); - BigInteger S1 = p.y.multiply(q.z.pow(3)).mod(P); - BigInteger S2 = q.y.multiply(p.z.pow(3)).mod(P); - if (U1.compareTo(U2) == 0) { - if (S1.compareTo(S2) != 0) { + BigInteger px = p.x, py = p.y, pz = p.z; + BigInteger qx = q.x, qy = q.y, qz = q.z; + + BigInteger pz2 = pz.multiply(pz).mod(P); + BigInteger U2 = qx.multiply(pz2).mod(P); + BigInteger S2 = qy.multiply(pz2).multiply(pz).mod(P); + + BigInteger U1, S1; + boolean qzIsOne = qz.equals(BigInteger.ONE); + if (qzIsOne) { + // Mixed affine+Jacobian add: qz^2 = qz^3 = 1 saves four multiplications. + U1 = px; + S1 = py; + } else { + BigInteger qz2 = qz.multiply(qz).mod(P); + U1 = px.multiply(qz2).mod(P); + S1 = py.multiply(qz2).multiply(qz).mod(P); + } + + if (U1.equals(U2)) { + if (!S1.equals(S2)) { return new Point(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ONE); } return jacobianDouble(p, A, P); } + BigInteger H = U2.subtract(U1); BigInteger R = S2.subtract(S1); BigInteger H2 = H.multiply(H).mod(P); BigInteger H3 = H.multiply(H2).mod(P); BigInteger U1H2 = U1.multiply(H2).mod(P); - BigInteger nx = R.pow(2).subtract(H3).subtract(BigInteger.valueOf(2).multiply(U1H2)).mod(P); + BigInteger nx = R.multiply(R).subtract(H3).subtract(TWO.multiply(U1H2)).mod(P); BigInteger ny = R.multiply(U1H2.subtract(nx)).subtract(S1.multiply(H3)).mod(P); - BigInteger nz = H.multiply(p.z).multiply(q.z).mod(P); + BigInteger nz = qzIsOne ? H.multiply(pz).mod(P) : H.multiply(pz).multiply(qz).mod(P); return new Point(nx, ny, nz); } /** - * Multiply point and scalar in elliptic curves + * Multiply point and scalar in elliptic curves using Montgomery ladder + * for constant-time execution. * * @param p First Point to multiply * @param n Scalar to multiply * @param N Order of the elliptic curve * @param A Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod P) * @param P Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod P) - * @return Point that represents the product of First Point and scalar + * @return Point that represents the scalar multiplication */ - public static Point jacobianMultiply(Point p, BigInteger n, BigInteger N, BigInteger A, BigInteger P) { - if (BigInteger.ZERO.compareTo(p.y) == 0 || BigInteger.ZERO.compareTo(n) == 0) { + static Point jacobianMultiply(Point p, BigInteger n, BigInteger N, BigInteger A, BigInteger P) { + if (p.y.equals(BigInteger.ZERO) || n.equals(BigInteger.ZERO)) { return new Point(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ONE); } - if (BigInteger.ONE.compareTo(n) == 0) { - return p; + + if (n.signum() < 0 || n.compareTo(N) >= 0) { + n = n.mod(N); } - if (n.compareTo(BigInteger.ZERO) < 0 || n.compareTo(N) >= 0) { - return jacobianMultiply(p, n.mod(N), N, A, P); + + if (n.equals(BigInteger.ZERO)) { + return new Point(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ONE); } - if (n.mod(BigInteger.valueOf(2)).compareTo(BigInteger.ZERO) == 0) { - return jacobianDouble(jacobianMultiply(p, n.divide(BigInteger.valueOf(2)), N, A, P), A, P); + + // Montgomery ladder: always performs one add and one double per bit + Point r0 = new Point(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ONE); + Point r1 = new Point(p.x, p.y, p.z); + + for (int i = n.bitLength() - 1; i >= 0; i--) { + if (!n.testBit(i)) { + r1 = jacobianAdd(r0, r1, A, P); + r0 = jacobianDouble(r0, A, P); + } else { + r0 = jacobianAdd(r0, r1, A, P); + r1 = jacobianDouble(r1, A, P); + } } - if (n.mod(BigInteger.valueOf(2)).compareTo(BigInteger.ONE) == 0) { - return jacobianAdd(jacobianDouble(jacobianMultiply(p, n.divide(BigInteger.valueOf(2)), N, A, P), A, P), p, A, P); + + return r0; + } + + /** + * Compute n1*p1 + n2*p2 using Shamir's trick with Joint Sparse Form + * (Solinas 2001). JSF picks signed digits in {-1, 0, 1} so at most ~l/2 + * digit pairs are non-zero, versus ~3l/4 for the raw binary form. Not + * constant-time -- use only with public scalars (e.g. verification). + * + * @param jp1 First point in Jacobian coordinates + * @param n1 First scalar + * @param jp2 Second point in Jacobian coordinates + * @param n2 Second scalar + * @param N Order of the elliptic curve + * @param A Coefficient of the first-order term of the equation Y^2 = X^3 + A*X + B (mod P) + * @param P Prime number in the module of the equation Y^2 = X^3 + A*X + B (mod P) + * @return Point n1*p1 + n2*p2 in Jacobian coordinates + */ + static Point shamirMultiply(Point jp1, BigInteger n1, Point jp2, BigInteger n2, BigInteger N, BigInteger A, BigInteger P) { + if (n1.signum() < 0 || n1.compareTo(N) >= 0) { + n1 = n1.mod(N); + } + if (n2.signum() < 0 || n2.compareTo(N) >= 0) { + n2 = n2.mod(N); + } + + if (n1.signum() == 0 && n2.signum() == 0) { + return new Point(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ONE); + } + + Point jp1p2 = jacobianAdd(jp1, jp2, A, P); + Point negJp2 = negate(jp2, P); + Point jp1mp2 = jacobianAdd(jp1, negJp2, A, P); + Point negJp1 = negate(jp1, P); + Point negJp1p2 = negate(jp1p2, P); + Point negJp1mp2 = negate(jp1mp2, P); + + // addTable[(u0, u1)]: index by (u0+1)*3 + (u1+1), with (0,0) unused. + // (1,0) -> jp1; (-1,0) -> -jp1; (0,1) -> jp2; (0,-1) -> -jp2; + // (1,1) -> jp1+jp2; (-1,-1) -> -(jp1+jp2); (1,-1) -> jp1-jp2; (-1,1) -> -(jp1-jp2). + Point[] addTable = new Point[9]; + addTable[(1 + 1) * 3 + (0 + 1)] = jp1; + addTable[(-1 + 1) * 3 + (0 + 1)] = negJp1; + addTable[(0 + 1) * 3 + (1 + 1)] = jp2; + addTable[(0 + 1) * 3 + (-1 + 1)] = negJp2; + addTable[(1 + 1) * 3 + (1 + 1)] = jp1p2; + addTable[(-1 + 1) * 3 + (-1 + 1)] = negJp1p2; + addTable[(1 + 1) * 3 + (-1 + 1)] = jp1mp2; + addTable[(-1 + 1) * 3 + (1 + 1)] = negJp1mp2; + + int[][] digits = jsfDigits(n1, n2); + Point r = new Point(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ONE); + for (int[] pair : digits) { + int u0 = pair[0], u1 = pair[1]; + r = jacobianDouble(r, A, P); + if (u0 != 0 || u1 != 0) { + r = jacobianAdd(r, addTable[(u0 + 1) * 3 + (u1 + 1)], A, P); + } + } + + return r; + } + + /** + * Negate a point in Jacobian coordinates: (x, y, z) -> (x, -y, z). + */ + private static Point negate(Point p, BigInteger P) { + if (p.y.signum() == 0) { + return new Point(p.x, BigInteger.ZERO, p.z); + } + return new Point(p.x, P.subtract(p.y), p.z); + } + + /** + * Compute n1*p1 + n2*p2 using the GLV endomorphism. Splits each 256-bit + * scalar into two ~128-bit scalars via k = k1 + k2*lambda (mod N), then + * runs a 4-scalar simultaneous double-and-add over (p1, phi(p1), p2, phi(p2)) + * with a 16-entry precomputed table of subset sums. Halves the loop + * length versus the plain Shamir path. + */ + static Point glvMultiplyAndAdd(Point p1, BigInteger n1, Point p2, BigInteger n2, Curve curve) { + Curve.GLVParams glv = curve.glvParams; + BigInteger N = curve.N, A = curve.A, P = curve.P; + BigInteger beta = glv.beta; + + BigInteger[] d1 = glvDecompose(n1.mod(N), glv, N); + BigInteger[] d2 = glvDecompose(n2.mod(N), glv, N); + BigInteger k1 = d1[0], k2 = d1[1], k3 = d2[0], k4 = d2[1]; + + // Base points (affine, z=1) -- phi((x, y)) = (beta*x mod P, y). + Point[] bases = new Point[]{ + new Point(p1.x, p1.y, BigInteger.ONE), + new Point(beta.multiply(p1.x).mod(P), p1.y, BigInteger.ONE), + new Point(p2.x, p2.y, BigInteger.ONE), + new Point(beta.multiply(p2.x).mod(P), p2.y, BigInteger.ONE), + }; + BigInteger[] scalars = new BigInteger[]{k1, k2, k3, k4}; + for (int i = 0; i < 4; i++) { + if (scalars[i].signum() < 0) { + scalars[i] = scalars[i].negate(); + bases[i] = new Point(bases[i].x, P.subtract(bases[i].y), BigInteger.ONE); + } + } + + // Precompute table[idx] = sum of bases[i] selected by bits of idx. + Point[] table = new Point[16]; + table[0] = new Point(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ONE); + for (int idx = 1; idx < 16; idx++) { + int low = idx & -idx; + int i = Integer.numberOfTrailingZeros(low); + table[idx] = jacobianAdd(table[idx ^ low], bases[i], A, P); + } + + int maxLen = 0; + for (BigInteger s : scalars) { + if (s.bitLength() > maxLen) { + maxLen = s.bitLength(); + } + } + Point r = new Point(BigInteger.ZERO, BigInteger.ZERO, BigInteger.ONE); + BigInteger s0 = scalars[0], s1 = scalars[1], s2 = scalars[2], s3 = scalars[3]; + for (int bit = maxLen - 1; bit >= 0; bit--) { + r = jacobianDouble(r, A, P); + int idx = (s0.testBit(bit) ? 1 : 0) + | (s1.testBit(bit) ? 2 : 0) + | (s2.testBit(bit) ? 4 : 0) + | (s3.testBit(bit) ? 8 : 0); + if (idx != 0) { + r = jacobianAdd(r, table[idx], A, P); + } + } + + return fromJacobian(r, P); + } + + /** + * Decompose k into (k1, k2) with k = k1 + k2*lambda (mod N) and + * |k1|, |k2| ~ sqrt(N). Babai rounding against the precomputed basis + * {(a1, b1), (a2, b2)}; k1 and k2 may be negative. + */ + static BigInteger[] glvDecompose(BigInteger k, Curve.GLVParams glv, BigInteger N) { + BigInteger a1 = glv.a1, b1 = glv.b1, a2 = glv.a2, b2 = glv.b2; + BigInteger halfN = N.shiftRight(1); + BigInteger c1 = b2.multiply(k).add(halfN).divide(N); + BigInteger c2 = b1.negate().multiply(k).add(halfN).divide(N); + BigInteger k1 = k.subtract(c1.multiply(a1)).subtract(c2.multiply(a2)); + BigInteger k2 = c1.negate().multiply(b1).subtract(c2.multiply(b2)); + return new BigInteger[]{k1, k2}; + } + + /** + * Joint Sparse Form of (k0, k1): list of signed-digit pairs (u0, u1) in + * {-1, 0, 1}, ordered MSB-first. At most one of any two consecutive pairs + * is non-zero, giving density ~1/2 instead of ~3/4 from raw binary. + */ + static int[][] jsfDigits(BigInteger k0, BigInteger k1) { + java.util.List digits = new java.util.ArrayList<>(); + int d0 = 0; + int d1 = 0; + while (k0.signum() != 0 || d0 != 0 || k1.signum() != 0 || d1 != 0) { + int low0 = (k0.testBit(0) ? 1 : 0) | (k0.testBit(1) ? 2 : 0) | (k0.testBit(2) ? 4 : 0); + int low1 = (k1.testBit(0) ? 1 : 0) | (k1.testBit(1) ? 2 : 0) | (k1.testBit(2) ? 4 : 0); + int a0 = (low0 + d0) & 7; + int a1 = (low1 + d1) & 7; + int u0; + if ((a0 & 1) != 0) { + u0 = ((a0 & 3) == 1) ? 1 : -1; + if (((a0 & 7) == 3 || (a0 & 7) == 5) && (a1 & 3) == 2) { + u0 = -u0; + } + } else { + u0 = 0; + } + int u1; + if ((a1 & 1) != 0) { + u1 = ((a1 & 3) == 1) ? 1 : -1; + if (((a1 & 7) == 3 || (a1 & 7) == 5) && (a0 & 3) == 2) { + u1 = -u1; + } + } else { + u1 = 0; + } + digits.add(new int[]{u0, u1}); + if (2 * d0 == 1 + u0) { + d0 = 1 - d0; + } + if (2 * d1 == 1 + u1) { + d1 = 1 - d1; + } + k0 = k0.shiftRight(1); + k1 = k1.shiftRight(1); } - return null; + // Reverse in place (MSB-first). + java.util.Collections.reverse(digits); + return digits.toArray(new int[0][]); } } diff --git a/src/main/java/com/starkbank/ellipticcurve/Point.java b/src/main/java/com/starkbank/ellipticcurve/Point.java index 9c1af57..0cab751 100644 --- a/src/main/java/com/starkbank/ellipticcurve/Point.java +++ b/src/main/java/com/starkbank/ellipticcurve/Point.java @@ -9,18 +9,14 @@ public class Point { public BigInteger z; /** - * * @param x x * @param y y */ public Point(BigInteger x, BigInteger y) { - this.x = x; - this.y = y; - this.z = BigInteger.ZERO; + this(x, y, BigInteger.ZERO); } /** - * * @param x x * @param y y * @param z z @@ -34,4 +30,9 @@ public Point(BigInteger x, BigInteger y, BigInteger z) { public boolean isAtInfinity() { return this.y.equals(BigInteger.ZERO); } + + @Override + public String toString() { + return String.format("(%s, %s, %s)", x, y, z); + } } diff --git a/src/main/java/com/starkbank/ellipticcurve/PrivateKey.java b/src/main/java/com/starkbank/ellipticcurve/PrivateKey.java index 515d200..3cbbef2 100644 --- a/src/main/java/com/starkbank/ellipticcurve/PrivateKey.java +++ b/src/main/java/com/starkbank/ellipticcurve/PrivateKey.java @@ -4,7 +4,6 @@ import com.starkbank.ellipticcurve.utils.BinaryAscii; import com.starkbank.ellipticcurve.utils.RandomInteger; import java.math.BigInteger; -import java.util.Arrays; public class PrivateKey { @@ -13,25 +12,24 @@ public class PrivateKey { public BigInteger secret; /** - * + * Generate a new random private key on secp256k1 */ public PrivateKey() { this(Curve.secp256k1, null); - secret = RandomInteger.between(BigInteger.ONE, curve.N); } /** + * Create a private key on a specified curve. If secret is null, a random one is generated. * * @param curve curve * @param secret secret */ public PrivateKey(Curve curve, BigInteger secret) { this.curve = curve; - this.secret = secret; + this.secret = secret != null ? secret : RandomInteger.between(BigInteger.ONE, curve.N.subtract(BigInteger.ONE)); } /** - * * @return PublicKey */ public PublicKey publicKey() { @@ -41,7 +39,19 @@ public PublicKey publicKey() { } /** + * Get the hex string representation of the private key secret * + * @return hex string + */ + public String toString() { + String hex = this.secret.toString(16); + if (hex.length() % 2 != 0) { + hex = "0" + hex; + } + return hex; + } + + /** * @return ByteString */ public ByteString toByteString() { @@ -49,7 +59,6 @@ public ByteString toByteString() { } /** - * * @return ByteString */ public ByteString toDer() { @@ -62,16 +71,13 @@ public ByteString toDer() { } /** - * * @return String */ public String toPem() { return Der.toPem(this.toDer(), "EC PRIVATE KEY"); } - /** - * * @param string string * @return PrivateKey */ @@ -81,16 +87,14 @@ public static PrivateKey fromPem(String string) { } /** - * * @param string string - * @return Privatekey + * @return PrivateKey */ public static PrivateKey fromDer(String string) { return fromDer(new ByteString(string.getBytes())); } /** - * * @param string ByteString * @return PrivateKey */ @@ -126,17 +130,11 @@ public static PrivateKey fromDer(ByteString string) { if (!"".equals(empty.toString())) { throw new RuntimeException(String.format("trailing junk after DER privkey curve_oid: %s", BinaryAscii.hexFromBinary(empty))); } - Curve curve = (Curve) Curve.curvesByOid.get(Arrays.hashCode(oidCurve)); - if (curve == null) { - throw new RuntimeException(String.format("Unknown curve with oid %s. I only know about these: %s", Arrays.toString(oidCurve), Arrays.toString(Curve.supportedCurves.toArray()))); - } + Curve curve = Curve.getByOid(oidCurve); if (privkeyStr.length() < curve.length()) { int l = curve.length() - privkeyStr.length(); byte[] bytes = new byte[l + privkeyStr.length()]; - for (int i = 0; i < curve.length() - privkeyStr.length(); i++) { - bytes[i] = 0; - } byte[] privateKey = privkeyStr.getBytes(); System.arraycopy(privateKey, 0, bytes, l, bytes.length - l); privkeyStr = new ByteString(bytes); @@ -146,7 +144,6 @@ public static PrivateKey fromDer(ByteString string) { } /** - * * @param string byteString * @param curve curve * @return PrivateKey @@ -156,7 +153,6 @@ public static PrivateKey fromString(ByteString string, Curve curve) { } /** - * * @param string string * @return PrivateKey */ @@ -165,11 +161,21 @@ public static PrivateKey fromString(String string) { } /** - * * @param string byteString * @return PrivateKey */ public static PrivateKey fromString(ByteString string) { return PrivateKey.fromString(string, Curve.secp256k1); } + + /** + * Create a PrivateKey from a hex string and curve + * + * @param hexString hex string representation of the secret + * @param curve the curve + * @return PrivateKey + */ + public static PrivateKey fromString(String hexString, Curve curve) { + return new PrivateKey(curve, new BigInteger(hexString, 16)); + } } diff --git a/src/main/java/com/starkbank/ellipticcurve/PublicKey.java b/src/main/java/com/starkbank/ellipticcurve/PublicKey.java index 1011680..72cbf33 100644 --- a/src/main/java/com/starkbank/ellipticcurve/PublicKey.java +++ b/src/main/java/com/starkbank/ellipticcurve/PublicKey.java @@ -2,9 +2,9 @@ import com.starkbank.ellipticcurve.utils.ByteString; import com.starkbank.ellipticcurve.utils.Der; import com.starkbank.ellipticcurve.utils.BinaryAscii; +import java.math.BigInteger; import java.util.Arrays; import static com.starkbank.ellipticcurve.Curve.secp256k1; -import static com.starkbank.ellipticcurve.Curve.supportedCurves; public class PublicKey { @@ -12,8 +12,11 @@ public class PublicKey { public Point point; public Curve curve; + private static final String EVEN_TAG = "02"; + private static final String ODD_TAG = "03"; + private static final long[] ECDSA_PUBLIC_KEY_OID = new long[]{1, 2, 840, 10045, 2, 1}; + /** - * * @param point point * @param curve curve */ @@ -23,7 +26,6 @@ public PublicKey(Point point, Curve curve) { } /** - * * @return ByteString */ public ByteString toByteString() { @@ -31,7 +33,6 @@ public ByteString toByteString() { } /** - * * @param encoded encoded * @return ByteString */ @@ -39,24 +40,50 @@ public ByteString toByteString(boolean encoded) { ByteString xStr = BinaryAscii.stringFromNumber(point.x, curve.length()); ByteString yStr = BinaryAscii.stringFromNumber(point.y, curve.length()); xStr.insert(yStr.getBytes()); - if(encoded) { - xStr.insert(0, new byte[]{0, 4} ); + if (encoded) { + xStr.insert(0, new byte[]{0, 4}); } return xStr; } /** + * Get the hex string representation of the public key point * + * @param encoded whether to include the 0004 prefix + * @return hex string + */ + public String toString(boolean encoded) { + int baseLength = 2 * curve.length(); + String xHex = leftPad(point.x.toString(16), baseLength); + String yHex = leftPad(point.y.toString(16), baseLength); + String string = xHex + yHex; + if (encoded) { + return "0004" + string; + } + return string; + } + + /** + * Get the compressed hex string representation of the public key + * + * @return compressed hex string + */ + public String toCompressed() { + int baseLength = 2 * curve.length(); + String parityTag = point.y.mod(BigInteger.TWO).equals(BigInteger.ZERO) ? EVEN_TAG : ODD_TAG; + String xHex = leftPad(point.x.toString(16), baseLength); + return parityTag + xHex; + } + + /** * @return ByteString */ public ByteString toDer() { - long[] oidEcPublicKey = new long[]{1, 2, 840, 10045, 2, 1}; - ByteString encodeEcAndOid = Der.encodeSequence(Der.encodeOid(oidEcPublicKey), Der.encodeOid(curve.oid)); + ByteString encodeEcAndOid = Der.encodeSequence(Der.encodeOid(ECDSA_PUBLIC_KEY_OID), Der.encodeOid(curve.oid)); return Der.encodeSequence(encodeEcAndOid, Der.encodeBitString(this.toByteString(true))); } /** - * * @return String */ public String toPem() { @@ -64,7 +91,6 @@ public String toPem() { } /** - * * @param string string * @return PublicKey */ @@ -73,7 +99,6 @@ public static PublicKey fromPem(String string) { } /** - * * @param string byteString * @return PublicKey */ @@ -82,36 +107,40 @@ public static PublicKey fromDer(ByteString string) { ByteString s1 = str[0]; ByteString empty = str[1]; if (!empty.isEmpty()) { - throw new RuntimeException (String.format("trailing junk after DER pubkey: %s", BinaryAscii.hexFromBinary(empty))); + throw new RuntimeException(String.format("trailing junk after DER pubkey: %s", BinaryAscii.hexFromBinary(empty))); } str = Der.removeSequence(s1); ByteString s2 = str[0]; ByteString pointStrBitstring = str[1]; Object[] o = Der.removeObject(s2); + long[] publicKeyOid = (long[]) o[0]; ByteString rest = (ByteString) o[1]; o = Der.removeObject(rest); long[] oidCurve = (long[]) o[0]; empty = (ByteString) o[1]; if (!empty.isEmpty()) { - throw new RuntimeException (String.format("trailing junk after DER pubkey objects: %s", BinaryAscii.hexFromBinary(empty))); + throw new RuntimeException(String.format("trailing junk after DER pubkey objects: %s", BinaryAscii.hexFromBinary(empty))); } - Curve curve = (Curve) Curve.curvesByOid.get(Arrays.hashCode(oidCurve)); - if (curve == null) { - throw new RuntimeException(String.format("Unknown curve with oid %s. I only know about these: %s", Arrays.toString(oidCurve), Arrays.toString(supportedCurves.toArray()))); + if (!Arrays.equals(publicKeyOid, ECDSA_PUBLIC_KEY_OID)) { + throw new RuntimeException(String.format( + "The Public Key Object Identifier (OID) should be %s, but %s was found instead", + Arrays.toString(ECDSA_PUBLIC_KEY_OID), Arrays.toString(publicKeyOid) + )); } + Curve curve = Curve.getByOid(oidCurve); + str = Der.removeBitString(pointStrBitstring); ByteString pointStr = str[0]; empty = str[1]; if (!empty.isEmpty()) { - throw new RuntimeException (String.format("trailing junk after pubkey pointstring: %s", BinaryAscii.hexFromBinary(empty))); + throw new RuntimeException(String.format("trailing junk after pubkey pointstring: %s", BinaryAscii.hexFromBinary(empty))); } return PublicKey.fromString(pointStr.substring(2), curve); } /** - * * @param string byteString * @param curve curve * @param validatePoint validatePoint @@ -142,7 +171,6 @@ public static PublicKey fromString(ByteString string, Curve curve, boolean valid } /** - * * @param string byteString * @param curve curve * @return PublicKey @@ -152,7 +180,6 @@ public static PublicKey fromString(ByteString string, Curve curve) { } /** - * * @param string byteString * @param validatePoint validatePoint * @return PublicKey @@ -162,11 +189,88 @@ public static PublicKey fromString(ByteString string, boolean validatePoint) { } /** - * * @param string byteString * @return PublicKey */ public static PublicKey fromString(ByteString string) { return fromString(string, true); } + + /** + * Create a PublicKey from a hex string representation + * + * @param hexString hex string of x+y coordinates, optionally with 0004 prefix + * @param curve the curve + * @param validatePoint whether to validate the point + * @return PublicKey + */ + public static PublicKey fromString(String hexString, Curve curve, boolean validatePoint) { + int baseLength = 2 * curve.length(); + if (hexString.length() > 2 * baseLength && hexString.startsWith("0004")) { + hexString = hexString.substring(4); + } + String xs = hexString.substring(0, baseLength); + String ys = hexString.substring(baseLength); + Point p = new Point(new BigInteger(xs, 16), new BigInteger(ys, 16)); + PublicKey publicKey = new PublicKey(p, curve); + if (!validatePoint) { + return publicKey; + } + if (p.isAtInfinity()) { + throw new RuntimeException("Public Key point is at infinity"); + } + if (!curve.contains(p)) { + throw new RuntimeException(String.format("Point (%s,%s) is not valid for curve %s", p.x, p.y, curve.name)); + } + if (!Math.multiply(p, curve.N, curve.N, curve.A, curve.P).isAtInfinity()) { + throw new RuntimeException(String.format("Point (%s,%s) * %s.N is not at infinity", p.x, p.y, curve.name)); + } + return publicKey; + } + + /** + * Create a PublicKey from a hex string representation using default curve + * + * @param hexString hex string + * @param curve the curve + * @return PublicKey + */ + public static PublicKey fromString(String hexString, Curve curve) { + return fromString(hexString, curve, true); + } + + /** + * Create a PublicKey from a compressed hex string + * + * @param compressedHex the compressed public key hex (02/03 prefix + x coordinate) + * @param curve the curve + * @return PublicKey + */ + public static PublicKey fromCompressed(String compressedHex, Curve curve) { + String parityTag = compressedHex.substring(0, 2); + String xHex = compressedHex.substring(2); + if (!parityTag.equals(EVEN_TAG) && !parityTag.equals(ODD_TAG)) { + throw new RuntimeException("Compressed string should start with 02 or 03"); + } + BigInteger x = new BigInteger(xHex, 16); + BigInteger y = curve.y(x, parityTag.equals(EVEN_TAG)); + return new PublicKey(new Point(x, y), curve); + } + + /** + * Create a PublicKey from a compressed hex string using secp256k1 + * + * @param compressedHex the compressed public key hex + * @return PublicKey + */ + public static PublicKey fromCompressed(String compressedHex) { + return fromCompressed(compressedHex, secp256k1); + } + + private static String leftPad(String s, int length) { + while (s.length() < length) { + s = "0" + s; + } + return s; + } } diff --git a/src/main/java/com/starkbank/ellipticcurve/Signature.java b/src/main/java/com/starkbank/ellipticcurve/Signature.java index 1ede9f2..5fbaae7 100644 --- a/src/main/java/com/starkbank/ellipticcurve/Signature.java +++ b/src/main/java/com/starkbank/ellipticcurve/Signature.java @@ -11,39 +11,86 @@ public class Signature { public BigInteger r; public BigInteger s; + public Integer recoveryId; /** - * * @param r r * @param s s */ public Signature(BigInteger r, BigInteger s) { + this(r, s, null); + } + + /** + * @param r r + * @param s s + * @param recoveryId recoveryId + */ + public Signature(BigInteger r, BigInteger s, Integer recoveryId) { this.r = r; this.s = s; + this.recoveryId = recoveryId; } /** - * - * @return ByteString + * @return ByteString (DER encoded, without recovery ID) */ public ByteString toDer() { - return Der.encodeSequence(Der.encodeInteger(r), Der.encodeInteger(s)); + return toDer(false); } /** - * - * @return String + * @param withRecoveryId whether to prepend the recovery ID byte + * @return ByteString + */ + public ByteString toDer(boolean withRecoveryId) { + ByteString encodedSequence = Der.encodeSequence(Der.encodeInteger(r), Der.encodeInteger(s)); + if (!withRecoveryId) { + return encodedSequence; + } + byte recoveryByte = (byte) (27 + this.recoveryId); + byte[] seqBytes = encodedSequence.getBytes(); + byte[] result = new byte[1 + seqBytes.length]; + result[0] = recoveryByte; + System.arraycopy(seqBytes, 0, result, 1, seqBytes.length); + return new ByteString(result); + } + + /** + * @return String (base64 encoded, without recovery ID) */ public String toBase64() { - return Base64.encodeBytes(toDer().getBytes()); + return toBase64(false); + } + + /** + * @param withRecoveryId whether to include the recovery ID byte + * @return String (base64 encoded) + */ + public String toBase64(boolean withRecoveryId) { + return Base64.encodeBytes(toDer(withRecoveryId).getBytes()); } /** - * * @param string byteString * @return Signature */ public static Signature fromDer(ByteString string) { + return fromDer(string, false); + } + + /** + * @param string byteString + * @param recoveryByte whether the first byte is a recovery ID + * @return Signature + */ + public static Signature fromDer(ByteString string, boolean recoveryByte) { + Integer recoveryId = null; + if (recoveryByte) { + recoveryId = (string.getShort(0)) - 27; + string = string.substring(1); + } + ByteString[] str = Der.removeSequence(string); ByteString rs = str[0]; ByteString empty = str[1]; @@ -59,21 +106,50 @@ public static Signature fromDer(ByteString string) { if (!empty.isEmpty()) { throw new RuntimeException(String.format("trailing junk after DER numbers: %s", BinaryAscii.hexFromBinary(empty))); } - return new Signature(r, s); + return new Signature(r, s, recoveryId); } /** - * - * @param string byteString + * @param string base64 encoded string * @return Signature */ public static Signature fromBase64(ByteString string) { - ByteString der = null; + return fromBase64(string, false); + } + + /** + * @param string base64 encoded string + * @param recoveryByte whether the first byte is a recovery ID + * @return Signature + */ + public static Signature fromBase64(ByteString string, boolean recoveryByte) { + ByteString der; try { der = new ByteString(Base64.decode(string.getBytes())); } catch (IOException e) { throw new IllegalArgumentException("Corrupted base64 string! Could not decode base64 from it"); } - return fromDer(der); + return fromDer(der, recoveryByte); + } + + /** + * Convenience: fromBase64 from a plain String + * + * @param string base64 encoded string + * @return Signature + */ + public static Signature fromBase64(String string) { + return fromBase64(new ByteString(string.getBytes()), false); + } + + /** + * Convenience: fromBase64 from a plain String with recovery byte option + * + * @param string base64 encoded string + * @param recoveryByte whether the first byte is a recovery ID + * @return Signature + */ + public static Signature fromBase64(String string, boolean recoveryByte) { + return fromBase64(new ByteString(string.getBytes()), recoveryByte); } } diff --git a/src/main/java/com/starkbank/ellipticcurve/utils/BinaryAscii.java b/src/main/java/com/starkbank/ellipticcurve/utils/BinaryAscii.java index a67c704..afd3d07 100644 --- a/src/main/java/com/starkbank/ellipticcurve/utils/BinaryAscii.java +++ b/src/main/java/com/starkbank/ellipticcurve/utils/BinaryAscii.java @@ -6,7 +6,6 @@ public final class BinaryAscii { /** - * * @param string byteString * @return String */ @@ -15,13 +14,11 @@ public static String hexFromBinary(ByteString string) { } /** - * * @param bytes byte[] * @return String */ public static String hexFromBinary(byte[] bytes) { StringBuilder hexString = new StringBuilder(); - for (byte aByte : bytes) { String hex = Integer.toHexString(0xFF & aByte); if (hex.length() == 1) { @@ -33,21 +30,23 @@ public static String hexFromBinary(byte[] bytes) { } /** - * * @param string string * @return byte[] */ public static byte[] binaryFromHex(String string) { - byte[] bytes = new BigInteger(string, 16).toByteArray(); - int i = 0; - while (i < bytes.length && bytes[i] == 0) { - i++; + if (string.length() % 2 != 0) { + string = "0" + string; + } + int len = string.length(); + byte[] data = new byte[len / 2]; + for (int i = 0; i < len; i += 2) { + data[i / 2] = (byte) ((Character.digit(string.charAt(i), 16) << 4) + + Character.digit(string.charAt(i + 1), 16)); } - return Arrays.copyOfRange(bytes, i, bytes.length); + return data; } /** - * * @param c c * @return byte[] */ @@ -56,25 +55,25 @@ public static byte[] toBytes(int c) { } /** - * Get a number representation of a string + * Get a number representation of a byte array * - * @param string String to be converted in a number - * @return Number in hex from string + * @param string byte[] to be converted to a number + * @return BigInteger */ public static BigInteger numberFromString(byte[] string) { - return new BigInteger(BinaryAscii.hexFromBinary(string), 16); + return new BigInteger(1, string); } /** * Get a string representation of a number * * @param number number to be converted in a string - * @param length length max number of character for the string - * @return hexadecimal string + * @param length length max number of bytes for the string + * @return ByteString */ public static ByteString stringFromNumber(BigInteger number, int length) { String fmtStr = "%0" + String.valueOf(2 * length) + "x"; String hexString = String.format(fmtStr, number); - return new ByteString(BinaryAscii.binaryFromHex(hexString)); + return new ByteString(binaryFromHex(hexString)); } } diff --git a/src/main/java/com/starkbank/ellipticcurve/utils/RandomInteger.java b/src/main/java/com/starkbank/ellipticcurve/utils/RandomInteger.java index 06f978c..cd2fe31 100644 --- a/src/main/java/com/starkbank/ellipticcurve/utils/RandomInteger.java +++ b/src/main/java/com/starkbank/ellipticcurve/utils/RandomInteger.java @@ -1,19 +1,177 @@ package com.starkbank.ellipticcurve.utils; + +import com.starkbank.ellipticcurve.Curve; import java.math.BigInteger; import java.security.SecureRandom; -import java.util.Random; +import java.util.Arrays; +import java.util.Iterator; +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; public class RandomInteger { + private static final SecureRandom secureRandom = new SecureRandom(); + /** + * Return integer x in the range: start <= x <= end * - * @param start start - * @param end end + * @param start minimum value of the integer + * @param end maximum value of the integer * @return BigInteger */ public static BigInteger between(BigInteger start, BigInteger end) { - Random random = new SecureRandom(); - return new BigInteger(end.toByteArray().length * 8 - 1, random).abs().add(start); + BigInteger range = end.subtract(start).add(BigInteger.ONE); + int bits = range.bitLength(); + BigInteger result; + do { + result = new BigInteger(bits, secureRandom); + } while (result.compareTo(range) >= 0); + return result.add(start); + } + + /** + * Generate nonce values per hedged RFC 6979: deterministic k derivation + * with fresh random entropy mixed into K-init (RFC 6979 §3.6). Same message + * and key yield different signatures, while preserving RFC 6979's protection + * against RNG failures. + * + * @param hashBytes the hash of the message + * @param secret the private key secret + * @param curve the curve + * @param algorithm the HMAC algorithm name (e.g., "HmacSHA256") + * @return an iterator of candidate k values + */ + public static Iterator rfc6979(byte[] hashBytes, BigInteger secret, Curve curve, String algorithm) { + int orderBitLen = curve.nBitLength; + int orderByteLen = (orderBitLen + 7) / 8; + + // Secret bytes, zero-padded to orderByteLen + byte[] secretBytes = bigIntToFixedBytes(secret, orderByteLen); + + // Hash reduced mod N, then zero-padded to orderByteLen + BigInteger hashReduced = numberFromByteString(hashBytes, orderBitLen).mod(curve.N); + byte[] hashOctets = bigIntToFixedBytes(hashReduced, orderByteLen); + + // Fresh random entropy mixed into K-init per RFC 6979 §3.6 (hedged). + byte[] extraEntropy = new byte[orderByteLen]; + secureRandom.nextBytes(extraEntropy); + + int hLen = getHmacLength(algorithm); + + byte[] V = new byte[hLen]; + Arrays.fill(V, (byte) 0x01); + byte[] K = new byte[hLen]; + Arrays.fill(K, (byte) 0x00); + + // K = HMAC(K, V || 0x00 || secretBytes || hashOctets || extraEntropy) + K = hmac(algorithm, K, concat(V, new byte[]{0x00}, secretBytes, hashOctets, extraEntropy)); + V = hmac(algorithm, K, V); + // K = HMAC(K, V || 0x01 || secretBytes || hashOctets || extraEntropy) + K = hmac(algorithm, K, concat(V, new byte[]{0x01}, secretBytes, hashOctets, extraEntropy)); + V = hmac(algorithm, K, V); + + final byte[] finalK = K; + final byte[] finalV = V; + final BigInteger curveN = curve.N; + final String algo = algorithm; + final int bitLen = orderBitLen; + + return new Iterator() { + private byte[] k = finalK; + private byte[] v = finalV; + + @Override + public boolean hasNext() { + return true; + } + + @Override + public BigInteger next() { + while (true) { + byte[] T = new byte[0]; + while (T.length * 8 < bitLen) { + v = hmac(algo, k, v); + T = concat(T, v); + } + + BigInteger candidate = numberFromByteString(T, bitLen); + + if (candidate.compareTo(BigInteger.ONE) >= 0 && candidate.compareTo(curveN.subtract(BigInteger.ONE)) <= 0) { + // Prepare for next call + k = hmac(algo, k, concat(v, new byte[]{0x00})); + v = hmac(algo, k, v); + return candidate; + } + + k = hmac(algo, k, concat(v, new byte[]{0x00})); + v = hmac(algo, k, v); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + public static BigInteger numberFromByteString(byte[] bytes, int bitLength) { + BigInteger number = new BigInteger(1, bytes); + int hashBitLen = bytes.length * 8; + if (bitLength > 0 && hashBitLen > bitLength) { + number = number.shiftRight(hashBitLen - bitLength); + } + return number; + } + + private static byte[] bigIntToFixedBytes(BigInteger value, int length) { + String hex = value.toString(16); + while (hex.length() < length * 2) { + hex = "0" + hex; + } + return hexToBytes(hex); + } + + private static byte[] hexToBytes(String hex) { + int len = hex.length(); + byte[] data = new byte[len / 2]; + for (int i = 0; i < len; i += 2) { + data[i / 2] = (byte) ((Character.digit(hex.charAt(i), 16) << 4) + + Character.digit(hex.charAt(i + 1), 16)); + } + return data; + } + + static byte[] hmac(String algorithm, byte[] key, byte[] data) { + try { + Mac mac = Mac.getInstance(algorithm); + mac.init(new SecretKeySpec(key, algorithm)); + return mac.doFinal(data); + } catch (Exception e) { + throw new RuntimeException("HMAC computation failed", e); + } + } + + static byte[] concat(byte[]... arrays) { + int totalLen = 0; + for (byte[] a : arrays) totalLen += a.length; + byte[] result = new byte[totalLen]; + int offset = 0; + for (byte[] a : arrays) { + System.arraycopy(a, 0, result, offset, a.length); + offset += a.length; + } + return result; + } + + private static int getHmacLength(String algorithm) { + try { + Mac mac = Mac.getInstance(algorithm); + mac.init(new SecretKeySpec(new byte[1], algorithm)); + return mac.getMacLength(); + } catch (Exception e) { + throw new RuntimeException("Could not determine HMAC length for " + algorithm, e); + } } } diff --git a/src/test/java/com/starkbank/ellipticcurve/CompPubKeyTest.java b/src/test/java/com/starkbank/ellipticcurve/CompPubKeyTest.java new file mode 100644 index 0000000..0eba2c6 --- /dev/null +++ b/src/test/java/com/starkbank/ellipticcurve/CompPubKeyTest.java @@ -0,0 +1,74 @@ +package com.starkbank.ellipticcurve; + +import org.junit.Test; +import static org.junit.Assert.assertEquals; + + +public class CompPubKeyTest { + + @Test + public void testBatch() { + for (int i = 0; i < 100; i++) { + PrivateKey privateKey = new PrivateKey(); + PublicKey publicKey = privateKey.publicKey(); + String publicKeyString = publicKey.toCompressed(); + + PublicKey recoveredPublicKey = PublicKey.fromCompressed(publicKeyString, publicKey.curve); + + assertEquals(publicKey.point.x, recoveredPublicKey.point.x); + assertEquals(publicKey.point.y, recoveredPublicKey.point.y); + } + } + + @Test + public void testFromCompressedEven() { + String publicKeyCompressed = "0252972572d465d016d4c501887b8df303eee3ed602c056b1eb09260dfa0da0ab2"; + PublicKey publicKey = PublicKey.fromCompressed(publicKeyCompressed); + String pem = publicKey.toPem(); + assertEquals( + "-----BEGIN PUBLIC KEY-----\n" + + "MFYwEAYHKoZIzj0CAQYFK4EEAAoDQgAEUpclctRl0BbUxQGIe43zA+7j7WAsBWse\n" + + "sJJg36DaCrKIdC9NyX2e22/ZRrq8AC/fsG8myvEXuUBe15J1dj/bHA==\n" + + "-----END PUBLIC KEY-----\n", + pem + ); + } + + @Test + public void testFromCompressedOdd() { + String publicKeyCompressed = "0318ed2e1ec629e2d3dae7be1103d4f911c24e0c80e70038f5eb5548245c475f50"; + PublicKey publicKey = PublicKey.fromCompressed(publicKeyCompressed); + String pem = publicKey.toPem(); + assertEquals( + "-----BEGIN PUBLIC KEY-----\n" + + "MFYwEAYHKoZIzj0CAQYFK4EEAAoDQgAEGO0uHsYp4tPa574RA9T5EcJODIDnADj1\n" + + "61VIJFxHX1BMIg0B4cpBnLG6SzOTthXpndIKpr8HEHj3D9lJAI50EQ==\n" + + "-----END PUBLIC KEY-----\n", + pem + ); + } + + @Test + public void testToCompressedEven() { + PublicKey publicKey = PublicKey.fromPem( + "-----BEGIN PUBLIC KEY-----\n" + + "MFYwEAYHKoZIzj0CAQYFK4EEAAoDQgAEUpclctRl0BbUxQGIe43zA+7j7WAsBWse\n" + + "sJJg36DaCrKIdC9NyX2e22/ZRrq8AC/fsG8myvEXuUBe15J1dj/bHA==\n" + + "-----END PUBLIC KEY-----" + ); + String publicKeyCompressed = publicKey.toCompressed(); + assertEquals("0252972572d465d016d4c501887b8df303eee3ed602c056b1eb09260dfa0da0ab2", publicKeyCompressed); + } + + @Test + public void testToCompressedOdd() { + PublicKey publicKey = PublicKey.fromPem( + "-----BEGIN PUBLIC KEY-----\n" + + "MFYwEAYHKoZIzj0CAQYFK4EEAAoDQgAEGO0uHsYp4tPa574RA9T5EcJODIDnADj1\n" + + "61VIJFxHX1BMIg0B4cpBnLG6SzOTthXpndIKpr8HEHj3D9lJAI50EQ==\n" + + "-----END PUBLIC KEY-----" + ); + String publicKeyCompressed = publicKey.toCompressed(); + assertEquals("0318ed2e1ec629e2d3dae7be1103d4f911c24e0c80e70038f5eb5548245c475f50", publicKeyCompressed); + } +} diff --git a/src/test/java/com/starkbank/ellipticcurve/CurveTest.java b/src/test/java/com/starkbank/ellipticcurve/CurveTest.java new file mode 100644 index 0000000..d1bbfc0 --- /dev/null +++ b/src/test/java/com/starkbank/ellipticcurve/CurveTest.java @@ -0,0 +1,86 @@ +package com.starkbank.ellipticcurve; + +import org.junit.Test; +import java.math.BigInteger; +import static org.junit.Assert.assertTrue; + + +public class CurveTest { + + @Test + public void testSupportedCurve() { + Curve newCurve = new Curve( + BigInteger.ZERO, + BigInteger.valueOf(7), + new BigInteger("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f", 16), + new BigInteger("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141", 16), + new BigInteger("79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798", 16), + new BigInteger("483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8", 16), + "secp256k1", + new long[]{1, 3, 132, 0, 10} + ); + PrivateKey privateKey1 = new PrivateKey(newCurve, null); + PublicKey publicKey1 = privateKey1.publicKey(); + + String privateKeyPem = privateKey1.toPem(); + String publicKeyPem = publicKey1.toPem(); + + PrivateKey privateKey2 = PrivateKey.fromPem(privateKeyPem); + PublicKey publicKey2 = PublicKey.fromPem(publicKeyPem); + + String message = "test"; + + String signatureBase64 = Ecdsa.sign(message, privateKey2).toBase64(); + Signature signature = Signature.fromBase64(signatureBase64); + + assertTrue(Ecdsa.verify(message, signature, publicKey2)); + } + + @Test + public void testAddNewCurve() { + Curve newCurve = new Curve( + new BigInteger("f1fd178c0b3ad58f10126de8ce42435b3961adbcabc8ca6de8fcf353d86e9c00", 16), + new BigInteger("ee353fca5428a9300d4aba754a44c00fdfec0c9ae4b1a1803075ed967b7bb73f", 16), + new BigInteger("f1fd178c0b3ad58f10126de8ce42435b3961adbcabc8ca6de8fcf353d86e9c03", 16), + new BigInteger("f1fd178c0b3ad58f10126de8ce42435b53dc67e140d2bf941ffdd459c6d655e1", 16), + new BigInteger("b6b3d4c356c139eb31183d4749d423958c27d2dcaf98b70164c97a2dd98f5cff", 16), + new BigInteger("6142e0f7c8b204911f9271f0f3ecef8c2701c307e8e4c9e183115a1554062cfb", 16), + "frp256v1", + new long[]{1, 2, 250, 1, 223, 101, 256, 1} + ); + Curve.add(newCurve); + PrivateKey privateKey1 = new PrivateKey(newCurve, null); + PublicKey publicKey1 = privateKey1.publicKey(); + + String privateKeyPem = privateKey1.toPem(); + String publicKeyPem = publicKey1.toPem(); + + PrivateKey privateKey2 = PrivateKey.fromPem(privateKeyPem); + PublicKey publicKey2 = PublicKey.fromPem(publicKeyPem); + + String message = "test"; + + String signatureBase64 = Ecdsa.sign(message, privateKey2).toBase64(); + Signature signature = Signature.fromBase64(signatureBase64); + + assertTrue(Ecdsa.verify(message, signature, publicKey2)); + } + + @Test(expected = RuntimeException.class) + public void testUnsupportedCurve() { + Curve newCurve = new Curve( + new BigInteger("a9fb57dba1eea9bc3e660a909d838d726e3bf623d52620282013481d1f6e5374", 16), + new BigInteger("662c61c430d84ea4fe66a7733d0b76b7bf93ebc4af2f49256ae58101fee92b04", 16), + new BigInteger("a9fb57dba1eea9bc3e660a909d838d726e3bf623d52620282013481d1f6e5377", 16), + new BigInteger("a9fb57dba1eea9bc3e660a909d838d718c397aa3b561a6f7901e0e82974856a7", 16), + new BigInteger("a3e8eb3cc1cfe7b7732213b23a656149afa142c47aafbc2b79a191562e1305f4", 16), + new BigInteger("2d996c823439c56d7f7b22e14644417e69bcb6de39d027001dabe8f35b25c9be", 16), + "brainpoolP256t1", + new long[]{1, 3, 36, 3, 3, 2, 8, 1, 1, 8} + ); + + String privateKeyPem = new PrivateKey(newCurve, null).toPem(); + // This should throw because brainpoolP256t1 is not registered + PrivateKey.fromPem(privateKeyPem); + } +} diff --git a/src/test/java/com/starkbank/ellipticcurve/EcdsaTest.java b/src/test/java/com/starkbank/ellipticcurve/EcdsaTest.java index 4283664..4744e9c 100644 --- a/src/test/java/com/starkbank/ellipticcurve/EcdsaTest.java +++ b/src/test/java/com/starkbank/ellipticcurve/EcdsaTest.java @@ -1,8 +1,7 @@ package com.starkbank.ellipticcurve; -import org.junit.Test; +import org.junit.Test; import java.math.BigInteger; - import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -38,8 +37,8 @@ public void testZeroSignature() { PrivateKey privateKey = new PrivateKey(); PublicKey publicKey = privateKey.publicKey(); - String message = "This is the right message"; + String message2 = "This is the wrong message"; - assertFalse(Ecdsa.verify(message, new Signature(BigInteger.ZERO, BigInteger.ZERO), publicKey)); + assertFalse(Ecdsa.verify(message2, new Signature(BigInteger.ZERO, BigInteger.ZERO), publicKey)); } } diff --git a/src/test/java/com/starkbank/ellipticcurve/OpenSSLTest.java b/src/test/java/com/starkbank/ellipticcurve/OpenSSLTest.java index 85354b0..ee305fa 100644 --- a/src/test/java/com/starkbank/ellipticcurve/OpenSSLTest.java +++ b/src/test/java/com/starkbank/ellipticcurve/OpenSSLTest.java @@ -1,4 +1,5 @@ package com.starkbank.ellipticcurve; + import com.starkbank.ellipticcurve.utils.ByteString; import org.junit.Test; import java.io.IOException; @@ -10,7 +11,7 @@ public class OpenSSLTest { @Test public void testAssign() throws URISyntaxException, IOException { - // Generated by:openssl ecparam -name secp256k1 - genkey - out privateKey.pem + // Generated by: openssl ecparam -name secp256k1 -genkey -out privateKey.pem String privateKeyPem = Utils.readFileAsString("privateKey.pem"); PrivateKey privateKey = PrivateKey.fromPem(privateKeyPem); @@ -26,7 +27,7 @@ public void testAssign() throws URISyntaxException, IOException { @Test public void testVerifySignature() throws IOException, URISyntaxException { - // openssl ec -in privateKey.pem - pubout - out publicKey.pem + // openssl ec -in privateKey.pem -pubout -out publicKey.pem String publicKeyPem = Utils.readFileAsString("publicKey.pem"); // openssl dgst -sha256 -sign privateKey.pem -out signature.binary message.txt ByteString signatureBin = new ByteString(Utils.readFileAsBytes("signature.binary")); diff --git a/src/test/java/com/starkbank/ellipticcurve/PrivateKeyTest.java b/src/test/java/com/starkbank/ellipticcurve/PrivateKeyTest.java index a3fa823..1613b39 100644 --- a/src/test/java/com/starkbank/ellipticcurve/PrivateKeyTest.java +++ b/src/test/java/com/starkbank/ellipticcurve/PrivateKeyTest.java @@ -1,4 +1,5 @@ package com.starkbank.ellipticcurve; + import com.starkbank.ellipticcurve.utils.ByteString; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -7,7 +8,7 @@ public class PrivateKeyTest { @Test - public void testPemConversion() { + public void testPemConversion() { PrivateKey privateKey1 = new PrivateKey(); String pem = privateKey1.toPem(); PrivateKey privateKey2 = PrivateKey.fromPem(pem); @@ -25,7 +26,7 @@ public void testDerConversion() { } @Test - public void testStringConversion() { + public void testStringConversion() { PrivateKey privateKey1 = new PrivateKey(); ByteString string = privateKey1.toByteString(); PrivateKey privateKey2 = PrivateKey.fromString(string); diff --git a/src/test/java/com/starkbank/ellipticcurve/PublicKeyTest.java b/src/test/java/com/starkbank/ellipticcurve/PublicKeyTest.java index 638b790..e479276 100644 --- a/src/test/java/com/starkbank/ellipticcurve/PublicKeyTest.java +++ b/src/test/java/com/starkbank/ellipticcurve/PublicKeyTest.java @@ -1,4 +1,5 @@ package com.starkbank.ellipticcurve; + import com.starkbank.ellipticcurve.utils.ByteString; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/src/test/java/com/starkbank/ellipticcurve/RandomTest.java b/src/test/java/com/starkbank/ellipticcurve/RandomTest.java new file mode 100644 index 0000000..8346912 --- /dev/null +++ b/src/test/java/com/starkbank/ellipticcurve/RandomTest.java @@ -0,0 +1,29 @@ +package com.starkbank.ellipticcurve; + +import org.junit.Test; +import static org.junit.Assert.assertTrue; + + +public class RandomTest { + + @Test + public void testMany() { + for (int i = 0; i < 100; i++) { + PrivateKey privateKey1 = new PrivateKey(); + PublicKey publicKey1 = privateKey1.publicKey(); + + String privateKeyPem = privateKey1.toPem(); + String publicKeyPem = publicKey1.toPem(); + + PrivateKey privateKey2 = PrivateKey.fromPem(privateKeyPem); + PublicKey publicKey2 = PublicKey.fromPem(publicKeyPem); + + String message = "test"; + + String signatureBase64 = Ecdsa.sign(message, privateKey2).toBase64(); + Signature signature = Signature.fromBase64(signatureBase64); + + assertTrue(Ecdsa.verify(message, signature, publicKey2)); + } + } +} diff --git a/src/test/java/com/starkbank/ellipticcurve/SecurityTest.java b/src/test/java/com/starkbank/ellipticcurve/SecurityTest.java new file mode 100644 index 0000000..306f9bf --- /dev/null +++ b/src/test/java/com/starkbank/ellipticcurve/SecurityTest.java @@ -0,0 +1,542 @@ +package com.starkbank.ellipticcurve; + +import org.junit.Before; +import org.junit.Test; +import java.math.BigInteger; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; + +import static org.junit.Assert.*; + + +public class SecurityTest { + + // ===== Prime256v1PublicKeyDerivationTest (prime256v1/SHA-256) ===== + // RFC 6979 A.2.5 public key derivation. Signatures are hedged, so r/s + // no longer match fixed test vectors, but pubkey derivation is unchanged. + + public static class Prime256v1PublicKeyDerivationTest { + private PrivateKey privateKey; + private PublicKey publicKey; + + @Before + public void setUp() { + privateKey = new PrivateKey( + Curve.prime256v1, + new BigInteger("C9AFA9D845BA75166B5C215767B1D6934E50C3DB36E89B127B8A622B120F6721", 16) + ); + publicKey = privateKey.publicKey(); + } + + @Test + public void testPublicKeyMatchesRfc() { + assertEquals( + new BigInteger("60FED4BA255A9D31C961EB74C6356D68C049B8923B61FA6CE669622E60F29FB6", 16), + publicKey.point.x + ); + assertEquals( + new BigInteger("7903FE1008B8BC99A41AE9E95628BC64F2F1B20C2D7E9F5177A3C294D4462299", 16), + publicKey.point.y + ); + } + + @Test + public void testSampleMessageRoundTrip() { + Signature sig = Ecdsa.sign("sample", privateKey); + assertTrue(sig.s.compareTo(Curve.prime256v1.N.shiftRight(1)) <= 0); + assertTrue(Ecdsa.verify("sample", sig, publicKey)); + } + + @Test + public void testTestMessageRoundTrip() { + Signature sig = Ecdsa.sign("test", privateKey); + assertTrue(sig.s.compareTo(Curve.prime256v1.N.shiftRight(1)) <= 0); + assertTrue(Ecdsa.verify("test", sig, publicKey)); + } + } + + // ===== Secp256k1PublicKeyDerivationTest ===== + // secp256k1 with secret=1 (pubkey = generator G). + + public static class Secp256k1PublicKeyDerivationTest { + private PrivateKey privateKey; + private PublicKey publicKey; + + @Before + public void setUp() { + privateKey = new PrivateKey(Curve.secp256k1, BigInteger.ONE); + publicKey = privateKey.publicKey(); + } + + @Test + public void testPublicKeyIsGenerator() { + assertEquals(Curve.secp256k1.G.x, publicKey.point.x); + assertEquals(Curve.secp256k1.G.y, publicKey.point.y); + } + + @Test + public void testSampleMessageRoundTrip() { + Signature sig = Ecdsa.sign("sample", privateKey); + assertTrue(Ecdsa.verify("sample", sig, publicKey)); + } + + @Test + public void testTestMessageRoundTrip() { + Signature sig = Ecdsa.sign("test", privateKey); + assertTrue(Ecdsa.verify("test", sig, publicKey)); + } + } + + // ===== MalleabilityTest ===== + + public static class MalleabilityTest { + + @Test + public void testSignAlwaysProducesLowS() { + for (int i = 0; i < 100; i++) { + PrivateKey privateKey = new PrivateKey(); + Signature signature = Ecdsa.sign("test message", privateKey); + assertTrue(signature.s.compareTo(privateKey.curve.N.shiftRight(1)) <= 0); + } + } + + @Test + public void testHighSSignatureStillVerifies() { + PrivateKey privateKey = new PrivateKey(); + PublicKey publicKey = privateKey.publicKey(); + String message = "test message"; + + Signature signature = Ecdsa.sign(message, privateKey); + Signature highS = new Signature(signature.r, privateKey.curve.N.subtract(signature.s)); + + assertTrue(Ecdsa.verify(message, signature, publicKey)); + assertTrue(Ecdsa.verify(message, highS, publicKey)); + } + } + + // ===== PublicKeyValidationTest ===== + + public static class PublicKeyValidationTest { + + @Test + public void testRejectOffCurvePublicKey() { + PrivateKey privateKey = new PrivateKey(); + PublicKey publicKey = privateKey.publicKey(); + String message = "test message"; + + Signature signature = Ecdsa.sign(message, privateKey); + + Point offCurvePoint = new Point(publicKey.point.x, publicKey.point.y.add(BigInteger.ONE)); + PublicKey offCurveKey = new PublicKey(offCurvePoint, publicKey.curve); + + assertFalse(Ecdsa.verify(message, signature, offCurveKey)); + } + + @Test(expected = RuntimeException.class) + public void testFromStringRejectsOffCurvePoint() { + PublicKey p = new PrivateKey().publicKey(); + int baseLength = 2 * p.curve.length(); + String badY = leftPad(p.point.y.add(BigInteger.ONE).toString(16), baseLength); + String badHex = leftPad(p.point.x.toString(16), baseLength) + badY; + PublicKey.fromString(badHex, p.curve); + } + + @Test(expected = RuntimeException.class) + public void testFromStringRejectsInfinityPoint() { + int baseLength = 2 * Curve.secp256k1.length(); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < baseLength * 2; i++) sb.append("0"); + String zeroHex = sb.toString(); + PublicKey.fromString(zeroHex, Curve.secp256k1); + } + + private static String leftPad(String s, int length) { + while (s.length() < length) s = "0" + s; + return s; + } + } + + // ===== ForgeryAttemptTest ===== + + public static class ForgeryAttemptTest { + private PrivateKey privateKey; + private PublicKey publicKey; + private String message; + private Signature signature; + + @Before + public void setUp() { + privateKey = new PrivateKey(); + publicKey = privateKey.publicKey(); + message = "authentic message"; + signature = Ecdsa.sign(message, privateKey); + } + + @Test + public void testRejectZeroSignature() { + assertFalse(Ecdsa.verify(message, new Signature(BigInteger.ZERO, BigInteger.ZERO), publicKey)); + } + + @Test + public void testRejectREqualsZero() { + assertFalse(Ecdsa.verify(message, new Signature(BigInteger.ZERO, signature.s), publicKey)); + } + + @Test + public void testRejectSEqualsZero() { + assertFalse(Ecdsa.verify(message, new Signature(signature.r, BigInteger.ZERO), publicKey)); + } + + @Test + public void testRejectREqualsN() { + BigInteger N = publicKey.curve.N; + assertFalse(Ecdsa.verify(message, new Signature(N, signature.s), publicKey)); + } + + @Test + public void testRejectSEqualsN() { + BigInteger N = publicKey.curve.N; + assertFalse(Ecdsa.verify(message, new Signature(signature.r, N), publicKey)); + } + + @Test + public void testRejectRExceedsN() { + BigInteger N = publicKey.curve.N; + assertFalse(Ecdsa.verify(message, new Signature(N.add(BigInteger.ONE), signature.s), publicKey)); + } + + @Test + public void testRejectArbitrarySignature() { + assertFalse(Ecdsa.verify(message, new Signature(BigInteger.ONE, BigInteger.ONE), publicKey)); + } + + @Test + public void testRejectBoundarySignature() { + BigInteger N = publicKey.curve.N; + assertFalse(Ecdsa.verify(message, new Signature(N.subtract(BigInteger.ONE), N.subtract(BigInteger.ONE)), publicKey)); + } + + @Test + public void testWrongKeyRejected() { + PublicKey otherKey = new PrivateKey().publicKey(); + assertFalse(Ecdsa.verify(message, signature, otherKey)); + } + } + + // ===== HedgedSignatureTest ===== + + public static class HedgedSignatureTest { + + @Test + public void testSameInputsProduceDifferentSignatures() { + PrivateKey privateKey = new PrivateKey(); + String message = "test message"; + + Signature signature1 = Ecdsa.sign(message, privateKey); + Signature signature2 = Ecdsa.sign(message, privateKey); + + assertTrue(!signature1.r.equals(signature2.r) || !signature1.s.equals(signature2.s)); + } + + @Test + public void testDifferentMessagesDifferentSignatures() { + PrivateKey privateKey = new PrivateKey(); + + Signature signature1 = Ecdsa.sign("message 1", privateKey); + Signature signature2 = Ecdsa.sign("message 2", privateKey); + + assertTrue(!signature1.r.equals(signature2.r) || !signature1.s.equals(signature2.s)); + } + + @Test + public void testDifferentKeysDifferentSignatures() { + String message = "test message"; + + Signature signature1 = Ecdsa.sign(message, new PrivateKey()); + Signature signature2 = Ecdsa.sign(message, new PrivateKey()); + + assertTrue(!signature1.r.equals(signature2.r) || !signature1.s.equals(signature2.s)); + } + } + + // ===== EdgeCaseMessageTest ===== + + public static class EdgeCaseMessageTest { + private PrivateKey privateKey; + private PublicKey publicKey; + + @Before + public void setUp() { + privateKey = new PrivateKey(); + publicKey = privateKey.publicKey(); + } + + private void signAndVerify(String message) { + Signature sig = Ecdsa.sign(message, privateKey); + assertTrue(Ecdsa.verify(message, sig, publicKey)); + assertFalse(Ecdsa.verify(message + "x", sig, publicKey)); + } + + @Test + public void testEmptyMessage() { + signAndVerify(""); + } + + @Test + public void testSingleCharMessage() { + signAndVerify("a"); + } + + @Test + public void testUnicodeMessage() { + signAndVerify("\u00e9\u00e8\u00ea\u00eb"); + } + + @Test + public void testEmojiMessage() { + signAndVerify("\uD83D\uDD12\uD83D\uDD11"); + } + + @Test + public void testNullByteMessage() { + signAndVerify("before\0after"); + } + + @Test + public void testLongMessage() { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 10000; i++) sb.append("a"); + signAndVerify(sb.toString()); + } + + @Test + public void testNewlinesAndWhitespace() { + signAndVerify(" line1\n\tline2\r\n "); + } + } + + // ===== SerializationRoundTripTest ===== + + public static class SerializationRoundTripTest { + private PrivateKey privateKey; + private PublicKey publicKey; + private String message; + private Signature signature; + + @Before + public void setUp() { + privateKey = new PrivateKey(); + publicKey = privateKey.publicKey(); + message = "round-trip test"; + signature = Ecdsa.sign(message, privateKey); + } + + @Test + public void testSignatureDerRoundTrip() { + com.starkbank.ellipticcurve.utils.ByteString der = signature.toDer(); + Signature restored = Signature.fromDer(der); + assertEquals(restored.r, signature.r); + assertEquals(restored.s, signature.s); + assertTrue(Ecdsa.verify(message, restored, publicKey)); + } + + @Test + public void testSignatureBase64RoundTrip() { + String b64 = signature.toBase64(); + Signature restored = Signature.fromBase64(b64); + assertEquals(restored.r, signature.r); + assertEquals(restored.s, signature.s); + assertTrue(Ecdsa.verify(message, restored, publicKey)); + } + + @Test + public void testSignatureDerWithRecoveryIdRoundTrip() { + com.starkbank.ellipticcurve.utils.ByteString der = signature.toDer(true); + Signature restored = Signature.fromDer(der, true); + assertEquals(restored.r, signature.r); + assertEquals(restored.s, signature.s); + assertEquals(restored.recoveryId, signature.recoveryId); + } + + @Test + public void testPrivateKeyPemRoundTrip() { + String pem = privateKey.toPem(); + PrivateKey restored = PrivateKey.fromPem(pem); + assertEquals(restored.secret, privateKey.secret); + assertEquals(restored.curve.name, privateKey.curve.name); + } + + @Test + public void testPrivateKeyDerRoundTrip() { + com.starkbank.ellipticcurve.utils.ByteString der = privateKey.toDer(); + PrivateKey restored = PrivateKey.fromDer(der); + assertEquals(restored.secret, privateKey.secret); + } + + @Test + public void testPublicKeyPemRoundTrip() { + String pem = publicKey.toPem(); + PublicKey restored = PublicKey.fromPem(pem); + assertEquals(restored.point.x, publicKey.point.x); + assertEquals(restored.point.y, publicKey.point.y); + } + + @Test + public void testPublicKeyCompressedRoundTrip() { + String compressed = publicKey.toCompressed(); + PublicKey restored = PublicKey.fromCompressed(compressed, publicKey.curve); + assertEquals(restored.point.x, publicKey.point.x); + assertEquals(restored.point.y, publicKey.point.y); + assertTrue(Ecdsa.verify(message, signature, restored)); + } + + @Test + public void testPublicKeyCompressedEvenAndOdd() { + for (int i = 0; i < 20; i++) { + PrivateKey pk = new PrivateKey(); + PublicKey pub = pk.publicKey(); + String compressed = pub.toCompressed(); + PublicKey restored = PublicKey.fromCompressed(compressed, pub.curve); + assertEquals(restored.point.x, pub.point.x); + assertEquals(restored.point.y, pub.point.y); + } + } + + @Test + public void testPrime256v1KeyRoundTrip() { + PrivateKey pk = new PrivateKey(Curve.prime256v1, null); + String pem = pk.toPem(); + PrivateKey restored = PrivateKey.fromPem(pem); + assertEquals(restored.secret, pk.secret); + assertEquals("prime256v1", restored.curve.name); + } + } + + // ===== TonelliShanksTest ===== + + public static class TonelliShanksTest { + + @Test + public void testPrimeCongruent1Mod4() { + // P = 17: 17 - 1 = 16 = 2^4, S = 4, exercises full Tonelli-Shanks + BigInteger P = BigInteger.valueOf(17); + for (int value = 1; value < 17; value++) { + BigInteger val = BigInteger.valueOf(value); + BigInteger halfP = P.subtract(BigInteger.ONE).divide(BigInteger.TWO); + if (val.modPow(halfP, P).equals(BigInteger.ONE)) { + BigInteger root = Math.modularSquareRoot(val, P); + assertEquals(val, root.multiply(root).mod(P)); + } + } + } + + @Test + public void testPrimeCongruent5Mod8() { + // P = 13: 13 - 1 = 12 = 3 * 2^2, S = 2 + BigInteger P = BigInteger.valueOf(13); + for (int value = 1; value < 13; value++) { + BigInteger val = BigInteger.valueOf(value); + BigInteger halfP = P.subtract(BigInteger.ONE).divide(BigInteger.TWO); + if (val.modPow(halfP, P).equals(BigInteger.ONE)) { + BigInteger root = Math.modularSquareRoot(val, P); + assertEquals(val, root.multiply(root).mod(P)); + } + } + } + + @Test + public void testPrimeCongruent3Mod4() { + // P = 7: fast path (S = 1) + BigInteger P = BigInteger.valueOf(7); + for (int value = 1; value < 7; value++) { + BigInteger val = BigInteger.valueOf(value); + BigInteger halfP = P.subtract(BigInteger.ONE).divide(BigInteger.TWO); + if (val.modPow(halfP, P).equals(BigInteger.ONE)) { + BigInteger root = Math.modularSquareRoot(val, P); + assertEquals(val, root.multiply(root).mod(P)); + } + } + } + + @Test + public void testZeroValue() { + assertEquals(BigInteger.ZERO, Math.modularSquareRoot(BigInteger.ZERO, BigInteger.valueOf(17))); + } + } + + // ===== HashTruncationTest ===== + + public static class HashTruncationTest { + + @Test + public void testSignVerifyWithSha512() throws NoSuchAlgorithmException { + PrivateKey privateKey = new PrivateKey(); + PublicKey publicKey = privateKey.publicKey(); + String message = "test message"; + + Signature signature = Ecdsa.sign(message, privateKey, MessageDigest.getInstance("SHA-512")); + + assertTrue(Ecdsa.verify(message, signature, publicKey, MessageDigest.getInstance("SHA-512"))); + assertFalse(Ecdsa.verify("wrong message", signature, publicKey, MessageDigest.getInstance("SHA-512"))); + } + + @Test + public void testSha512SignaturesAreHedged() throws NoSuchAlgorithmException { + PrivateKey privateKey = new PrivateKey(); + String message = "test message"; + + Signature signature1 = Ecdsa.sign(message, privateKey, MessageDigest.getInstance("SHA-512")); + Signature signature2 = Ecdsa.sign(message, privateKey, MessageDigest.getInstance("SHA-512")); + + assertTrue(!signature1.r.equals(signature2.r) || !signature1.s.equals(signature2.s)); + } + + @Test + public void testHashMismatchFails() throws NoSuchAlgorithmException { + PrivateKey privateKey = new PrivateKey(); + PublicKey publicKey = privateKey.publicKey(); + String message = "test message"; + + Signature signature = Ecdsa.sign(message, privateKey, MessageDigest.getInstance("SHA-256")); + assertFalse(Ecdsa.verify(message, signature, publicKey, MessageDigest.getInstance("SHA-512"))); + } + } + + // ===== Prime256v1SecurityTest ===== + + public static class Prime256v1SecurityTest { + + @Test + public void testSignVerify() { + PrivateKey privateKey = new PrivateKey(Curve.prime256v1, null); + PublicKey publicKey = privateKey.publicKey(); + String message = "test message"; + + Signature signature = Ecdsa.sign(message, privateKey); + + assertTrue(signature.s.compareTo(Curve.prime256v1.N.shiftRight(1)) <= 0); + assertTrue(Ecdsa.verify(message, signature, publicKey)); + } + + @Test + public void testSignaturesAreHedged() { + PrivateKey privateKey = new PrivateKey(Curve.prime256v1, null); + String message = "test message"; + + Signature signature1 = Ecdsa.sign(message, privateKey); + Signature signature2 = Ecdsa.sign(message, privateKey); + + assertTrue(!signature1.r.equals(signature2.r) || !signature1.s.equals(signature2.s)); + } + + @Test + public void testWrongCurveKeyFails() { + PrivateKey k1Key = new PrivateKey(Curve.secp256k1, null); + PrivateKey p256Key = new PrivateKey(Curve.prime256v1, null); + String message = "cross-curve test"; + + Signature sig = Ecdsa.sign(message, k1Key); + assertFalse(Ecdsa.verify(message, sig, p256Key.publicKey())); + } + } +} diff --git a/src/test/java/com/starkbank/ellipticcurve/SignatureTest.java b/src/test/java/com/starkbank/ellipticcurve/SignatureTest.java index 38bca1d..6a79c7d 100644 --- a/src/test/java/com/starkbank/ellipticcurve/SignatureTest.java +++ b/src/test/java/com/starkbank/ellipticcurve/SignatureTest.java @@ -1,7 +1,9 @@ package com.starkbank.ellipticcurve; + import com.starkbank.ellipticcurve.utils.ByteString; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; public class SignatureTest { @@ -30,9 +32,20 @@ public void testBase64Conversion() { String base64 = signature1.toBase64(); - Signature signature2 = Signature.fromBase64(new ByteString(base64.getBytes())); + Signature signature2 = Signature.fromBase64(base64); assertEquals(signature1.r, signature2.r); assertEquals(signature1.s, signature2.s); } + + @Test + public void testUniqueness() { + PrivateKey privateKey = new PrivateKey(); + String message = "This is a text message"; + + Signature signature1 = Ecdsa.sign(message, privateKey); + Signature signature2 = Ecdsa.sign(message, privateKey); + + assertNotEquals(signature1.toBase64(), signature2.toBase64()); + } } diff --git a/src/test/java/com/starkbank/ellipticcurve/SignatureWithRecoveryIdTest.java b/src/test/java/com/starkbank/ellipticcurve/SignatureWithRecoveryIdTest.java new file mode 100644 index 0000000..e1445b0 --- /dev/null +++ b/src/test/java/com/starkbank/ellipticcurve/SignatureWithRecoveryIdTest.java @@ -0,0 +1,39 @@ +package com.starkbank.ellipticcurve; + +import org.junit.Test; +import static org.junit.Assert.assertEquals; + + +public class SignatureWithRecoveryIdTest { + + @Test + public void testDerConversion() { + PrivateKey privateKey = new PrivateKey(); + String message = "This is a text message"; + + Signature signature1 = Ecdsa.sign(message, privateKey); + + com.starkbank.ellipticcurve.utils.ByteString der = signature1.toDer(true); + Signature signature2 = Signature.fromDer(der, true); + + assertEquals(signature1.r, signature2.r); + assertEquals(signature1.s, signature2.s); + assertEquals(signature1.recoveryId, signature2.recoveryId); + } + + @Test + public void testBase64Conversion() { + PrivateKey privateKey = new PrivateKey(); + String message = "This is a text message"; + + Signature signature1 = Ecdsa.sign(message, privateKey); + + String base64 = signature1.toBase64(true); + + Signature signature2 = Signature.fromBase64(base64, true); + + assertEquals(signature1.r, signature2.r); + assertEquals(signature1.s, signature2.s); + assertEquals(signature1.recoveryId, signature2.recoveryId); + } +}