Skip to content

Commit 6301728

Browse files
committed
refactor + restore raising specific PKey errors
1 parent 4e1cbbd commit 6301728

File tree

7 files changed

+69
-62
lines changed

7 files changed

+69
-62
lines changed

src/main/java/org/jruby/ext/openssl/PKeyDSA.java

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -144,24 +144,26 @@ public IRubyObject initialize_copy(final IRubyObject original) {
144144
public String getAlgorithm() { return "DSA"; }
145145

146146
@JRubyMethod(name = "generate", meta = true)
147-
public static IRubyObject generate(IRubyObject self, IRubyObject arg) {
148-
final Ruby runtime = self.getRuntime();
147+
public static IRubyObject generate(ThreadContext context, IRubyObject self, IRubyObject arg) {
149148
final int keySize = RubyNumeric.fix2int(arg);
150-
return dsaGenerate(runtime, new PKeyDSA(runtime, (RubyClass) self), keySize);
149+
return dsaGenerate(context.runtime, new PKeyDSA(context.runtime, (RubyClass) self), keySize);
150+
}
151+
152+
static PKeyDSA generateImpl(final Ruby runtime, PKeyDSA dsa, int keySize) throws NoSuchAlgorithmException {
153+
KeyPairGenerator gen = SecurityHelper.getKeyPairGenerator("DSA");
154+
gen.initialize(keySize, getSecureRandom(runtime));
155+
KeyPair pair = gen.generateKeyPair();
156+
dsa.privateKey = (DSAPrivateKey) pair.getPrivate();
157+
dsa.publicKey = (DSAPublicKey) pair.getPublic();
158+
return dsa;
151159
}
152160

153161
/*
154162
* c: dsa_generate
155163
*/
156-
private static PKeyDSA dsaGenerate(final Ruby runtime,
157-
PKeyDSA dsa, int keySize) throws RaiseException {
164+
private static PKeyDSA dsaGenerate(final Ruby runtime, PKeyDSA dsa, int keySize) throws RaiseException {
158165
try {
159-
KeyPairGenerator gen = SecurityHelper.getKeyPairGenerator("DSA");
160-
gen.initialize(keySize, getSecureRandom(runtime));
161-
KeyPair pair = gen.generateKeyPair();
162-
dsa.privateKey = (DSAPrivateKey) pair.getPrivate();
163-
dsa.publicKey = (DSAPublicKey) pair.getPublic();
164-
return dsa;
166+
return generateImpl(runtime, dsa, keySize);
165167
}
166168
catch (NoSuchAlgorithmException e) {
167169
throw newDSAError(runtime, e.getMessage());
@@ -303,7 +305,7 @@ public RubyBoolean private_p() {
303305
@JRubyMethod(name = "public_to_der")
304306
public RubyString public_to_der(ThreadContext context) {
305307
if (publicKey == null) {
306-
throw newPKeyError(context.runtime, "incompletely initialized DSA key");
308+
throw newDSAError(context.runtime, "incompletely initialized DSA key");
307309
}
308310
final byte[] bytes;
309311
try {
@@ -329,7 +331,7 @@ public RubyString to_der() {
329331
throw newDSAError(getRuntime(), bcExceptionMessage(e));
330332
}
331333
catch (IllegalArgumentException e) {
332-
throw newPKeyError(getRuntime(), e.getMessage());
334+
throw newDSAError(getRuntime(), e.getMessage());
333335
}
334336
catch (IOException e) {
335337
throw newDSAError(getRuntime(), e.getMessage(), e);
@@ -486,13 +488,13 @@ public IRubyObject verify_raw(IRubyObject digest, IRubyObject sign, IRubyObject
486488
return runtime.newBoolean(verify("NONEwithDSA", getPublicKey(), dataBytes, sigBytes));
487489
}
488490
catch (NoSuchAlgorithmException e) {
489-
throw newPKeyError(runtime, e.getMessage());
491+
throw newDSAError(runtime, e.getMessage());
490492
}
491493
catch (SignatureException e) {
492-
throw newPKeyError(runtime, "invalid signature");
494+
throw newDSAError(runtime, "invalid signature");
493495
}
494496
catch (InvalidKeyException e) {
495-
throw newPKeyError(runtime, "invalid key");
497+
throw newDSAError(runtime, "invalid key");
496498
}
497499
}
498500

src/main/java/org/jruby/ext/openssl/PKeyEC.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ public PKeyEC generate_key(final ThreadContext context) {
440440
this.privateKey = pair.getPrivate();
441441
}
442442
catch (NoSuchAlgorithmException | InvalidAlgorithmParameterException ex) {
443-
throw PKey.newPKeyError(context.runtime, ex.getMessage());
443+
throw newECError(context.runtime, ex.getMessage());
444444
}
445445
catch (GeneralSecurityException ex) {
446446
throw (RaiseException) newECError(context.runtime, ex.toString()).initCause(ex);
@@ -745,7 +745,7 @@ public RubyString to_der() {
745745
return public_to_der(runtime);
746746
}
747747
if (privateKey == null) {
748-
throw PKey.newPKeyError(runtime, "can't export - no public key set");
748+
throw newECError(runtime, "can't export - no public key set");
749749
}
750750

751751
try {

src/main/java/org/jruby/ext/openssl/PKeyRSA.java

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -220,32 +220,33 @@ public static IRubyObject generate(IRubyObject self, IRubyObject[] args) {
220220
return rsaGenerate(runtime, new PKeyRSA(runtime, (RubyClass) self), keySize, exp);
221221
}
222222

223+
static PKeyRSA generateImpl(final Ruby runtime, PKeyRSA rsa, int keySize, BigInteger exp)
224+
throws NoSuchAlgorithmException, InvalidAlgorithmParameterException {
225+
KeyPairGenerator gen = SecurityHelper.getKeyPairGenerator("RSA");
226+
if ( "IBMJCEFIPS".equals( gen.getProvider().getName() ) ) {
227+
gen.initialize(keySize); // IBMJCEFIPS does not support parameters
228+
} else {
229+
gen.initialize(new RSAKeyGenParameterSpec(keySize, exp), getSecureRandom(runtime));
230+
}
231+
KeyPair pair = gen.generateKeyPair();
232+
rsa.privateKey = (RSAPrivateCrtKey) pair.getPrivate();
233+
rsa.publicKey = (RSAPublicKey) pair.getPublic();
234+
return rsa;
235+
}
236+
223237
/*
224238
* c: rsa_generate
225239
*/
226-
private static PKeyRSA rsaGenerate(final Ruby runtime,
227-
PKeyRSA rsa, int keySize, BigInteger exp) throws RaiseException {
240+
static PKeyRSA rsaGenerate(final Ruby runtime, PKeyRSA rsa, int keySize, BigInteger exp) throws RaiseException {
228241
try {
229-
KeyPairGenerator gen = SecurityHelper.getKeyPairGenerator("RSA");
230-
if ( "IBMJCEFIPS".equals( gen.getProvider().getName() ) ) {
231-
gen.initialize(keySize); // IBMJCEFIPS does not support parameters
232-
} else {
233-
gen.initialize(new RSAKeyGenParameterSpec(keySize, exp), getSecureRandom(runtime));
234-
}
235-
KeyPair pair = gen.generateKeyPair();
236-
rsa.privateKey = (RSAPrivateCrtKey) pair.getPrivate();
237-
rsa.publicKey = (RSAPublicKey) pair.getPublic();
238-
}
239-
catch (NoSuchAlgorithmException e) {
240-
throw newRSAError(runtime, e.getMessage());
242+
return generateImpl(runtime, rsa, keySize, exp);
241243
}
242-
catch (InvalidAlgorithmParameterException e) {
244+
catch (NoSuchAlgorithmException|InvalidAlgorithmParameterException e) {
243245
throw newRSAError(runtime, e.getMessage());
244246
}
245247
catch (RuntimeException e) {
246-
throw newRSAError(rsa.getRuntime(), e);
248+
throw newRSAError(runtime, e);
247249
}
248-
return rsa;
249250
}
250251

251252
@JRubyMethod(rest = true, visibility = Visibility.PRIVATE)
@@ -336,7 +337,7 @@ public IRubyObject initialize(final ThreadContext context, final IRubyObject[] a
336337
if ( key == null ) key = tryPKCS8EncodedKey(runtime, rsaFactory, str.getBytes());
337338
if ( key == null ) key = tryX509EncodedKey(runtime, rsaFactory, str.getBytes());
338339

339-
if ( key == null ) throw newPKeyError(runtime, "Neither PUB key nor PRIV key:");
340+
if ( key == null ) throw newRSAError(runtime, "Neither PUB key nor PRIV key:");
340341

341342
if ( key instanceof KeyPair ) {
342343
PublicKey publicKey = ((KeyPair) key).getPublic();
@@ -615,7 +616,7 @@ private static ASN1ObjectIdentifier osslNameToCipherOid(final String osslName) {
615616

616617
private String getPadding(final int padding) {
617618
if ( padding < 1 || padding > 4 ) {
618-
throw newPKeyError(getRuntime(), "");
619+
throw newRSAError(getRuntime(), "");
619620
}
620621
// BC accepts "/NONE/*" but SunJCE doesn't. use "/ECB/*"
621622
String p = "/ECB/PKCS1Padding";
@@ -635,7 +636,7 @@ public IRubyObject private_encrypt(final ThreadContext context, final IRubyObjec
635636
if ( Arity.checkArgumentCount(context.runtime, args, 1, 2) == 2 && ! args[1].isNil() ) {
636637
padding = RubyNumeric.fix2int(args[1]);
637638
}
638-
if ( privateKey == null ) throw newPKeyError(context.runtime, "incomplete RSA");
639+
if ( privateKey == null ) throw newRSAError(context.runtime, "incomplete RSA");
639640
return doCipherRSA(context.runtime, args[0], padding, ENCRYPT_MODE, privateKey);
640641
}
641642

@@ -645,7 +646,7 @@ public IRubyObject private_decrypt(final ThreadContext context, final IRubyObjec
645646
if ( Arity.checkArgumentCount(context.runtime, args, 1, 2) == 2 && ! args[1].isNil()) {
646647
padding = RubyNumeric.fix2int(args[1]);
647648
}
648-
if ( privateKey == null ) throw newPKeyError(context.runtime, "incomplete RSA");
649+
if ( privateKey == null ) throw newRSAError(context.runtime, "incomplete RSA");
649650
return doCipherRSA(context.runtime, args[0], padding, DECRYPT_MODE, privateKey);
650651
}
651652

@@ -655,7 +656,7 @@ public IRubyObject public_encrypt(final ThreadContext context, final IRubyObject
655656
if ( Arity.checkArgumentCount(context.runtime, args, 1, 2) == 2 && ! args[1].isNil()) {
656657
padding = RubyNumeric.fix2int(args[1]);
657658
}
658-
if ( publicKey == null ) throw newPKeyError(context.runtime, "incomplete RSA");
659+
if ( publicKey == null ) throw newRSAError(context.runtime, "incomplete RSA");
659660
return doCipherRSA(context.runtime, args[0], padding, ENCRYPT_MODE, publicKey);
660661
}
661662

@@ -665,7 +666,7 @@ public IRubyObject public_decrypt(final ThreadContext context, final IRubyObject
665666
if ( Arity.checkArgumentCount(context.runtime, args, 1, 2) == 2 && ! args[1].isNil() ) {
666667
padding = RubyNumeric.fix2int(args[1]);
667668
}
668-
if ( publicKey == null ) throw newPKeyError(context.runtime, "incomplete RSA");
669+
if ( publicKey == null ) throw newRSAError(context.runtime, "incomplete RSA");
669670
return doCipherRSA(context.runtime, args[0], padding, DECRYPT_MODE, publicKey);
670671
}
671672

@@ -699,7 +700,7 @@ public IRubyObject oid() {
699700
@JRubyMethod(name = "sign_raw", required = 2, optional = 1)
700701
public IRubyObject sign_raw(ThreadContext context, IRubyObject[] args) {
701702
final Ruby runtime = context.runtime;
702-
if (privateKey == null) throw newPKeyError(runtime, "Private RSA key needed!");
703+
if (privateKey == null) throw newRSAError(runtime, "Private RSA key needed!");
703704

704705
final String digestAlg = getDigestAlgName(args[0]);
705706
final byte[] hashBytes = args[1].convertToString().getBytes();
@@ -715,7 +716,7 @@ public IRubyObject sign_raw(ThreadContext context, IRubyObject[] args) {
715716
try {
716717
return StringHelper.newString(runtime, signWithPSS(hashBytes, digestAlg, mgf1Alg, saltLen));
717718
} catch (IllegalArgumentException | CryptoException e) {
718-
throw (RaiseException) newPKeyError(runtime, e.getMessage()).initCause(e);
719+
throw (RaiseException) newRSAError(runtime, e.getMessage()).initCause(e);
719720
}
720721
}
721722
}
@@ -726,13 +727,13 @@ public IRubyObject sign_raw(ThreadContext context, IRubyObject[] args) {
726727
ByteList signed = sign("NONEwithRSA", privateKey, new ByteList(digestInfoBytes, false));
727728
return RubyString.newString(runtime, signed);
728729
} catch (IOException e) {
729-
throw newPKeyError(runtime, "failed to encode DigestInfo: " + e.getMessage());
730+
throw newRSAError(runtime, "failed to encode DigestInfo: " + e.getMessage());
730731
} catch (NoSuchAlgorithmException e) {
731-
throw newPKeyError(runtime, "unsupported algorithm: NONEwithRSA");
732+
throw newRSAError(runtime, "unsupported algorithm: NONEwithRSA");
732733
} catch (InvalidKeyException e) {
733-
throw newPKeyError(runtime, "invalid key");
734+
throw newRSAError(runtime, "invalid key");
734735
} catch (SignatureException e) {
735-
throw newPKeyError(runtime, e.getMessage());
736+
throw newRSAError(runtime, e.getMessage());
736737
}
737738
}
738739

@@ -765,11 +766,11 @@ public IRubyObject verify_raw(ThreadContext context, IRubyObject[] args) {
765766
new ByteList(sigBytes, false));
766767
return runtime.newBoolean(ok);
767768
} catch (IOException e) {
768-
throw newPKeyError(runtime, "failed to encode DigestInfo: " + e.getMessage());
769+
throw newRSAError(runtime, "failed to encode DigestInfo: " + e.getMessage());
769770
} catch (NoSuchAlgorithmException e) {
770-
throw newPKeyError(runtime, "unsupported algorithm: NONEwithRSA");
771+
throw newRSAError(runtime, "unsupported algorithm: NONEwithRSA");
771772
} catch (InvalidKeyException e) {
772-
throw newPKeyError(runtime, "invalid key");
773+
throw newRSAError(runtime, "invalid key");
773774
} catch (SignatureException e) {
774775
return runtime.getFalse();
775776
}
@@ -819,7 +820,7 @@ public IRubyObject sign(ThreadContext context, IRubyObject[] args) {
819820
if (!(opts instanceof RubyHash)) throw runtime.newTypeError("expected Hash");
820821
String paddingMode = Utils.extractStringOpt(context, opts, "rsa_padding_mode", true);
821822
if ("pss".equalsIgnoreCase(paddingMode)) {
822-
if (privateKey == null) throw newPKeyError(runtime, "Private RSA key needed!");
823+
if (privateKey == null) throw newRSAError(runtime, "Private RSA key needed!");
823824
final String digestAlg = getDigestAlgName(digest);
824825
int saltLen = Utils.extractIntOpt(context, opts, "rsa_pss_saltlen", -1, true);
825826
String mgf1Alg = Utils.extractStringOpt(context, opts, "rsa_mgf1_md", true);
@@ -830,7 +831,7 @@ public IRubyObject sign(ThreadContext context, IRubyObject[] args) {
830831
try {
831832
signedData = signDataWithPSS(runtime, data.convertToString(), digestAlg, mgf1Alg, saltLen);
832833
} catch (IllegalArgumentException | DataLengthException | CryptoException e) {
833-
throw (RaiseException) newPKeyError(runtime, e.getMessage()).initCause(e);
834+
throw (RaiseException) newRSAError(runtime, e.getMessage()).initCause(e);
834835
}
835836
return StringHelper.newString(runtime, signedData);
836837
}
@@ -843,7 +844,7 @@ public IRubyObject sign(ThreadContext context, IRubyObject[] args) {
843844
@JRubyMethod(name = "sign_pss", required = 2, optional = 1)
844845
public IRubyObject sign_pss(ThreadContext context, IRubyObject[] args) {
845846
final Ruby runtime = context.runtime;
846-
if (privateKey == null) throw newPKeyError(runtime, "Private RSA key needed!");
847+
if (privateKey == null) throw newRSAError(runtime, "Private RSA key needed!");
847848
final String digestAlg = getDigestAlgName(args[0]);
848849
final IRubyObject opts = args.length > 2 ? args[2] : context.nil;
849850
final int maxSalt = maxPSSSaltLength(digestAlg, privateKey.getModulus().bitLength());
@@ -869,7 +870,7 @@ public IRubyObject sign_pss(ThreadContext context, IRubyObject[] args) {
869870
try {
870871
signedData = signDataWithPSS(runtime, args[1].convertToString(), digestAlg, mgf1Alg, saltLen);
871872
} catch (IllegalArgumentException | DataLengthException | CryptoException e) {
872-
throw (RaiseException) newPKeyError(runtime, e.getMessage()).initCause(e);
873+
throw (RaiseException) newRSAError(runtime, e.getMessage()).initCause(e);
873874
}
874875
return StringHelper.newString(runtime, signedData);
875876
}

src/test/ruby/dsa/test_dsa.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_new
4141
def test_new_empty
4242
key = OpenSSL::PKey::DSA.new
4343
assert_nil(key.p)
44-
assert_raise(OpenSSL::PKey::PKeyError) { key.to_der }
44+
assert_pkey_error { key.to_der }
4545
end
4646

4747
def test_dup

src/test/ruby/ec/test_ec.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_ec_key
4949
end
5050

5151
def test_generate
52-
assert_raise(OpenSSL::PKey::PKeyError) { OpenSSL::PKey::EC.generate("non-existent") }
52+
assert_pkey_error { OpenSSL::PKey::EC.generate("non-existent") }
5353
g = OpenSSL::PKey::EC::Group.new("prime256v1")
5454
ec = OpenSSL::PKey::EC.generate(g)
5555
assert_equal(true, ec.private?)

src/test/ruby/rsa/test_rsa.rb

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def test_no_private_exp
1414
rsa = Fixtures.pkey("rsa-1.pem")
1515
key.set_key(rsa.n, rsa.e, nil)
1616
key.set_factors(rsa.p, rsa.q)
17-
assert_raise(OpenSSL::PKey::PKeyError){ key.private_encrypt("foo") }
18-
assert_raise(OpenSSL::PKey::PKeyError){ key.private_decrypt("foo") }
17+
assert_pkey_error { key.private_encrypt("foo") }
18+
assert_pkey_error { key.private_decrypt("foo") }
1919
end
2020

2121
def test_private
@@ -201,7 +201,7 @@ def test_sign_verify_raw_legacy
201201
# Failure cases
202202
assert_raise(ArgumentError){ key.private_encrypt() }
203203
assert_raise(ArgumentError){ key.private_encrypt("hi", 1, nil) }
204-
assert_raise(OpenSSL::PKey::PKeyError){ key.private_encrypt(plain0, 666) }
204+
assert_pkey_error { key.private_encrypt(plain0, 666) }
205205
end
206206

207207
def test_sign_verify_raw
@@ -219,7 +219,7 @@ def test_sign_verify_raw
219219
assert_equal false, key.verify_raw("SHA256", signature, wrong_hash)
220220

221221
# Data exceeding the key modulus must raise PKeyError
222-
assert_raise(OpenSSL::PKey::PKeyError) {
222+
assert_pkey_error {
223223
key.sign_raw("SHA1", "x" * (key.n.num_bytes + 1))
224224
}
225225

@@ -283,7 +283,7 @@ def test_sign_verify_pss
283283
key.verify_pss("SHA256", signature, data, salt_length: :auto, mgf1_hash: "SHA256")
284284
end
285285

286-
assert_raise(OpenSSL::PKey::PKeyError) {
286+
assert_pkey_error {
287287
key.sign_pss("SHA256", data, salt_length: 223, mgf1_hash: "SHA256")
288288
}
289289
end
@@ -437,7 +437,7 @@ def test_RSAPrivateKey_encrypted
437437
cipher = OpenSSL::Cipher.new("aes-128-cbc")
438438
exported = rsa1024.to_pem(cipher, "abcdef\0\1")
439439
assert_same_rsa rsa1024, OpenSSL::PKey::RSA.new(exported, "abcdef\0\1")
440-
assert_raise(OpenSSL::PKey::PKeyError) {
440+
assert_pkey_error {
441441
OpenSSL::PKey::RSA.new(exported, "abcdef")
442442
}
443443
end

src/test/ruby/test_helper.rb

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def setup; require 'openssl' end
6969

7070
alias assert_raise assert_raises unless method_defined?(:assert_raise)
7171

72+
def assert_pkey_error(&block)
73+
assert_raise_kind_of(OpenSSL::PKey::PKeyError, &block)
74+
end
75+
7276
unless method_defined?(:skip)
7377
if method_defined?(:omit)
7478
alias skip omit

0 commit comments

Comments
 (0)