Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .evergreen/.evg.yml
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,16 @@ functions:
set +o xtrace
MONGODB_URI="${MONGODB_URI}" KMS_TLS_ERROR_TYPE=${KMS_TLS_ERROR_TYPE} .evergreen/run-kms-tls-tests.sh

"run-kms-retry-test":
- command: shell.exec
type: "test"
params:
working_dir: "src"
script: |
${PREPARE_SHELL}
set +o xtrace
MONGODB_URI="${MONGODB_URI}" .evergreen/run-kms-retry-tests.sh

"run-csfle-aws-from-environment-test":
- command: shell.exec
type: "test"
Expand Down Expand Up @@ -1632,6 +1642,17 @@ tasks:
AUTH: "noauth"
SSL: "nossl"

- name: "test-kms-retry-task"
tags: [ "kms-retry" ]
commands:
- func: "start-mongo-orchestration"
vars:
TOPOLOGY: "server"
AUTH: "noauth"
SSL: "nossl"
- func: "start-csfle-servers"
- func: "run-kms-retry-test"

- name: "test-csfle-aws-from-environment-task"
tags: [ "csfle-aws-from-environment" ]
commands:
Expand Down Expand Up @@ -2528,6 +2549,12 @@ buildvariants:
tasks:
- name: ".kms-tls"

- matrix_name: "kms-retry-test"
matrix_spec: { os: "linux", version: [ "5.0" ], topology: [ "standalone" ] }
display_name: "CSFLE KMS Retry"
tasks:
- name: ".kms-retry"

- matrix_name: "csfle-aws-from-environment-test"
matrix_spec: { os: "linux", version: [ "5.0" ], topology: [ "standalone" ] }
display_name: "CSFLE AWS From Environment"
Expand Down
43 changes: 43 additions & 0 deletions .evergreen/run-kms-retry-tests.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#!/bin/bash

set -o errexit # Exit the script with error if any of the commands fail

# Supported/used environment variables:
# MONGODB_URI Set the suggested connection MONGODB_URI (including credentials and topology info)

############################################
# Main Program #
############################################
RELATIVE_DIR_PATH="$(dirname "${BASH_SOURCE:-$0}")"
. "${RELATIVE_DIR_PATH}/setup-env.bash"
echo "Running KMS Retry tests"

