diff --git a/iothub/device/iot-device-client/src/main/java/com/microsoft/azure/sdk/iot/device/transport/amqps/AmqpsCbsSessionHandler.java b/iothub/device/iot-device-client/src/main/java/com/microsoft/azure/sdk/iot/device/transport/amqps/AmqpsCbsSessionHandler.java index 9c103177c2..18b2373e32 100644 --- a/iothub/device/iot-device-client/src/main/java/com/microsoft/azure/sdk/iot/device/transport/amqps/AmqpsCbsSessionHandler.java +++ b/iothub/device/iot-device-client/src/main/java/com/microsoft/azure/sdk/iot/device/transport/amqps/AmqpsCbsSessionHandler.java @@ -148,6 +148,13 @@ public void onAuthenticationFailed(String deviceId, TransportException transport this.connectionStateCallback.onAuthenticationFailed(deviceId, transportException); } + public void onAuthenticationTimedOut(String deviceId) + { + log.warn("Timed out waiting for CBS authentication response for device {}. Closing this connection...", deviceId); + this.connectionStateCallback.onCBSSessionClosedUnexpectedly(null); + this.close(); + } + public void close() { log.trace("Closing this CBS session"); diff --git a/iothub/device/iot-device-client/src/main/java/com/microsoft/azure/sdk/iot/device/transport/amqps/AmqpsSasTokenRenewalHandler.java b/iothub/device/iot-device-client/src/main/java/com/microsoft/azure/sdk/iot/device/transport/amqps/AmqpsSasTokenRenewalHandler.java index b764b1a9d9..dcd42a7595 100644 --- a/iothub/device/iot-device-client/src/main/java/com/microsoft/azure/sdk/iot/device/transport/amqps/AmqpsSasTokenRenewalHandler.java +++ b/iothub/device/iot-device-client/src/main/java/com/microsoft/azure/sdk/iot/device/transport/amqps/AmqpsSasTokenRenewalHandler.java @@ -29,6 +29,8 @@ class AmqpsSasTokenRenewalHandler extends BaseHandler implements AuthenticationM private boolean isClosed; private AmqpsSasTokenRenewalHandler nextToAuthenticate; private Task scheduledTask; + private Task authenticationTimeoutTask; + private int currentAuthenticationRequestId; public AmqpsSasTokenRenewalHandler(AmqpsCbsSessionHandler amqpsCbsSessionHandler, AmqpsSessionHandler amqpsSessionHandler) { @@ -70,6 +72,7 @@ public void sendAuthenticationMessage(Reactor reactor) throws TransportException log.debug("Sending authentication message for device {}", amqpsSessionHandler.getDeviceId()); amqpsCbsSessionHandler.sendAuthenticationMessage(amqpsSessionHandler.getClientConfiguration(), this); + scheduleAuthenticationTimeout(reactor); scheduleRenewal(reactor); } } @@ -77,6 +80,8 @@ public void sendAuthenticationMessage(Reactor reactor) throws TransportException @Override public DeliveryState handleAuthenticationResponseMessage(int status, String description, Reactor reactor) { + cancelAuthenticationTimeout(); + try { if (nextToAuthenticate != null) @@ -104,6 +109,15 @@ public DeliveryState handleAuthenticationResponseMessage(int status, String desc return Accepted.getInstance(); } + private void onAuthenticationTimedOut(int authenticationRequestId) + { + if (!isClosed && authenticationRequestId == currentAuthenticationRequestId) + { + log.warn("Timed out waiting for CBS authentication response for device {}", this.amqpsSessionHandler.getDeviceId()); + this.amqpsCbsSessionHandler.onAuthenticationTimedOut(this.amqpsSessionHandler.getDeviceId()); + } + } + // Once closed, this handler will stop sending authentication messages for its device. This object may not be re-opened. public void close() { @@ -128,6 +142,36 @@ private void scheduleRenewalRetry(Reactor reactor) this.scheduledTask = reactor.schedule(RETRY_INTERVAL_MILLISECONDS, this); } + private void scheduleAuthenticationTimeout(Reactor reactor) + { + cancelAuthenticationTimeout(); + + currentAuthenticationRequestId++; + final int expectedAuthenticationRequestId = currentAuthenticationRequestId; + long authenticationTimeout = this.amqpsSessionHandler.getClientConfiguration().getOperationTimeout(); + + log.trace("Scheduling CBS authentication response timeout for device {} in {} milliseconds", this.amqpsSessionHandler.getDeviceId(), authenticationTimeout); + + this.authenticationTimeoutTask = reactor.schedule((int) authenticationTimeout, new BaseHandler() + { + @Override + public void onTimerTask(Event event) + { + onAuthenticationTimedOut(expectedAuthenticationRequestId); + } + }); + } + + private void cancelAuthenticationTimeout() + { + if (this.authenticationTimeoutTask != null) + { + this.authenticationTimeoutTask.cancel(); + this.authenticationTimeoutTask.attachments().clear(); + this.authenticationTimeoutTask = null; + } + } + // Removes any children of this handler (such as LoggingFlowController) and disassociates this handler // from the proton reactor. By removing the reference of the proton reactor to this handler, this handler becomes // eligible for garbage collection by the JVM. This is important for multiplexed connections where SAS token renewal @@ -140,6 +184,8 @@ private void clearHandlers() this.scheduledTask.attachments().clear(); } + cancelAuthenticationTimeout(); + // an instance of this class shouldn't have any children, but other handlers may be added as this SDK // grows and this protects against potential memory leaks Iterator childrenIterator = this.children(); diff --git a/iothub/device/iot-device-client/src/test/java/com/microsoft/azure/sdk/iot/device/transport/amqps/AmqpsSasTokenRenewalHandlerTest.java b/iothub/device/iot-device-client/src/test/java/com/microsoft/azure/sdk/iot/device/transport/amqps/AmqpsSasTokenRenewalHandlerTest.java new file mode 100644 index 0000000000..dffcdc495d --- /dev/null +++ b/iothub/device/iot-device-client/src/test/java/com/microsoft/azure/sdk/iot/device/transport/amqps/AmqpsSasTokenRenewalHandlerTest.java @@ -0,0 +1,183 @@ +/* + * Copyright (c) Microsoft. All rights reserved. + * Licensed under the MIT license. See LICENSE file in the project root for full license information. + */ + +package com.microsoft.azure.sdk.iot.device.transport.amqps; + +import com.microsoft.azure.sdk.iot.device.ClientConfiguration; +import com.microsoft.azure.sdk.iot.device.auth.IotHubSasTokenAuthenticationProvider; +import com.microsoft.azure.sdk.iot.device.transport.TransportException; +import mockit.Delegate; +import mockit.Expectations; +import mockit.Mocked; +import mockit.Verifications; +import org.apache.qpid.proton.engine.BaseHandler; +import org.apache.qpid.proton.engine.Event; +import org.apache.qpid.proton.engine.Handler; +import org.apache.qpid.proton.engine.Record; +import org.apache.qpid.proton.reactor.Reactor; +import org.apache.qpid.proton.reactor.Task; +import org.junit.Test; + +/** + * Unit tests for AmqpsSasTokenRenewalHandler. + */ +@SuppressWarnings("ThrowableNotThrown") +public class AmqpsSasTokenRenewalHandlerTest +{ + @Mocked + AmqpsCbsSessionHandler mockedCbsSessionHandler; + + @Mocked + AmqpsSessionHandler mockedSessionHandler; + + @Mocked + ClientConfiguration mockedConfig; + + @Mocked + IotHubSasTokenAuthenticationProvider mockedSasTokenAuthenticationProvider; + + @Mocked + Reactor mockedReactor; + + @Mocked + Task mockedAuthenticationTimeoutTask; + + @Mocked + Task mockedRenewalTask; + + @Mocked + Event mockedEvent; + + @Mocked + Record mockedRecord; + + // Tests_SRS_AMQPSSASTOKENRENEWALHANDLER_34_001: [If no CBS authentication response is received before the operation timeout, this function shall notify the CBS session that authentication timed out.] + @Test + public void authenticationResponseTimeoutNotifiesCbsSession() throws TransportException + { + //arrange + final String deviceId = "someDevice"; + final int authenticationTimeout = 1000; + final int renewalPeriod = 2000; + final Handler[] authenticationTimeoutHandler = new Handler[1]; + + new Expectations() + { + { + mockedSessionHandler.getDeviceId(); + result = deviceId; + + mockedSessionHandler.getClientConfiguration(); + result = mockedConfig; + + mockedConfig.getOperationTimeout(); + result = authenticationTimeout; + + mockedConfig.getSasTokenAuthentication(); + result = mockedSasTokenAuthenticationProvider; + + mockedSasTokenAuthenticationProvider.getMillisecondsBeforeProactiveRenewal(); + result = renewalPeriod; + + mockedCbsSessionHandler.sendAuthenticationMessage(mockedConfig, (AuthenticationMessageCallback) any); + + mockedReactor.schedule(anyInt, (Handler) any); + result = new Delegate() + { + @SuppressWarnings("unused") + Task schedule(int delay, Handler handler) + { + if (delay == authenticationTimeout) + { + authenticationTimeoutHandler[0] = handler; + return mockedAuthenticationTimeoutTask; + } + + return mockedRenewalTask; + } + }; + } + }; + + AmqpsSasTokenRenewalHandler sasTokenRenewalHandler = new AmqpsSasTokenRenewalHandler(mockedCbsSessionHandler, mockedSessionHandler); + sasTokenRenewalHandler.sendAuthenticationMessage(mockedReactor); + + //act + ((BaseHandler) authenticationTimeoutHandler[0]).onTimerTask(mockedEvent); + + //assert + new Verifications() + { + { + mockedCbsSessionHandler.onAuthenticationTimedOut(deviceId); + times = 1; + } + }; + } + + // Tests_SRS_AMQPSSASTOKENRENEWALHANDLER_34_002: [If a CBS authentication response is received before the operation timeout, this function shall cancel the authentication response timeout.] + @Test + public void authenticationResponseCancelsResponseTimeoutTask() throws TransportException + { + //arrange + final String deviceId = "someDevice"; + final int authenticationTimeout = 1000; + final int renewalPeriod = 2000; + + new Expectations() + { + { + mockedSessionHandler.getDeviceId(); + result = deviceId; + + mockedSessionHandler.getClientConfiguration(); + result = mockedConfig; + + mockedConfig.getOperationTimeout(); + result = authenticationTimeout; + + mockedConfig.getSasTokenAuthentication(); + result = mockedSasTokenAuthenticationProvider; + + mockedSasTokenAuthenticationProvider.getMillisecondsBeforeProactiveRenewal(); + result = renewalPeriod; + + mockedCbsSessionHandler.sendAuthenticationMessage(mockedConfig, (AuthenticationMessageCallback) any); + + mockedReactor.schedule(authenticationTimeout, (Handler) any); + result = mockedAuthenticationTimeoutTask; + + mockedReactor.schedule(renewalPeriod, (Handler) any); + result = mockedRenewalTask; + + mockedAuthenticationTimeoutTask.attachments(); + result = mockedRecord; + + mockedSessionHandler.openLinks(); + } + }; + + AmqpsSasTokenRenewalHandler sasTokenRenewalHandler = new AmqpsSasTokenRenewalHandler(mockedCbsSessionHandler, mockedSessionHandler); + sasTokenRenewalHandler.sendAuthenticationMessage(mockedReactor); + + //act + sasTokenRenewalHandler.handleAuthenticationResponseMessage(200, "", mockedReactor); + + //assert + new Verifications() + { + { + mockedAuthenticationTimeoutTask.cancel(); + times = 1; + + mockedRecord.clear(); + times = 1; + + mockedCbsSessionHandler.onAuthenticationTimedOut(anyString); + times = 0; + } + }; + } +}