Skip to content

Commit ae75385

Browse files
committed
INTERNAL: Add ScramSaslClient
1 parent 5c13575 commit ae75385

File tree

4 files changed

+254
-0
lines changed

4 files changed

+254
-0
lines changed

pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,11 @@
118118
</exclusion>
119119
</exclusions>
120120
</dependency>
121+
<dependency>
122+
<groupId>com.bolyartech.scram_sasl</groupId>
123+
<artifactId>scram_sasl</artifactId>
124+
<version>2.0.2</version>
125+
</dependency>
121126

122127
<!-- TEST -->
123128
<dependency>
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package net.spy.memcached.auth;
2+
3+
import java.util.Collection;
4+
import java.util.Collections;
5+
import java.util.HashMap;
6+
import java.util.Map;
7+
8+
public enum ScramMechanism {
9+
SCRAM_SHA_256("SHA-256", "HmacSHA256");
10+
11+
private static final Map<String, ScramMechanism> MECHANISMS_MAP;
12+
13+
private final String mechanismName;
14+
private final String hashAlgorithm;
15+
private final String macAlgorithm;
16+
17+
static {
18+
Map<String, ScramMechanism> map = new HashMap<>();
19+
for (ScramMechanism mech : values()) {
20+
map.put(mech.mechanismName, mech);
21+
}
22+
MECHANISMS_MAP = Collections.unmodifiableMap(map);
23+
}
24+
25+
private ScramMechanism(String hashAlgorithm, String macAlgorithm) {
26+
this.mechanismName = "SCRAM-" + hashAlgorithm;
27+
this.hashAlgorithm = hashAlgorithm;
28+
this.macAlgorithm = macAlgorithm;
29+
}
30+
31+
public final String mechanismName() {
32+
return this.mechanismName;
33+
}
34+
public String hashAlgorithm() {
35+
return hashAlgorithm;
36+
}
37+
38+
public String macAlgorithm() {
39+
return macAlgorithm;
40+
}
41+
42+
public static ScramMechanism forMechanismName(String mechanismName) {
43+
return MECHANISMS_MAP.get(mechanismName);
44+
}
45+
46+
public static Collection<String> mechanismNames() {
47+
return MECHANISMS_MAP.keySet();
48+
}
49+
}
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
package net.spy.memcached.auth;
2+
3+
import java.nio.charset.StandardCharsets;
4+
import java.util.Arrays;
5+
import java.util.Collection;
6+
import java.util.Map;
7+
8+
import javax.security.auth.callback.Callback;
9+
import javax.security.auth.callback.CallbackHandler;
10+
import javax.security.auth.callback.NameCallback;
11+
import javax.security.auth.callback.PasswordCallback;
12+
import javax.security.sasl.SaslClient;
13+
import javax.security.sasl.SaslClientFactory;
14+
import javax.security.sasl.SaslException;
15+
16+
import com.bolyartech.scram_sasl.client.ScramClientFunctionality;
17+
import com.bolyartech.scram_sasl.client.ScramClientFunctionalityImpl;
18+
import com.bolyartech.scram_sasl.common.ScramException;
19+
20+
public class ScramSaslClient implements SaslClient {
21+
22+
enum State {
23+
SEND_CLIENT_FIRST_MESSAGE,
24+
RECEIVE_SERVER_FIRST_MESSAGE,
25+
RECEIVE_SERVER_FINAL_MESSAGE,
26+
COMPLETE,
27+
FAILED
28+
}
29+
30+
private final ScramMechanism mechanism;
31+
private final CallbackHandler callbackHandler;
32+
private final ScramClientFunctionality scf;
33+
private State state;
34+
35+
public ScramSaslClient(ScramMechanism mechanism, CallbackHandler cbh) {
36+
this.callbackHandler = cbh;
37+
this.mechanism = mechanism;
38+
this.scf = new ScramClientFunctionalityImpl(
39+
mechanism.hashAlgorithm(), mechanism.macAlgorithm());
40+
this.state = State.SEND_CLIENT_FIRST_MESSAGE;
41+
}
42+
43+
@Override
44+
public String getMechanismName() {
45+
return this.mechanism.mechanismName();
46+
}
47+
48+
@Override
49+
public boolean hasInitialResponse() {
50+
return true;
51+
}
52+
53+
@Override
54+
public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
55+
try {
56+
switch (this.state) {
57+
case SEND_CLIENT_FIRST_MESSAGE:
58+
if (challenge != null && challenge.length != 0) {
59+
throw new SaslException("Expected empty challenge");
60+
}
61+
62+
NameCallback nameCallback = new NameCallback("Name: ");
63+
64+
try {
65+
callbackHandler.handle(new Callback[]{nameCallback});
66+
} catch (Throwable e) {
67+
throw new SaslException("User name could not be obtained", e);
68+
}
69+
70+
String username = nameCallback.getName();
71+
byte[] clientFirstMessage = this.scf.prepareFirstMessage(username).getBytes();
72+
this.state = State.RECEIVE_SERVER_FIRST_MESSAGE;
73+
return clientFirstMessage;
74+
75+
case RECEIVE_SERVER_FIRST_MESSAGE:
76+
String serverFirstMessage = new String(challenge, StandardCharsets.UTF_8);
77+
78+
PasswordCallback passwordCallback = new PasswordCallback("Password: ", false);
79+
try {
80+
callbackHandler.handle(new Callback[]{passwordCallback});
81+
} catch (Throwable e) {
82+
throw new SaslException("Password could not be obtained", e);
83+
}
84+
85+
String password = String.valueOf(passwordCallback.getPassword());
86+
byte[] clientFinalMessage = this.scf.prepareFinalMessage(
87+
password, serverFirstMessage).getBytes();
88+
if (clientFinalMessage == null) {
89+
throw new SaslException("clientFinalMessage should not be null");
90+
}
91+
this.state = State.RECEIVE_SERVER_FINAL_MESSAGE;
92+
return clientFinalMessage;
93+
94+
case RECEIVE_SERVER_FINAL_MESSAGE:
95+
String serverFinalMessage = new String(challenge, StandardCharsets.UTF_8);
96+
if (!this.scf.checkServerFinalMessage(serverFinalMessage)) {
97+
throw new SaslException("Sasl authentication using " + this.mechanism +
98+
" failed with error: invalid server final message");
99+
}
100+
this.state = State.COMPLETE;
101+
return new byte[]{};
102+
103+
default:
104+
throw new SaslException("Unexpected challenge in Sasl client state " + this.state);
105+
}
106+
} catch (ScramException e) {
107+
this.state = State.FAILED;
108+
throw new SaslException("ScramException", e);
109+
} catch (SaslException e) {
110+
this.state = State.FAILED;
111+
throw e;
112+
}
113+
}
114+
115+
@Override
116+
public boolean isComplete() {
117+
return this.state == State.COMPLETE;
118+
}
119+
120+
@Override
121+
public byte[] unwrap(byte[] incoming, int offset, int len) throws SaslException {
122+
if (!isComplete()) {
123+
throw new IllegalStateException("Authentication exchange has not completed");
124+
}
125+
return Arrays.copyOfRange(incoming, offset, offset + len);
126+
}
127+
128+
@Override
129+
public byte[] wrap(byte[] outgoing, int offset, int len) throws SaslException {
130+
if (!isComplete()) {
131+
throw new IllegalStateException("Authentication exchange has not completed");
132+
}
133+
return Arrays.copyOfRange(outgoing, offset, offset + len);
134+
}
135+
136+
@Override
137+
public Object getNegotiatedProperty(String propName) {
138+
if (!isComplete()) {
139+
throw new IllegalStateException("Authentication exchange has not completed");
140+
}
141+
return null;
142+
}
143+
144+
@Override
145+
public void dispose() throws SaslException {
146+
}
147+
148+
public static class ScramSaslClientFactory implements SaslClientFactory {
149+
@Override
150+
public SaslClient createSaslClient(String[] mechanisms,
151+
String authorizationId,
152+
String protocol,
153+
String serverName,
154+
Map<String, ?> props,
155+
CallbackHandler cbh) throws SaslException {
156+
157+
ScramMechanism mechanism = null;
158+
for (String mech : mechanisms) {
159+
mechanism = ScramMechanism.forMechanismName(mech);
160+
if (mechanism != null) {
161+
break;
162+
}
163+
}
164+
if (mechanism == null) {
165+
throw new SaslException(String.format("Requested mechanisms '%s' not supported."
166+
+ " Supported mechanisms are '%s'.",
167+
Arrays.asList(mechanisms), ScramMechanism.mechanismNames()));
168+
}
169+
170+
return new ScramSaslClient(mechanism, cbh);
171+
}
172+
173+
@Override
174+
public String[] getMechanismNames(Map<String, ?> props) {
175+
Collection<String> mechanisms = ScramMechanism.mechanismNames();
176+
return mechanisms.toArray(new String[0]);
177+
}
178+
}
179+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package net.spy.memcached.auth;
2+
3+
import java.security.Provider;
4+
import java.security.Security;
5+
6+
import net.spy.memcached.auth.ScramSaslClient.ScramSaslClientFactory;
7+
8+
public final class ScramSaslClientProvider extends Provider {
9+
10+
private static final long serialVersionUID = 1L;
11+
12+
@SuppressWarnings("deprecation")
13+
private ScramSaslClientProvider() {
14+
super("SASL/SCRAM Client Provider", 1.0, "SASL/SCRAM Client Provider for Arcus");
15+
put("SaslClientFactory.SCRAM-SHA-256", ScramSaslClientFactory.class.getName());
16+
}
17+
18+
public static void initialize() {
19+
Security.addProvider(new ScramSaslClientProvider());
20+
}
21+
}

0 commit comments

Comments
 (0)