cp ${JAVA_HOME}/lib/security/cacerts mongo-truststore
${JAVA_HOME}/bin/keytool -importcert -trustcacerts -file ${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem -keystore mongo-truststore -storepass changeit -storetype JKS -noprompt

export GRADLE_EXTRA_VARS="-Pssl.enabled=true -Pssl.trustStoreType=jks -Pssl.trustStore=`pwd`/mongo-truststore -Pssl.trustStorePassword=changeit"

./gradlew -version

# Disable errexit so both suites run and their exit codes can be captured below.
set +o errexit

./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \
-Dorg.mongodb.test.kms.retry.ca.path="${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem" \
driver-sync:cleanTest driver-sync:test --tests ClientSideEncryptionKmsRetryProseTest
first=$?
echo "sync exit code: $first"

./gradlew --stacktrace --info ${GRADLE_EXTRA_VARS} -Dorg.mongodb.test.uri=${MONGODB_URI} \
-Dorg.mongodb.test.kms.retry.ca.path="${DRIVERS_TOOLS}/.evergreen/x509gen/ca.pem" \
driver-reactive-streams:cleanTest driver-reactive-streams:test --tests ClientSideEncryptionKmsRetryProseTest
second=$?
echo "reactive exit code: $second"

if [ $first -ne 0 ]; then
exit $first
elif [ $second -ne 0 ]; then
exit $second
else
exit 0
fi
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
import com.mongodb.MongoClientException;
import com.mongodb.MongoClientSettings;
import com.mongodb.MongoConfigurationException;
import com.mongodb.MongoOperationTimeoutException;
import com.mongodb.client.model.vault.RewrapManyDataKeyOptions;
import com.mongodb.internal.TimeoutContext;
import com.mongodb.internal.authentication.AwsCredentialHelper;
import com.mongodb.internal.authentication.AzureCredentialHelper;
import com.mongodb.internal.authentication.GcpCredentialHelper;
import com.mongodb.internal.crypt.capi.MongoCryptOptions;
import com.mongodb.internal.time.Timeout;
import com.mongodb.lang.Nullable;
import org.bson.BsonDocument;
import org.bson.BsonDocumentWrapper;
Expand All @@ -52,6 +55,32 @@
*/
public final class MongoCryptHelper {

public static final String KMS_TIMEOUT_ERROR_MESSAGE = "KMS key decryption exceeded the timeout limit.";

/**
* Throws a {@code MongoOperationTimeoutException} if the operation timeout has expired or the
* KMS retry backoff would exceed the remaining operation time.
*
* @param operationTimeout the operation timeout, or null if none
* @param backoffMicros the backoff to sleep before the next KMS attempt, in microseconds
*/
public static void checkKmsRetryBackoff(@Nullable final Timeout operationTimeout, final long backoffMicros) {
if (operationTimeout == null) {
return;
}
operationTimeout.run(TimeUnit.MICROSECONDS,
// infinite timeout: no CSOT budget to enforce; libmongocrypt's retry count is the only limit
() -> { },
remainingMicros -> {
if (remainingMicros < backoffMicros) {
throw TimeoutContext.createMongoTimeoutException(KMS_TIMEOUT_ERROR_MESSAGE);
}
},
() -> {
throw TimeoutContext.createMongoTimeoutException(KMS_TIMEOUT_ERROR_MESSAGE);
});
}

public static MongoCryptOptions createMongoCryptOptions(final ClientEncryptionSettings settings) {
return createMongoCryptOptions(settings.getKmsProviders(), false, emptyList(), emptyMap(), null, null,
settings.getKeyExpiration(TimeUnit.MILLISECONDS));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,32 @@
import java.io.Closeable;
import java.nio.channels.CompletionHandler;
import java.nio.channels.InterruptedByTimeoutException;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Map;

import static com.mongodb.internal.capi.MongoCryptHelper.KMS_TIMEOUT_ERROR_MESSAGE;
import static com.mongodb.internal.capi.MongoCryptHelper.checkKmsRetryBackoff;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.bson.assertions.Assertions.assertTrue;

class KeyManagementService implements Closeable {
private static final Logger LOGGER = Loggers.getLogger("client");
private static final String TIMEOUT_ERROR_MESSAGE = "KMS key decryption exceeded the timeout limit.";
private final Map<String, SSLContext> kmsProviderSslContextMap;
private final int timeoutMillis;
private final TlsChannelStreamFactoryFactory tlsChannelStreamFactoryFactory;

KeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap, final int timeoutMillis) {
this(kmsProviderSslContextMap, timeoutMillis, new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver()));
}

KeyManagementService(final Map<String, SSLContext> kmsProviderSslContextMap, final int timeoutMillis,
final TlsChannelStreamFactoryFactory tlsChannelStreamFactoryFactory) {
assertTrue("timeoutMillis > 0", timeoutMillis > 0);
this.kmsProviderSslContextMap = kmsProviderSslContextMap;
this.tlsChannelStreamFactoryFactory = new TlsChannelStreamFactoryFactory(new DefaultInetAddressResolver());
this.tlsChannelStreamFactoryFactory = tlsChannelStreamFactoryFactory;
this.timeoutMillis = timeoutMillis;
}

Expand All @@ -74,6 +82,18 @@ public void close() {
}

Mono<Void> decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) {
return Mono.defer(() -> {
long sleepMicros = keyDecryptor.sleepMicroseconds();
if (sleepMicros > 0) {
checkKmsRetryBackoff(operationTimeout, sleepMicros);
return Mono.delay(Duration.of(sleepMicros, ChronoUnit.MICROS))
.then(attemptDecryptKey(keyDecryptor, operationTimeout));
}
return attemptDecryptKey(keyDecryptor, operationTimeout);
}).onErrorMap(this::unWrapException);
}

private Mono<Void> attemptDecryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Timeout operationTimeout) {
SocketSettings socketSettings = SocketSettings.builder()
.connectTimeout(timeoutMillis, MILLISECONDS)
.readTimeout(timeoutMillis, MILLISECONDS)
Expand All @@ -86,88 +106,135 @@ Mono<Void> decryptKey(final MongoKeyDecryptor keyDecryptor, @Nullable final Time
LOGGER.info("Connecting to KMS server at " + serverAddress);

return Mono.<Void>create(sink -> {
Stream stream = streamFactory.create(serverAddress);
OperationContext operationContext = createOperationContext(operationTimeout, socketSettings);
Stream stream = streamFactory.create(serverAddress);
stream.openAsync(operationContext, new AsyncCompletionHandler<Void>() {
@Override
public void completed(@Nullable final Void ignored) {
streamWrite(stream, keyDecryptor, operationContext, sink);
try {
streamWrite(stream, keyDecryptor, operationContext, operationTimeout, sink);
} catch (Throwable t) {
stream.close();
sink.error(t);
}
}

@Override
public void failed(final Throwable t) {
stream.close();
handleError(t, operationContext, sink);
failOrHandleError(t, keyDecryptor, operationTimeout, sink);
}
});
}).onErrorMap(this::unWrapException);
});
}

private void streamWrite(final Stream stream, final MongoKeyDecryptor keyDecryptor,
final OperationContext operationContext, final MonoSink<Void> sink) {
final OperationContext operationContext, @Nullable final Timeout operationTimeout,
final MonoSink<Void> sink) {
List<ByteBuf> byteBufs = singletonList(new ByteBufNIO(keyDecryptor.getMessage()));
stream.writeAsync(byteBufs, operationContext, new AsyncCompletionHandler<Void>() {
@Override
public void completed(@Nullable final Void aVoid) {
streamRead(stream, keyDecryptor, operationContext, sink);
try {
streamRead(stream, keyDecryptor, operationContext, operationTimeout, sink);
} catch (Throwable t) {
stream.close();
sink.error(t);
}
}

@Override
public void failed(final Throwable t) {
stream.close();
handleError(t, operationContext, sink);
failOrHandleError(t, keyDecryptor, operationTimeout, sink);
}
});
}

private void streamRead(final Stream stream, final MongoKeyDecryptor keyDecryptor,
final OperationContext operationContext, final MonoSink<Void> sink) {
final OperationContext operationContext, @Nullable final Timeout operationTimeout,
final MonoSink<Void> sink) {
int bytesNeeded = keyDecryptor.bytesNeeded();
if (bytesNeeded > 0) {
AsynchronousChannelStream asyncStream = (AsynchronousChannelStream) stream;
ByteBuf buffer = asyncStream.getBuffer(bytesNeeded);
long readTimeoutMS = operationContext.getTimeoutContext().getReadTimeoutMS();
asyncStream.getChannel().read(buffer.asNIO(), readTimeoutMS, MILLISECONDS, null,
new CompletionHandler<Integer, Void>() {

@Override
public void completed(final Integer integer, final Void aVoid) {
if (integer == -1) {
sink.error(new MongoException(
"Unexpected end of stream from KMS provider " + keyDecryptor.getKmsProvider()));
return;
}
buffer.flip();
try {
keyDecryptor.feed(buffer.asNIO());
buffer.release();
streamRead(stream, keyDecryptor, operationContext, sink);
} catch (Throwable t) {
sink.error(t);
}
}

@Override
public void failed(final Throwable t, final Void aVoid) {
buffer.release();
stream.close();
handleError(t, operationContext, sink);
}
});
} else {
if (bytesNeeded <= 0) {
stream.close();
sink.success();
return;
}
AsynchronousChannelStream asyncStream = (AsynchronousChannelStream) stream;
ByteBuf buffer = asyncStream.getBuffer(bytesNeeded);
CompletionHandler<Integer, Void> readHandler = new CompletionHandler<Integer, Void>() {

@Override
public void completed(final Integer integer, final Void aVoid) {
try {
if (integer == -1) {
buffer.release();
stream.close();
// Treat an unexpected end of stream (the KMS server closed the connection) as a retryable
// transient network error: hand it to failOrHandleError so the context is retried if budget allows.
MongoException eof = new MongoException("Unexpected end of stream from KMS provider "
+ keyDecryptor.getKmsProvider());
failOrHandleError(eof, keyDecryptor, operationTimeout, sink);
return;
}
buffer.flip();
boolean shouldRetry;
try {
shouldRetry = keyDecryptor.feedWithRetry(buffer.asNIO());
} finally {
buffer.release();
}
if (shouldRetry) {
// libmongocrypt marked the context for retry; complete this attempt and let the state machine re-present it
stream.close();
sink.success();
} else {
streamRead(stream, keyDecryptor, operationContext, operationTimeout, sink);
}
} catch (Throwable t) {
stream.close();
sink.error(t);
}
}

@Override
public void failed(final Throwable t, final Void aVoid) {
buffer.release();
stream.close();
failOrHandleError(t, keyDecryptor, operationTimeout, sink);
}
};
try {
long readTimeoutMS = operationContext.getTimeoutContext().getReadTimeoutMS();
asyncStream.getChannel().read(buffer.asNIO(), readTimeoutMS, MILLISECONDS, null, readHandler);
} catch (RuntimeException | Error e) {
// the handler was not invoked, so the buffer must be released here
buffer.release();
throw e;
}
}

private static void handleError(final Throwable t, final OperationContext operationContext, final MonoSink<Void> sink) {
if (isTimeoutException(t) && operationContext.getTimeoutContext().hasTimeoutMS()) {
sink.error(TimeoutContext.createMongoTimeoutException(TIMEOUT_ERROR_MESSAGE, t));
private static void failOrHandleError(final Throwable t, final MongoKeyDecryptor keyDecryptor,
@Nullable final Timeout operationTimeout, final MonoSink<Void> sink) {
if (isTimeoutException(t) && hasExpired(operationTimeout)) {
sink.error(TimeoutContext.createMongoTimeoutException(KMS_TIMEOUT_ERROR_MESSAGE, t));
return;
}
if (keyDecryptor.fail()) {
LOGGER.debug("Retrying KMS request after transient error", t);
sink.success();
} else {
sink.error(t);
}
}

private static boolean hasExpired(@Nullable final Timeout operationTimeout) {
return operationTimeout != null && operationTimeout.call(MILLISECONDS,
() -> false,
remainingMillis -> false,
() -> true);
}

private OperationContext createOperationContext(@Nullable final Timeout operationTimeout, final SocketSettings socketSettings) {
TimeoutSettings timeoutSettings;
if (operationTimeout == null) {
Expand All @@ -179,7 +246,7 @@ private OperationContext createOperationContext(@Nullable final Timeout operatio
},
(ms) -> createTimeoutSettings(socketSettings, ms),
() -> {
throw new MongoOperationTimeoutException(TIMEOUT_ERROR_MESSAGE);
throw new MongoOperationTimeoutException(KMS_TIMEOUT_ERROR_MESSAGE);
});
}
return OperationContext.simpleOperationContext(new TimeoutContext(timeoutSettings));
Expand All @@ -197,7 +264,13 @@ private static TimeoutSettings createTimeoutSettings(final SocketSettings socket
}

private Throwable unWrapException(final Throwable t) {
return t instanceof MongoSocketException ? t.getCause() : t;
// Unwrap the IOException the async stream layer wraps in a MongoSocketException, to match the sync path.
// Socket timeout subclasses are meaningful MongoClientExceptions, so preserve them rather than unwrapping.
if (t instanceof MongoSocketReadTimeoutException || t instanceof MongoSocketWriteTimeoutException) {
return t;
}
Throwable cause = t.getCause();
return t instanceof MongoSocketException && cause != null ? cause : t;
}

private static boolean isTimeoutException(final Throwable t) {
Expand Down
Loading