diff --git a/.github/lock.yml b/.github/lock.yml deleted file mode 100644 index 119e4840bee..00000000000 --- a/.github/lock.yml +++ /dev/null @@ -1,2 +0,0 @@ -daysUntilLock: 90 -lockComment: false diff --git a/.github/workflows/gradle-wrapper-validation.yml b/.github/workflows/gradle-wrapper-validation.yml index 0bde8aeca18..a5070c937c6 100644 --- a/.github/workflows/gradle-wrapper-validation.yml +++ b/.github/workflows/gradle-wrapper-validation.yml @@ -1,6 +1,9 @@ name: "Validate Gradle Wrapper" on: [push, pull_request] +permissions: + contents: read + jobs: validation: name: "Gradle wrapper validation" diff --git a/.github/workflows/lock.yml b/.github/workflows/lock.yml new file mode 100644 index 00000000000..cb648e7fc73 --- /dev/null +++ b/.github/workflows/lock.yml @@ -0,0 +1,20 @@ +name: 'Lock Threads' + +on: + workflow_dispatch: + schedule: + - cron: '37 3 * * *' + +permissions: + issues: write + pull-requests: write + +jobs: + lock: + runs-on: ubuntu-latest + steps: + - uses: dessant/lock-threads@v2 + with: + github-token: ${{ github.token }} + issue-lock-inactive-days: 90 + pr-lock-inactive-days: 90 diff --git a/COMPILING.md b/COMPILING.md index c37d45bda4b..208f05e25cb 100644 --- a/COMPILING.md +++ b/COMPILING.md @@ -43,11 +43,11 @@ This section is only necessary if you are making changes to the code generation. Most users only need to use `skipCodegen=true` as discussed above. ### Build Protobuf -The codegen plugin is C++ code and requires protobuf 3.12.0 or later. +The codegen plugin is C++ code and requires protobuf 3.17.2 or later. For Linux, Mac and MinGW: ``` -$ PROTOBUF_VERSION=3.12.0 +$ PROTOBUF_VERSION=3.17.2 $ curl -LO https://github.com/protocolbuffers/protobuf/releases/download/v$PROTOBUF_VERSION/protobuf-all-$PROTOBUF_VERSION.tar.gz $ tar xzf protobuf-all-$PROTOBUF_VERSION.tar.gz $ cd protobuf-$PROTOBUF_VERSION @@ -80,16 +80,16 @@ When building on Windows and VC++, you need to specify project properties for Gradle to find protobuf: ``` .\gradlew publishToMavenLocal ^ - -PvcProtobufInclude=C:\path\to\protobuf-3.12.0\src ^ - -PvcProtobufLibs=C:\path\to\protobuf-3.12.0\vsprojects\Release ^ + -PvcProtobufInclude=C:\path\to\protobuf\src ^ + -PvcProtobufLibs=C:\path\to\protobuf\vsprojects\Release ^ -PtargetArch=x86_32 ``` Since specifying those properties every build is bothersome, you can instead create ``\gradle.properties`` with contents like: ``` -vcProtobufInclude=C:\\path\\to\\protobuf-3.12.0\\src -vcProtobufLibs=C:\\path\\to\\protobuf-3.12.0\\vsprojects\\Release +vcProtobufInclude=C:\\path\\to\\protobuf\\src +vcProtobufLibs=C:\\path\\to\\protobuf\\vsprojects\\Release targetArch=x86_32 ``` diff --git a/README.md b/README.md index 412c97172dc..e7d93b7319b 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,8 @@ For a guided tour, take a look at the [quick start guide](https://grpc.io/docs/languages/java/quickstart) or the more explanatory [gRPC basics](https://grpc.io/docs/languages/java/basics). -The [examples](https://github.com/grpc/grpc-java/tree/v1.37.0/examples) and the -[Android example](https://github.com/grpc/grpc-java/tree/v1.37.0/examples/android) +The [examples](https://github.com/grpc/grpc-java/tree/v1.39.0/examples) and the +[Android example](https://github.com/grpc/grpc-java/tree/v1.39.0/examples/android) are standalone projects that showcase the usage of gRPC. Download @@ -43,17 +43,17 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: io.grpc grpc-netty-shaded - 1.37.0 + 1.39.0 io.grpc grpc-protobuf - 1.37.0 + 1.39.0 io.grpc grpc-stub - 1.37.0 + 1.39.0 org.apache.tomcat @@ -65,23 +65,23 @@ Download [the JARs][]. Or for Maven with non-Android, add to your `pom.xml`: Or for Gradle with non-Android, add to your dependencies: ```gradle -implementation 'io.grpc:grpc-netty-shaded:1.37.0' -implementation 'io.grpc:grpc-protobuf:1.37.0' -implementation 'io.grpc:grpc-stub:1.37.0' +implementation 'io.grpc:grpc-netty-shaded:1.39.0' +implementation 'io.grpc:grpc-protobuf:1.39.0' +implementation 'io.grpc:grpc-stub:1.39.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` For Android client, use `grpc-okhttp` instead of `grpc-netty-shaded` and `grpc-protobuf-lite` instead of `grpc-protobuf`: ```gradle -implementation 'io.grpc:grpc-okhttp:1.37.0' -implementation 'io.grpc:grpc-protobuf-lite:1.37.0' -implementation 'io.grpc:grpc-stub:1.37.0' +implementation 'io.grpc:grpc-okhttp:1.39.0' +implementation 'io.grpc:grpc-protobuf-lite:1.39.0' +implementation 'io.grpc:grpc-stub:1.39.0' compileOnly 'org.apache.tomcat:annotations-api:6.0.53' // necessary for Java 9+ ``` [the JARs]: -https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.37.0 +https://search.maven.org/search?q=g:io.grpc%20AND%20v:1.39.0 Development snapshots are available in [Sonatypes's snapshot repository](https://oss.sonatype.org/content/repositories/snapshots/). @@ -111,9 +111,9 @@ For protobuf-based codegen integrated with the Maven build system, you can use protobuf-maven-plugin 0.6.1 - com.google.protobuf:protoc:3.12.0:exe:${os.detected.classifier} + com.google.protobuf:protoc:3.17.2:exe:${os.detected.classifier} grpc-java - io.grpc:protoc-gen-grpc-java:1.37.0:exe:${os.detected.classifier} + io.grpc:protoc-gen-grpc-java:1.39.0:exe:${os.detected.classifier} @@ -134,16 +134,16 @@ For non-Android protobuf-based codegen integrated with the Gradle build system, you can use [protobuf-gradle-plugin][]: ```gradle plugins { - id 'com.google.protobuf' version '0.8.15' + id 'com.google.protobuf' version '0.8.16' } protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.12.0" + artifact = "com.google.protobuf:protoc:3.17.2" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.37.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.39.0' } } generateProtoTasks { @@ -167,16 +167,16 @@ use protobuf-gradle-plugin but specify the 'lite' options: ```gradle plugins { - id 'com.google.protobuf' version '0.8.15' + id 'com.google.protobuf' version '0.8.16' } protobuf { protoc { - artifact = "com.google.protobuf:protoc:3.12.0" + artifact = "com.google.protobuf:protoc:3.17.2" } plugins { grpc { - artifact = 'io.grpc:protoc-gen-grpc-java:1.37.0' + artifact = 'io.grpc:protoc-gen-grpc-java:1.39.0' } } generateProtoTasks { diff --git a/SECURITY.md b/SECURITY.md index 9bebf5b709d..d2482e18cbd 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -60,6 +60,10 @@ BoringSSL. It includes pre-built libraries for 64 bit Windows, OS X, and 64 bit Linux. For 32 bit Windows, Conscrypt is an option. For all other platforms, Java 9+ is required. +For users of xDS management protocol, the grpc-netty-shaded transport is +particularly appropriate since it is already used internally for the xDS +protocol and is a runtime dependency of grpc-xds. + For users of grpc-netty we recommend [netty-tcnative with BoringSSL](#tls-with-netty-tcnative-on-boringssl), although using the built-in JDK support in Java 9+, [Conscrypt](#tls-with-conscrypt), and [netty-tcnative diff --git a/alts/build.gradle b/alts/build.gradle index 7236eee0311..8c467f51c12 100644 --- a/alts/build.gradle +++ b/alts/build.gradle @@ -21,10 +21,10 @@ dependencies { project(':grpc-protobuf'), project(':grpc-stub'), libraries.protobuf, - libraries.conscrypt + libraries.conscrypt, + libraries.guava, + libraries.google_auth_oauth2_http def nettyDependency = implementation project(':grpc-netty') - googleOauth2Dependency 'implementation' - guavaDependency 'implementation' compileOnly libraries.javax_annotation shadow configurations.implementation.getDependencies().minus(nettyDependency) diff --git a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java index 8be903f247b..ad2edd4e988 100644 --- a/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java +++ b/alts/src/main/java/io/grpc/alts/internal/AltsProtocolNegotiator.java @@ -65,6 +65,8 @@ public final class AltsProtocolNegotiator { private static final AsciiString SCHEME = AsciiString.of("https"); + private static final String DIRECT_PATH_SERVICE_CFE_CLUSTER_PREFIX = "google_cfe_"; + /** * ClientAltsProtocolNegotiatorFactory is a factory for doing client side negotiation of an ALTS * channel. @@ -282,7 +284,8 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { boolean isXdsDirectPath = false; if (clusterNameAttrKey != null) { String clusterName = grpcHandler.getEagAttributes().get(clusterNameAttrKey); - if (clusterName != null && !clusterName.equals("google_cfe")) { + if (clusterName != null + && !clusterName.startsWith(DIRECT_PATH_SERVICE_CFE_CLUSTER_PREFIX)) { isXdsDirectPath = true; } } diff --git a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java index 740462be62a..f149c4306c6 100644 --- a/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java +++ b/alts/src/test/java/io/grpc/alts/internal/GoogleDefaultProtocolNegotiatorTest.java @@ -181,8 +181,8 @@ public void altsHandler_xdsCluster() { @Test public void tlsHandler_googleCfe() { - Attributes attrs = - Attributes.newBuilder().set(XDS_CLUSTER_NAME_ATTR_KEY, "google_cfe").build(); + Attributes attrs = Attributes.newBuilder().set( + XDS_CLUSTER_NAME_ATTR_KEY, "google_cfe_api.googleapis.com").build(); subtest_tlsHandler(attrs); } } diff --git a/android-interop-testing/build.gradle b/android-interop-testing/build.gradle index 0434eb8f29c..f1c1d491233 100644 --- a/android-interop-testing/build.gradle +++ b/android-interop-testing/build.gradle @@ -63,12 +63,12 @@ dependencies { project(':grpc-stub'), project(':grpc-testing'), libraries.junit, - libraries.truth + libraries.truth, + libraries.opencensus_contrib_grpc_metrics implementation (libraries.google_auth_oauth2_http) { exclude group: 'org.apache.httpcomponents' } - censusGrpcMetricDependency 'implementation' compileOnly libraries.javax_annotation diff --git a/android/build.gradle b/android/build.gradle index b9942b0745f..a50f7b85592 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -31,7 +31,7 @@ repositories { dependencies { api project(':grpc-core') - guavaDependency 'implementation' + implementation libraries.guava testImplementation project('::grpc-okhttp') testImplementation libraries.androidx_test testImplementation libraries.junit diff --git a/api/BUILD.bazel b/api/BUILD.bazel index b3ff0e3536d..4b74e6a836f 100644 --- a/api/BUILD.bazel +++ b/api/BUILD.bazel @@ -3,10 +3,12 @@ java_library( srcs = glob([ "src/main/java/**/*.java", ]), + javacopts = ["-Xep:DoNotCall:OFF"], # Remove once requiring Bazel 3.4.0+; allows non-final visibility = ["//visibility:public"], deps = [ "//context", "@com_google_code_findbugs_jsr305//jar", + "@com_google_errorprone_error_prone_annotations//jar", "@com_google_guava_failureaccess//jar", # future transitive dep of Guava. See #5214 "@com_google_guava_guava//jar", "@com_google_j2objc_j2objc_annotations//jar", diff --git a/api/build.gradle b/api/build.gradle index 1574cdd8b4a..a959fed4ea9 100644 --- a/api/build.gradle +++ b/api/build.gradle @@ -13,7 +13,8 @@ evaluationDependsOn(project(':grpc-context').path) dependencies { api project(':grpc-context'), libraries.jsr305 - guavaDependency 'implementation' + implementation libraries.guava, + libraries.errorprone testImplementation project(':grpc-context').sourceSets.test.output, project(':grpc-testing'), diff --git a/api/src/main/java/io/grpc/Detachable.java b/api/src/main/java/io/grpc/Detachable.java new file mode 100644 index 00000000000..c0cbf016f5b --- /dev/null +++ b/api/src/main/java/io/grpc/Detachable.java @@ -0,0 +1,45 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.io.InputStream; + +/** + * An extension of {@link InputStream} that allows the underlying data source to be detached and + * transferred to a new instance of the same kind. The detached InputStream takes over the + * ownership of the underlying data source. That's said, the detached InputStream is responsible + * for releasing its resources after use. The detached InputStream preserves internal states of + * the underlying data source. Data can be consumed through the detached InputStream as if being + * continually consumed through the original instance. The original instance discards internal + * states of detached data source and is no longer consumable as if the data source is exhausted. + * + *

A normal usage of this API is to extend the lifetime of the data source owned by the + * original instance for doing extra processing before releasing it. For example, when combined + * with {@link HasByteBuffer}, a custom {@link io.grpc.MethodDescriptor.Marshaller} can take + * over the ownership of buffers containing inbound data and perform delayed deserialization. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/7387") +public interface Detachable { + + /** + * Detaches the underlying data source from this instance and transfers to an {@link + * InputStream}. Detaching data from an already-detached instance gives an InputStream with + * zero bytes of data. + * + */ + InputStream detach(); +} diff --git a/api/src/main/java/io/grpc/ForwardingChannelBuilder.java b/api/src/main/java/io/grpc/ForwardingChannelBuilder.java index 5fe0464bbc9..b584e7132e9 100644 --- a/api/src/main/java/io/grpc/ForwardingChannelBuilder.java +++ b/api/src/main/java/io/grpc/ForwardingChannelBuilder.java @@ -17,6 +17,7 @@ package io.grpc; import com.google.common.base.MoreObjects; +import com.google.errorprone.annotations.DoNotCall; import java.util.List; import java.util.Map; import java.util.concurrent.Executor; @@ -43,6 +44,7 @@ protected ForwardingChannelBuilder() {} /** * This method serves to force sub classes to "hide" this static factory. */ + @DoNotCall("Unsupported") public static ManagedChannelBuilder forAddress(String name, int port) { throw new UnsupportedOperationException("Subclass failed to hide static factory"); } @@ -50,6 +52,7 @@ public static ManagedChannelBuilder forAddress(String name, int port) { /** * This method serves to force sub classes to "hide" this static factory. */ + @DoNotCall("Unsupported") public static ManagedChannelBuilder forTarget(String target) { throw new UnsupportedOperationException("Subclass failed to hide static factory"); } diff --git a/api/src/main/java/io/grpc/ForwardingServerBuilder.java b/api/src/main/java/io/grpc/ForwardingServerBuilder.java index 27358d06992..696a441e9b6 100644 --- a/api/src/main/java/io/grpc/ForwardingServerBuilder.java +++ b/api/src/main/java/io/grpc/ForwardingServerBuilder.java @@ -17,6 +17,7 @@ package io.grpc; import com.google.common.base.MoreObjects; +import com.google.errorprone.annotations.DoNotCall; import java.io.File; import java.io.InputStream; import java.util.concurrent.Executor; @@ -38,6 +39,7 @@ protected ForwardingServerBuilder() {} /** * This method serves to force sub classes to "hide" this static factory. */ + @DoNotCall("Unsupported") public static ServerBuilder forPort(int port) { throw new UnsupportedOperationException("Subclass failed to hide static factory"); } diff --git a/api/src/main/java/io/grpc/HasByteBuffer.java b/api/src/main/java/io/grpc/HasByteBuffer.java new file mode 100644 index 00000000000..97f2435524a --- /dev/null +++ b/api/src/main/java/io/grpc/HasByteBuffer.java @@ -0,0 +1,52 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc; + +import java.nio.ByteBuffer; +import javax.annotation.Nullable; + +/** + * Extension to an {@link java.io.InputStream} whose content can be accessed as {@link + * ByteBuffer}s. + * + *

This can be used for optimizing the case for the consumer of a {@link ByteBuffer}-backed + * input stream supports efficient reading from {@link ByteBuffer}s directly. This turns the reader + * interface from an {@link java.io.InputStream} to {@link ByteBuffer}s, without copying the + * content to a byte array and read from it. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/7387") +public interface HasByteBuffer { + + /** + * Indicates whether or not {@link #getByteBuffer} operation is supported. + */ + boolean byteBufferSupported(); + + /** + * Gets a {@link ByteBuffer} containing some bytes of the content next to be read, or {@code + * null} if has reached end of the content. The number of bytes contained in the returned buffer + * is implementation specific. Calling this method does not change the position of the input + * stream. The returned buffer's content should not be modified, but the position, limit, and + * mark may be changed. Operations for changing the position, limit, and mark of the returned + * buffer does not affect the position, limit, and mark of this input stream. This is an optional + * method, so callers should first check {@link #byteBufferSupported}. + * + * @throws UnsupportedOperationException if this operation is not supported. + */ + @Nullable + ByteBuffer getByteBuffer(); +} diff --git a/api/src/main/java/io/grpc/InternalServerInterceptors.java b/api/src/main/java/io/grpc/InternalServerInterceptors.java index 26c0b352b3a..e48c1468f1e 100644 --- a/api/src/main/java/io/grpc/InternalServerInterceptors.java +++ b/api/src/main/java/io/grpc/InternalServerInterceptors.java @@ -21,11 +21,6 @@ */ @Internal public final class InternalServerInterceptors { - public static ServerCallHandler interceptCallHandler( - ServerInterceptor interceptor, ServerCallHandler callHandler) { - return ServerInterceptors.InterceptCallHandler.create(interceptor, callHandler); - } - public static ServerMethodDefinition wrapMethod( final ServerMethodDefinition definition, diff --git a/auth/build.gradle b/auth/build.gradle index 1f5e4b41953..233de359b49 100644 --- a/auth/build.gradle +++ b/auth/build.gradle @@ -10,9 +10,9 @@ description = "gRPC: Auth" dependencies { api project(':grpc-api'), libraries.google_auth_credentials - guavaDependency 'implementation' - testImplementation project(':grpc-testing') - googleOauth2Dependency 'testImplementation' + implementation libraries.guava + testImplementation project(':grpc-testing'), + libraries.google_auth_oauth2_http signature "org.codehaus.mojo.signature:java17:1.0@signature" signature "net.sf.androidscents.signature:android-api-level-14:4.0_r4@signature" } diff --git a/binder/build.gradle b/binder/build.gradle index 27fd9fec637..537c23a0092 100644 --- a/binder/build.gradle +++ b/binder/build.gradle @@ -7,6 +7,22 @@ plugins { description = 'gRPC BinderChannel' android { + sourceSets { + test { + java { + srcDirs += "${projectDir}/../core/src/test/java/" + setIncludes(["io/grpc/internal/FakeClock.java", + "io/grpc/binder/**"]) + } + } + androidTest { + java { + srcDirs += "${projectDir}/../core/src/test/java/" + setIncludes(["io/grpc/internal/AbstractTransportTest.java", + "io/grpc/binder/**"]) + } + } + } compileSdkVersion 29 compileOptions { sourceCompatibility 1.8 @@ -14,10 +30,11 @@ android { } defaultConfig { minSdkVersion 16 - targetSdkVersion 28 + targetSdkVersion 29 versionCode 1 versionName "1.0" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" + multiDexEnabled true } lintOptions { abortOnError false } } @@ -29,8 +46,10 @@ repositories { dependencies { api project(':grpc-core') - guavaDependency 'implementation' + implementation libraries.androidx_annotation + implementation libraries.androidx_core + implementation libraries.guava testImplementation libraries.androidx_core testImplementation libraries.androidx_test testImplementation libraries.junit @@ -39,9 +58,37 @@ dependencies { // Unreleased change: https://github.com/robolectric/robolectric/pull/5432 exclude group: 'com.google.auto.service', module: 'auto-service' } + testImplementation (libraries.guava_testlib) { + exclude group: 'junit', module: 'junit' + } testImplementation libraries.truth + + androidTestAnnotationProcessor libraries.autovalue + androidTestImplementation project(':grpc-testing') + androidTestImplementation project(':grpc-protobuf-lite') + androidTestImplementation libraries.autovalue_annotation + androidTestImplementation libraries.junit + androidTestImplementation libraries.androidx_core + androidTestImplementation libraries.androidx_test androidTestImplementation libraries.androidx_test_rules androidTestImplementation libraries.androidx_test_ext_junit androidTestImplementation libraries.truth + androidTestImplementation libraries.mockito_android + androidTestImplementation libraries.androidx_lifecycle_service + androidTestImplementation (libraries.guava_testlib) { + exclude group: 'junit', module: 'junit' + } } + +import net.ltgt.gradle.errorprone.CheckSeverity + +tasks.withType(JavaCompile) { + options.compilerArgs += [ + "-Xlint:-cast" + ] + appendToProperty(it.options.errorprone.excludedPaths, ".*/R.java", "|") + // Reuses source code from grpc-core, which targets Java 7 (no method references) + options.errorprone.check("UnnecessaryAnonymousClass", CheckSeverity.OFF) +} + [publishMavenPublicationToMavenRepository]*.onlyIf { false } diff --git a/binder/src/androidTest/AndroidManifest.xml b/binder/src/androidTest/AndroidManifest.xml new file mode 100644 index 00000000000..da2cdcb75ea --- /dev/null +++ b/binder/src/androidTest/AndroidManifest.xml @@ -0,0 +1,8 @@ + + + + + + + diff --git a/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java b/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java new file mode 100644 index 00000000000..13dd0acf6d9 --- /dev/null +++ b/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java @@ -0,0 +1,176 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; + +import android.app.Application; +import android.app.Service; +import android.content.Context; +import android.content.Intent; +import android.os.IBinder; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.common.io.ByteStreams; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import io.grpc.CallOptions; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.grpc.Server; +import io.grpc.ServerCallHandler; +import io.grpc.ServerServiceDefinition; +import io.grpc.stub.ClientCalls; +import io.grpc.stub.ServerCalls; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** + * Basic tests for Binder Channel, covering some of the edge cases not exercised by + * AbstractTransportTest. + */ +@RunWith(AndroidJUnit4.class) +public final class BinderChannelSmokeTest { + private final Context appContext = ApplicationProvider.getApplicationContext(); + + private static final int SLIGHTLY_MORE_THAN_ONE_BLOCK = 16 * 1024 + 100; + private static final String MSG = "Some text which will be repeated many many times"; + + final MethodDescriptor method = + MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) + .setFullMethodName("test/method") + .setType(MethodDescriptor.MethodType.UNARY) + .build(); + + final MethodDescriptor singleLargeResultMethod = + MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) + .setFullMethodName("test/noResultMethod") + .setType(MethodDescriptor.MethodType.SERVER_STREAMING) + .build(); + + AndroidComponentAddress serverAddress; + ManagedChannel channel; + + @Before + public void setUp() throws Exception { + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + respObserver.onNext(req); + respObserver.onCompleted(); + }); + + ServerCallHandler singleLargeResultCallHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + respObserver.onNext(createLargeString(SLIGHTLY_MORE_THAN_ONE_BLOCK)); + respObserver.onCompleted(); + }); + + ServerServiceDefinition serviceDef = + ServerServiceDefinition.builder("test") + .addMethod(method, callHandler) + .addMethod(singleLargeResultMethod, singleLargeResultCallHandler) + .build(); + + AndroidComponentAddress serverAddress = HostServices.allocateService(appContext); + HostServices.configureService(serverAddress, + HostServices.serviceParamsBuilder() + .setServerFactory((service, receiver) -> + BinderServerBuilder.forService(service, receiver) + .addService(serviceDef) + .build()) + .build()); + + channel = BinderChannelBuilder.forAddress(serverAddress, appContext).build(); + } + + @After + public void tearDown() throws Exception { + channel.shutdownNow(); + HostServices.awaitServiceShutdown(); + } + + private ListenableFuture doCall(String request) { + return doCall(method, request); + } + + private ListenableFuture doCall( + MethodDescriptor methodDesc, String request) { + ListenableFuture future = + ClientCalls.futureUnaryCall(channel.newCall(methodDesc, CallOptions.DEFAULT), request); + return Futures.withTimeout( + future, 5L, TimeUnit.SECONDS, Executors.newSingleThreadScheduledExecutor()); + } + + @Test + public void testBasicCall() throws Exception { + assertThat(doCall("Hello").get()).isEqualTo("Hello"); + } + + @Test + public void testEmptyMessage() throws Exception { + assertThat(doCall("").get()).isEmpty(); + } + + @Test + public void test100kString() throws Exception { + String fullMsg = createLargeString(100000); + assertThat(doCall(fullMsg).get()).isEqualTo(fullMsg); + } + + @Test + public void testSingleLargeResultCall() throws Exception { + String res = doCall(singleLargeResultMethod, "hello").get(); + assertThat(res.length()).isEqualTo(SLIGHTLY_MORE_THAN_ONE_BLOCK); + } + + private static String createLargeString(int size) { + StringBuilder sb = new StringBuilder(); + while (sb.length() < size) { + sb.append(MSG); + } + sb.setLength(size); + return sb.toString(); + } + + private static class StringMarshaller implements MethodDescriptor.Marshaller { + public static final StringMarshaller INSTANCE = new StringMarshaller(); + + @Override + public InputStream stream(String value) { + return new ByteArrayInputStream(value.getBytes(UTF_8)); + } + + @Override + public String parse(InputStream stream) { + try { + return new String(ByteStreams.toByteArray(stream), UTF_8); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } + } +} diff --git a/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java b/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java new file mode 100644 index 00000000000..9b9ce2b1825 --- /dev/null +++ b/binder/src/androidTest/java/io/grpc/binder/BinderSecurityTest.java @@ -0,0 +1,193 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import android.app.Service; +import android.content.Context; +import android.content.Intent; +import android.os.IBinder; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.common.base.Function; +import com.google.protobuf.Empty; +import io.grpc.CallOptions; +import io.grpc.ManagedChannel; +import io.grpc.MethodDescriptor; +import io.grpc.Server; +import io.grpc.ServerCallHandler; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; +import io.grpc.protobuf.lite.ProtoLiteUtils; +import io.grpc.stub.ClientCalls; +import io.grpc.stub.ServerCalls; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import javax.annotation.Nullable; +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; + +@RunWith(AndroidJUnit4.class) +public final class BinderSecurityTest { + private final Context appContext = ApplicationProvider.getApplicationContext(); + + String[] serviceNames = new String[] {"foo", "bar", "baz"}; + + @Nullable ManagedChannel channel; + Map> methods = new HashMap<>(); + List> calls = new ArrayList<>(); + + @After + public void tearDown() throws Exception { + if (channel != null) { + channel.shutdownNow(); + } + HostServices.awaitServiceShutdown(); + } + + private void createChannel() throws Exception { + createChannel(SecurityPolicies.serverInternalOnly(), SecurityPolicies.internalOnly()); + } + + private void createChannel(ServerSecurityPolicy serverPolicy, SecurityPolicy channelPolicy) + throws Exception { + AndroidComponentAddress addr = HostServices.allocateService(appContext); + HostServices.configureService(addr, + HostServices.serviceParamsBuilder() + .setServerFactory((service, receiver) -> buildServer(service, receiver, serverPolicy)) + .build()); + + channel = + BinderChannelBuilder.forAddress(addr, appContext) + .securityPolicy(channelPolicy) + .build(); + } + + private Server buildServer( + Service service, + IBinderReceiver receiver, + ServerSecurityPolicy serverPolicy) { + BinderServerBuilder serverBuilder = BinderServerBuilder.forService(service, receiver); + serverBuilder.securityPolicy(serverPolicy); + + MethodDescriptor.Marshaller marshaller = + ProtoLiteUtils.marshaller(Empty.getDefaultInstance()); + for (String serviceName : serviceNames) { + ServerServiceDefinition.Builder builder = ServerServiceDefinition.builder(serviceName); + for (int i = 0; i < 2; i++) { + // Add two methods to the service. + String name = serviceName + "/method" + i; + MethodDescriptor method = + MethodDescriptor.newBuilder(marshaller, marshaller) + .setFullMethodName(name) + .setType(MethodDescriptor.MethodType.UNARY) + .build(); + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + calls.add(method); + respObserver.onNext(req); + respObserver.onCompleted(); + }); + builder.addMethod(method, callHandler); + methods.put(name, method); + } + serverBuilder.addService(builder.build()); + } + return serverBuilder.build(); + } + + private void assertCallSuccess(MethodDescriptor method) { + assertThat( + ClientCalls.blockingUnaryCall( + channel, method, CallOptions.DEFAULT, Empty.getDefaultInstance())) + .isNotNull(); + } + + private void assertCallFailure(MethodDescriptor method, Status status) { + try { + ClientCalls.blockingUnaryCall(channel, method, CallOptions.DEFAULT, null); + fail(); + } catch (StatusRuntimeException sre) { + assertThat(sre.getStatus().getCode()).isEqualTo(status.getCode()); + } + } + + @Test + public void testAllowedCall() throws Exception { + createChannel(); + for (MethodDescriptor method : methods.values()) { + assertCallSuccess(method); + } + } + + @Test + public void testServerDisllowsCalls() throws Exception { + createChannel( + ServerSecurityPolicy.newBuilder() + .servicePolicy("foo", policy((uid) -> false)) + .servicePolicy("bar", policy((uid) -> false)) + .servicePolicy("baz", policy((uid) -> false)) + .build(), + SecurityPolicies.internalOnly()); + for (MethodDescriptor method : methods.values()) { + assertCallFailure(method, Status.PERMISSION_DENIED); + } + } + + @Test + public void testClientDoesntTrustServer() throws Exception { + createChannel(SecurityPolicies.serverInternalOnly(), policy((uid) -> false)); + for (MethodDescriptor method : methods.values()) { + assertCallFailure(method, Status.PERMISSION_DENIED); + } + } + + @Test + public void testPerServicePolicy() throws Exception { + createChannel( + ServerSecurityPolicy.newBuilder() + .servicePolicy("foo", policy((uid) -> true)) + .servicePolicy("bar", policy((uid) -> false)) + .build(), + SecurityPolicies.internalOnly()); + + for (MethodDescriptor method : methods.values()) { + if (method.getServiceName().equals("bar")) { + assertCallFailure(method, Status.PERMISSION_DENIED); + } else { + assertCallSuccess(method); + } + } + } + + private static SecurityPolicy policy(Function func) { + return new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + return func.apply(uid) ? Status.OK : Status.PERMISSION_DENIED; + } + }; + } +} diff --git a/binder/src/androidTest/java/io/grpc/binder/HostServices.java b/binder/src/androidTest/java/io/grpc/binder/HostServices.java new file mode 100644 index 00000000000..92b232f1ff0 --- /dev/null +++ b/binder/src/androidTest/java/io/grpc/binder/HostServices.java @@ -0,0 +1,295 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static java.util.concurrent.TimeUnit.SECONDS; + +import android.app.Service; +import android.content.Context; +import android.content.Intent; +import android.os.Binder; +import android.os.IBinder; +import android.os.Parcel; +import android.os.RemoteException; +import androidx.lifecycle.LifecycleService; +import com.google.auto.value.AutoValue; +import com.google.common.base.Supplier; +import com.google.common.collect.ImmutableList; +import io.grpc.NameResolver; +import io.grpc.Server; +import io.grpc.ServerServiceDefinition; +import io.grpc.ServerStreamTracer; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.internal.InternalServer; +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; + +/** + * A test helper class for creating android services to host gRPC servers. + * + *

Currently only supports two servers at a time. If more are required, define a new class, add + * it to the manifest, and the hostServiceClasses array. + */ +public final class HostServices { + + private static final Logger logger = Logger.getLogger(HostServices.class.getName()); + + private static final Class[] hostServiceClasses = + new Class[] { + HostService1.class, HostService2.class, + }; + + + public interface ServerFactory { + Server createServer(Service service, IBinderReceiver receiver); + } + + @AutoValue + public abstract static class ServiceParams { + @Nullable + abstract Executor transactionExecutor(); + + @Nullable + abstract Supplier rawBinderSupplier(); + + @Nullable + abstract ServerFactory serverFactory(); + + public abstract Builder toBuilder(); + + public static Builder builder() { + return new AutoValue_HostServices_ServiceParams.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setRawBinderSupplier(Supplier binderSupplier); + + public abstract Builder setServerFactory(ServerFactory serverFactory); + + /** + * If set, this executor will be used to pass any inbound transactions to the server. This can + * be used to simulate delayed, re-ordered, or dropped packets. + */ + public abstract Builder setTransactionExecutor(Executor transactionExecutor); + + public abstract ServiceParams build(); + } + } + + @GuardedBy("HostServices.class") + private static final Map, AndroidComponentAddress> serviceAddresses = new HashMap<>(); + + @GuardedBy("HostServices.class") + private static final Map, ServiceParams> serviceParams = new HashMap<>(); + + @GuardedBy("HostServices.class") + private static final Map, HostService> activeServices = new HashMap<>(); + + @Nullable + @GuardedBy("HostServices.class") + private static CountDownLatch serviceShutdownLatch; + + private HostServices() {} + + /** Create a new {@link ServiceParams} builder. */ + public static ServiceParams.Builder serviceParamsBuilder() { + return ServiceParams.builder(); + } + + /** + * Wait for all services to shutdown. This should be called from a test's tearDown method, to + * ensure the next test is able to use this class again (since Android itself is in control of the + * services). + */ + public static void awaitServiceShutdown() throws InterruptedException { + CountDownLatch latch = null; + synchronized (HostServices.class) { + if (serviceShutdownLatch == null && !activeServices.isEmpty()) { + latch = new CountDownLatch(activeServices.size()); + serviceShutdownLatch = latch; + } + serviceParams.clear(); + serviceAddresses.clear(); + } + if (latch != null) { + if (!latch.await(10, SECONDS)) { + throw new AssertionError("Failed to shut down services"); + } + } + synchronized (HostServices.class) { + checkState(activeServices.isEmpty()); + checkState(serviceParams.isEmpty()); + checkState(serviceAddresses.isEmpty()); + serviceShutdownLatch = null; + } + } + + /** Create the address for a host-service. */ + private static AndroidComponentAddress hostServiceAddress(Context appContext, Class cls) { + // NOTE: Even though we have a context object, we intentionally don't use a "local", + // address, since doing so would mark the address with our UID for security purposes, + // and that would limit the effectiveness of tests. + // Using this API forces us to rely on Binder.getCallingUid. + return AndroidComponentAddress.forRemoteComponent(appContext.getPackageName(), cls.getName()); + } + + /** + * Allocate a new host service. + * + * @param appContext The application context. + * @return The AndroidComponentAddress of the service. + */ + public static synchronized AndroidComponentAddress allocateService(Context appContext) { + for (Class cls : hostServiceClasses) { + if (!serviceAddresses.containsKey(cls)) { + AndroidComponentAddress address = hostServiceAddress(appContext, cls); + serviceAddresses.put(cls, address); + return address; + } + } + throw new AssertionError("This test helper only supports two services at a time."); + } + + /** + * Configure an allocated hosting service. + * + * @param androidComponentAddress The address of the service. + * @param params The parameters used to build the service. + */ + public static synchronized void configureService( + AndroidComponentAddress androidComponentAddress, ServiceParams params) { + for (Class cls : hostServiceClasses) { + if (serviceAddresses.get(cls).equals(androidComponentAddress)) { + checkState(!serviceParams.containsKey(cls)); + serviceParams.put(cls, params); + return; + } + } + throw new AssertionError("Unable to find service for address " + androidComponentAddress); + } + + /** An Android Service to host each gRPC server. */ + private abstract static class HostService extends LifecycleService { + + @Nullable private ServiceParams params; + @Nullable private Supplier binderSupplier; + @Nullable private Server server; + + @Override + public final void onCreate() { + super.onCreate(); + Class cls = getClass(); + synchronized (HostServices.class) { + checkState(!activeServices.containsKey(cls)); + activeServices.put(cls, this); + checkState(serviceParams.containsKey(cls)); + params = serviceParams.get(cls); + ServerFactory factory = params.serverFactory(); + if (factory != null) { + IBinderReceiver receiver = new IBinderReceiver(); + server = factory.createServer(this, receiver); + try { + server.start(); + } catch (IOException ioe) { + throw new AssertionError("Failed to start server", ioe); + } + binderSupplier = () -> receiver.get(); + } else { + binderSupplier = params.rawBinderSupplier(); + if (binderSupplier == null) { + throw new AssertionError("Insufficient params for host service"); + } + } + } + } + + @Override + public final IBinder onBind(Intent intent) { + // Calling super here is a little weird (it returns null), but there's a @CallSuper + // annotation. + super.onBind(intent); + synchronized (HostServices.class) { + Executor executor = params.transactionExecutor(); + if (executor != null) { + return new ProxyBinder(binderSupplier.get(), executor); + } else { + return binderSupplier.get(); + } + } + } + + @Override + public final void onDestroy() { + synchronized (HostServices.class) { + if (server != null) { + server.shutdown(); + server = null; + } + HostService removed = activeServices.remove(getClass()); + checkState(removed == this); + serviceAddresses.remove(getClass()); + serviceParams.remove(getClass()); + if (serviceShutdownLatch != null) { + serviceShutdownLatch.countDown(); + } + } + super.onDestroy(); + } + } + + /** The first concrete host service */ + public static final class HostService1 extends HostService {} + + /** The second concrete host service */ + public static final class HostService2 extends HostService {} + + /** Wraps an IBinder to send incoming transactions to a different thread. */ + private static class ProxyBinder extends Binder { + private final IBinder delegate; + private final Executor executor; + + ProxyBinder(IBinder delegate, Executor executor) { + this.delegate = delegate; + this.executor = executor; + } + + @Override + protected boolean onTransact(int code, Parcel parcel, Parcel reply, int flags) { + executor.execute( + () -> { + try { + delegate.transact(code, parcel, reply, flags); + } catch (RemoteException re) { + logger.log(Level.WARNING, "Exception in proxybinder", re); + } + }); + return true; + } + } +} diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java new file mode 100644 index 00000000000..a7ee6b764d8 --- /dev/null +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java @@ -0,0 +1,338 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.truth.Truth.assertThat; + +import android.app.Service; +import android.content.Context; +import android.content.Intent; +import android.os.IBinder; +import android.os.Parcel; +import androidx.core.content.ContextCompat; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.common.util.concurrent.testing.TestingExecutors; +import com.google.protobuf.Empty; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Server; +import io.grpc.ServerCallHandler; +import io.grpc.ServerServiceDefinition; +import io.grpc.Status; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.BinderServerBuilder; +import io.grpc.binder.BindServiceFlags; +import io.grpc.binder.HostServices; +import io.grpc.binder.IBinderReceiver; +import io.grpc.binder.InboundParcelablePolicy; +import io.grpc.binder.SecurityPolicies; +import io.grpc.internal.ClientStream; +import io.grpc.internal.ClientStreamListener; +import io.grpc.internal.ClientStreamListener.RpcProgress; +import io.grpc.internal.FixedObjectPool; +import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.StreamListener; +import io.grpc.protobuf.lite.ProtoLiteUtils; +import io.grpc.stub.ServerCalls; +import java.io.IOException; +import java.util.concurrent.ScheduledExecutorService; +import javax.annotation.Nullable; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** + * Client-side transport tests for binder channel. Like BinderChannelSmokeTest, this covers edge + * cases not exercised by AbstractTransportTest, but in this case we're dealing with rare ordering + * issues at the transport level, so we use a BinderTransport.BinderClientTransport directly, rather + * than a channel. + */ +@RunWith(AndroidJUnit4.class) +public final class BinderClientTransportTest { + private final Context appContext = ApplicationProvider.getApplicationContext(); + + MethodDescriptor.Marshaller marshaller = + ProtoLiteUtils.marshaller(Empty.getDefaultInstance()); + + MethodDescriptor methodDesc = + MethodDescriptor.newBuilder(marshaller, marshaller) + .setFullMethodName("test/method") + .setType(MethodDescriptor.MethodType.UNARY) + .build(); + + MethodDescriptor streamingMethodDesc = + MethodDescriptor.newBuilder(marshaller, marshaller) + .setFullMethodName("test/methodServerStreaming") + .setType(MethodDescriptor.MethodType.SERVER_STREAMING) + .build(); + + AndroidComponentAddress serverAddress; + BinderTransport.BinderClientTransport transport; + + private final ObjectPool executorServicePool = + new FixedObjectPool<>(TestingExecutors.sameThreadScheduledExecutor()); + private final TestTransportListener transportListener = new TestTransportListener(); + private final TestStreamListener streamListener = new TestStreamListener(); + + private int serverCallsCompleted; + + @Before + public void setUp() throws Exception { + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + respObserver.onNext(req); + respObserver.onCompleted(); + serverCallsCompleted += 1; + }); + + ServerCallHandler streamingCallHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + for (int i = 0; i < 100; i++) { + respObserver.onNext(req); + } + respObserver.onCompleted(); + serverCallsCompleted += 1; + }); + + ServerServiceDefinition serviceDef = + ServerServiceDefinition.builder("test") + .addMethod(methodDesc, callHandler) + .addMethod(streamingMethodDesc, streamingCallHandler) + .build(); + + serverAddress = HostServices.allocateService(appContext); + HostServices.configureService(serverAddress, + HostServices.serviceParamsBuilder() + .setServerFactory((service, receiver) -> + BinderServerBuilder.forService(service, receiver) + .addService(serviceDef) + .build()) + .build()); + + transport = + new BinderTransport.BinderClientTransport( + appContext, + serverAddress, + BindServiceFlags.DEFAULTS, + ContextCompat.getMainExecutor(appContext), + executorServicePool, + executorServicePool, + SecurityPolicies.internalOnly(), + InboundParcelablePolicy.DEFAULT, + Attributes.EMPTY); + + Runnable r = transport.start(transportListener); + r.run(); + transportListener.awaitReady(); + } + + @After + public void tearDown() throws Exception { + transport.shutdownNow(Status.OK); + HostServices.awaitServiceShutdown(); + } + + @Test + public void testShutdownBeforeStreamStart_b153326034() throws Exception { + ClientStream stream = transport.newStream(methodDesc, new Metadata(), CallOptions.DEFAULT); + transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); + + // This shouldn't throw an exception. + stream.start(streamListener); + } + + @Test + public void testRequestWhileStreamIsWaitingOnCall_b154088869() throws Exception { + ClientStream stream = + transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT); + + stream.start(streamListener); + stream.writeMessage(marshaller.stream(Empty.getDefaultInstance())); + stream.halfClose(); + stream.request(3); + + streamListener.awaitMessages(); + streamListener.messageProducer.next(); + streamListener.messageProducer.next(); + + // Without the fix, this loops forever. + stream.request(2); + } + + @Test + public void testTransactionForDiscardedCall_b155244043() throws Exception { + ClientStream stream = + transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT); + + stream.start(streamListener); + stream.writeMessage(marshaller.stream(Empty.getDefaultInstance())); + + assertThat(transport.getOngoingCalls()).hasSize(1); + int callId = transport.getOngoingCalls().keySet().iterator().next(); + stream.cancel(Status.UNKNOWN); + + // Send a transaction to the no-longer present call ID. It should be silently ignored. + Parcel p = Parcel.obtain(); + transport.handleTransaction(callId, p); + p.recycle(); + } + + @Test + public void testBadTransactionStreamThroughput_b163053382() throws Exception { + ClientStream stream = + transport.newStream(streamingMethodDesc, new Metadata(), CallOptions.DEFAULT); + + stream.start(streamListener); + stream.writeMessage(marshaller.stream(Empty.getDefaultInstance())); + stream.halfClose(); + stream.request(1000); + + // Wait until we receive the first message. + streamListener.awaitMessages(); + // Wait until the server actually provides all messages and completes the call. + awaitServerCallsCompleted(1); + + // Now we should be able to receive all messages on a single message producer. + assertThat(streamListener.drainMessages()).isEqualTo(100); + } + + @Test + public void testMessageProducerClosedAfterStream_b169313545() { + ClientStream stream = + transport.newStream(methodDesc, new Metadata(), CallOptions.DEFAULT); + + stream.start(streamListener); + stream.writeMessage(marshaller.stream(Empty.getDefaultInstance())); + stream.halfClose(); + stream.request(2); + + // Wait until we receive the first message. + streamListener.awaitMessages(); + + // Now cancel the stream, forcing it to close. + stream.cancel(Status.CANCELLED); + + // The message producer shouldn't throw an exception if we drain it now. + streamListener.drainMessages(); + } + + private synchronized void awaitServerCallsCompleted(int calls) { + while (serverCallsCompleted < calls) { + try { + wait(100); + } catch (InterruptedException inte) { + throw new AssertionError("Interrupted waiting for servercalls"); + } + } + } + + private static final class TestTransportListener implements ManagedClientTransport.Listener { + public boolean ready; + public boolean inUse; + @Nullable public Status shutdownStatus; + public boolean terminated; + + @Override + public void transportShutdown(Status shutdownStatus) { + this.shutdownStatus = shutdownStatus; + } + + @Override + public void transportTerminated() { + terminated = true; + } + + @Override + public synchronized void transportReady() { + ready = true; + notify(); + } + + public synchronized void awaitReady() { + while (!ready) { + try { + wait(100); + } catch (InterruptedException inte) { + throw new AssertionError("Interrupted waiting for ready"); + } + } + } + + @Override + public void transportInUse(boolean inUse) { + this.inUse = inUse; + } + } + + private static final class TestStreamListener implements ClientStreamListener { + + public StreamListener.MessageProducer messageProducer; + public boolean ready; + public Metadata headers; + @Nullable public Status closedStatus; + + @Override + public void messagesAvailable(StreamListener.MessageProducer messageProducer) { + this.messageProducer = messageProducer; + } + + public synchronized void awaitMessages() { + while (messageProducer == null) { + try { + wait(100); + } catch (InterruptedException inte) { + throw new AssertionError("Interrupted waiting for messages"); + } + } + } + + public int drainMessages() { + int n = 0; + while (messageProducer.next() != null) { + n += 1; + } + return n; + } + + @Override + public void onReady() { + ready = true; + } + + @Override + public void headersRead(Metadata headers) { + this.headers = headers; + } + + @Override + public void closed(Status status, Metadata trailers) { + this.closedStatus = status; + } + + @Override + public void closed(Status status, RpcProgress rpcProgress, Metadata trailers) { + this.closedStatus = status; + } + } +} diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java new file mode 100644 index 00000000000..24af04d4d61 --- /dev/null +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderTransportTest.java @@ -0,0 +1,136 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import android.content.Context; +import androidx.core.content.ContextCompat; +import androidx.test.core.app.ApplicationProvider; +import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.common.util.concurrent.MoreExecutors; +import io.grpc.ServerStreamTracer; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.BindServiceFlags; +import io.grpc.binder.HostServices; +import io.grpc.binder.InboundParcelablePolicy; +import io.grpc.binder.SecurityPolicies; +import io.grpc.internal.AbstractTransportTest; +import io.grpc.internal.FixedObjectPool; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.InternalServer; +import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import org.junit.After; +import org.junit.Ignore; +import org.junit.Test; +import org.junit.runner.RunWith; + +/** + * A test for the Android binder based transport. + * + *

This class really just sets up the test environment. All of the actual tests are defined in + * AbstractTransportTest. + */ +@RunWith(AndroidJUnit4.class) +public final class BinderTransportTest extends AbstractTransportTest { + + private final Context appContext = ApplicationProvider.getApplicationContext(); + private final ObjectPool executorServicePool = + SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); + + @Override + @After + public void tearDown() throws InterruptedException { + super.tearDown(); + HostServices.awaitServiceShutdown(); + } + + @Override + protected InternalServer newServer(List streamTracerFactories) { + AndroidComponentAddress addr = HostServices.allocateService(appContext); + + BinderServer binderServer = new BinderServer(addr, + executorServicePool, + streamTracerFactories, + SecurityPolicies.serverInternalOnly(), + InboundParcelablePolicy.DEFAULT); + + HostServices.configureService(addr, + HostServices.serviceParamsBuilder() + .setRawBinderSupplier(() -> binderServer.getHostBinder()) + .build()); + + return binderServer; + } + + @Override + protected InternalServer newServer( + int port, List streamTracerFactories) { + return newServer(streamTracerFactories); + } + + @Override + protected String testAuthority(InternalServer server) { + return ((AndroidComponentAddress) server.getListenSocketAddress()).getAuthority(); + } + + @Override + protected ManagedClientTransport newClientTransport(InternalServer server) { + AndroidComponentAddress addr = (AndroidComponentAddress) server.getListenSocketAddress(); + return new BinderTransport.BinderClientTransport( + appContext, + addr, + BindServiceFlags.DEFAULTS, + ContextCompat.getMainExecutor(appContext), + executorServicePool, + new FixedObjectPool<>(MoreExecutors.directExecutor()), + SecurityPolicies.internalOnly(), + InboundParcelablePolicy.DEFAULT, + eagAttrs()); + } + + @Test + @Ignore("BinderTransport doesn't report socket stats yet.") + @Override + public void socketStats() throws Exception {} + + @Test + @Ignore("BinderTransport doesn't do message-level flow control yet.") + @Override + public void flowControlPushBack() throws Exception {} + + @Test + @Ignore("This test isn't appropriate for BinderTransport.") + @Override + public void serverAlreadyListening() throws Exception { + // This test asserts that two Servers can't listen on the same SocketAddress. For a regular + // network server, that address refers to a network port, and for a BinderServer it + // refers to an Android Service class declared in an applications manifest. + // + // However, unlike a regular network server, which is responsible for listening on its port, a + // BinderServier is not responsible for the creation of its host Service. The opposite is + // the case, with the host Android Service (itself created by the Android platform in + // response to a connection) building the gRPC server. + // + // Passing this test would require us to manually check that two Server instances aren't, + // created with the same Android Service class, but due to the "inversion of control" described + // above, we would actually be testing (and making assumptions about) the precise lifecycle of + // Android Services, which is arguably not our concern. + } +} diff --git a/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java b/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java new file mode 100644 index 00000000000..b13840746d4 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/AndroidComponentAddress.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import android.content.ComponentName; +import android.content.Context; +import io.grpc.ExperimentalApi; +import java.net.SocketAddress; + +/** Custom SocketAddress class referencing an Android Component. */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") +public final class AndroidComponentAddress extends SocketAddress { + + private static final long serialVersionUID = 0L; + + private final ComponentName component; + + private AndroidComponentAddress(ComponentName component) { + this.component = component; + } + + /** Create an address for the given context instance. */ + public static AndroidComponentAddress forContext(Context context) { + return forLocalComponent(context, context.getClass()); + } + + /** Create an address referencing a component within this application. */ + public static AndroidComponentAddress forLocalComponent(Context context, Class cls) { + return forComponent(new ComponentName(context, cls)); + } + + /** Create an address referencing a component (potentially) in another application. */ + public static AndroidComponentAddress forRemoteComponent(String packageName, String className) { + return forComponent(new ComponentName(packageName, className)); + } + + public static AndroidComponentAddress forComponent(ComponentName component) { + return new AndroidComponentAddress(component); + } + + public String getAuthority() { + return component.getPackageName(); + } + + public ComponentName getComponent() { + return component; + } + + @Override + public int hashCode() { + return component.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof AndroidComponentAddress) { + AndroidComponentAddress that = (AndroidComponentAddress) obj; + return component.equals(that.component); + } + return false; + } + + @Override + public String toString() { + return "AndroidComponentAddress[" + component + "]"; + } +} diff --git a/binder/src/main/java/io/grpc/binder/ApiConstants.java b/binder/src/main/java/io/grpc/binder/ApiConstants.java new file mode 100644 index 00000000000..42006014419 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/ApiConstants.java @@ -0,0 +1,31 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import android.content.Intent; +import io.grpc.ExperimentalApi; + +/** Constant parts of the gRPC binder transport public API. */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") +public final class ApiConstants { + private ApiConstants() {} + + /** + * Service Action: Identifies gRPC clients in a {@link android.app.Service#onBind(Intent)} call. + */ + public static final String ACTION_BIND = "grpc.io.action.BIND"; +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BindServiceFlags.java b/binder/src/main/java/io/grpc/binder/BindServiceFlags.java similarity index 99% rename from binder/src/main/java/io/grpc/binder/internal/BindServiceFlags.java rename to binder/src/main/java/io/grpc/binder/BindServiceFlags.java index 645476801f7..ae0736357bb 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BindServiceFlags.java +++ b/binder/src/main/java/io/grpc/binder/BindServiceFlags.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.binder.internal; +package io.grpc.binder; import static android.content.Context.BIND_ABOVE_CLIENT; import static android.content.Context.BIND_ADJUST_WITH_ACTIVITY; diff --git a/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java b/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java new file mode 100644 index 00000000000..99191cfad3c --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/BinderChannelBuilder.java @@ -0,0 +1,264 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import static com.google.common.base.Preconditions.checkNotNull; + +import android.app.Application; +import android.content.ComponentName; +import android.content.Context; +import androidx.core.content.ContextCompat; +import com.google.errorprone.annotations.DoNotCall; +import io.grpc.ChannelCredentials; +import io.grpc.ChannelLogger; +import io.grpc.CompressorRegistry; +import io.grpc.DecompressorRegistry; +import io.grpc.ExperimentalApi; +import io.grpc.ForwardingChannelBuilder; +import io.grpc.ManagedChannel; +import io.grpc.ManagedChannelBuilder; +import io.grpc.binder.internal.BinderTransport; +import io.grpc.internal.ClientTransportFactory; +import io.grpc.internal.ConnectionClientTransport; +import io.grpc.internal.FixedObjectPool; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ManagedChannelImplBuilder; +import io.grpc.internal.ManagedChannelImplBuilder.ClientTransportFactoryBuilder; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import java.net.SocketAddress; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import javax.annotation.Nullable; + +/** + * Builder for a gRPC channel which communicates with an Android bound service. + * + * @see Bound + * Services + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") +public final class BinderChannelBuilder + extends ForwardingChannelBuilder { + + /** + * Creates a channel builder that will bind to a remote Android service. + * + *

The underlying Android binding will be torn down when the channel becomes idle. This happens + * after 30 minutes without use by default but can be configured via {@link + * ManagedChannelBuilder#idleTimeout(long, TimeUnit)} or triggered manually with {@link + * ManagedChannel#enterIdle()}. + * + *

You the caller are responsible for managing the lifecycle of any channels built by the + * resulting builder. They will not be shut down automatically. + * + * @param targetAddress the {@link AndroidComponentAddress} referencing the service to bind to. + * @param sourceContext the context to bind from (e.g. The current Activity or Application). + * @return a new builder + */ + public static BinderChannelBuilder forAddress( + AndroidComponentAddress targetAddress, Context sourceContext) { + return new BinderChannelBuilder(targetAddress, sourceContext); + } + + /** + * Always fails. Call {@link #forAddress(AndroidComponentAddress, Context)} instead. + */ + @DoNotCall("Unsupported. Use forAddress(AndroidComponentAddress, Context) instead") + public static BinderChannelBuilder forAddress(String name, int port) { + throw new UnsupportedOperationException( + "call forAddress(AndroidComponentAddress, Context) instead"); + } + + /** + * Always fails. Call {@link #forAddress(AndroidComponentAddress, Context)} instead. + */ + @DoNotCall("Unsupported. Use forAddress(AndroidComponentAddress, Context) instead") + public static BinderChannelBuilder forTarget(String target) { + throw new UnsupportedOperationException( + "call forAddress(AndroidComponentAddress, Context) instead"); + } + + private final ManagedChannelImplBuilder managedChannelImplBuilder; + + private Executor mainThreadExecutor; + private ObjectPool schedulerPool = + SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); + private SecurityPolicy securityPolicy; + private InboundParcelablePolicy inboundParcelablePolicy; + private BindServiceFlags bindServiceFlags; + + private BinderChannelBuilder( + AndroidComponentAddress targetAddress, + Context sourceContext) { + mainThreadExecutor = ContextCompat.getMainExecutor(sourceContext); + securityPolicy = SecurityPolicies.internalOnly(); + inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; + bindServiceFlags = BindServiceFlags.DEFAULTS; + + final class BinderChannelTransportFactoryBuilder + implements ClientTransportFactoryBuilder { + @Override + public ClientTransportFactory buildClientTransportFactory() { + return new TransportFactory( + sourceContext, + mainThreadExecutor, + schedulerPool, + managedChannelImplBuilder.getOffloadExecutorPool(), + securityPolicy, + bindServiceFlags, + inboundParcelablePolicy); + } + } + + managedChannelImplBuilder = + new ManagedChannelImplBuilder( + targetAddress, + targetAddress.getAuthority(), + new BinderChannelTransportFactoryBuilder(), + null); + } + + @Override + protected ManagedChannelBuilder delegate() { + return managedChannelImplBuilder; + } + + /** Specifies certain optional aspects of the underlying Android Service binding. */ + public BinderChannelBuilder setBindServiceFlags(BindServiceFlags bindServiceFlags) { + this.bindServiceFlags = bindServiceFlags; + return this; + } + + /** + * Provides a custom scheduled executor service. + * + *

This is an optional parameter. If the user has not provided a scheduled executor service + * when the channel is built, the builder will use a static cached thread pool. + * + * @return this + */ + public BinderChannelBuilder scheduledExecutorService( + ScheduledExecutorService scheduledExecutorService) { + schedulerPool = + new FixedObjectPool<>(checkNotNull(scheduledExecutorService, "scheduledExecutorService")); + return this; + } + + /** + * Provides a custom {@link Executor} for accessing this application's main thread. + * + *

Optional. A default implementation will be used if no custom Executor is provided. + * + * @return this + */ + public BinderChannelBuilder mainThreadExecutor(Executor mainThreadExecutor) { + this.mainThreadExecutor = mainThreadExecutor; + return this; + } + + /** + * Provides a custom security policy. + * + *

This is optional. If the user has not provided a security policy, this channel will only + * communicate with the same application UID. + * + * @return this + */ + public BinderChannelBuilder securityPolicy(SecurityPolicy securityPolicy) { + this.securityPolicy = checkNotNull(securityPolicy, "securityPolicy"); + return this; + } + + /** Sets the policy for inbound parcelable objects. */ + public BinderChannelBuilder inboundParcelablePolicy( + InboundParcelablePolicy inboundParcelablePolicy) { + this.inboundParcelablePolicy = checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy"); + return this; + } + + /** Creates new binder transports. */ + private static final class TransportFactory implements ClientTransportFactory { + private final Context sourceContext; + private final Executor mainThreadExecutor; + private final ObjectPool scheduledExecutorPool; + private final ObjectPool offloadExecutorPool; + private final SecurityPolicy securityPolicy; + private final InboundParcelablePolicy inboundParcelablePolicy; + private final BindServiceFlags bindServiceFlags; + + private ScheduledExecutorService executorService; + private Executor offloadExecutor; + private boolean closed; + + TransportFactory( + Context sourceContext, + Executor mainThreadExecutor, + ObjectPool scheduledExecutorPool, + ObjectPool offloadExecutorPool, + SecurityPolicy securityPolicy, + BindServiceFlags bindServiceFlags, + InboundParcelablePolicy inboundParcelablePolicy) { + this.sourceContext = sourceContext; + this.mainThreadExecutor = mainThreadExecutor; + this.scheduledExecutorPool = scheduledExecutorPool; + this.offloadExecutorPool = offloadExecutorPool; + this.securityPolicy = securityPolicy; + this.bindServiceFlags = bindServiceFlags; + this.inboundParcelablePolicy = inboundParcelablePolicy; + + executorService = scheduledExecutorPool.getObject(); + offloadExecutor = offloadExecutorPool.getObject(); + } + + @Override + public ConnectionClientTransport newClientTransport( + SocketAddress addr, ClientTransportOptions options, ChannelLogger channelLogger) { + if (closed) { + throw new IllegalStateException("The transport factory is closed."); + } + return new BinderTransport.BinderClientTransport( + sourceContext, + (AndroidComponentAddress) addr, + bindServiceFlags, + mainThreadExecutor, + scheduledExecutorPool, + offloadExecutorPool, + securityPolicy, + inboundParcelablePolicy, + options.getEagAttributes()); + } + + @Override + public ScheduledExecutorService getScheduledExecutorService() { + return executorService; + } + + @Override + public SwapChannelCredentialsResult swapChannelCredentials(ChannelCredentials channelCreds) { + return null; + } + + @Override + public void close() { + closed = true; + executorService = scheduledExecutorPool.returnObject(executorService); + offloadExecutor = offloadExecutorPool.returnObject(offloadExecutor); + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java b/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java new file mode 100644 index 00000000000..9189b3935a9 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/BinderServerBuilder.java @@ -0,0 +1,178 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import android.app.Service; +import android.os.IBinder; +import com.google.common.base.Supplier; +import com.google.errorprone.annotations.DoNotCall; +import io.grpc.CompressorRegistry; +import io.grpc.DecompressorRegistry; +import io.grpc.ExperimentalApi; +import io.grpc.Server; +import io.grpc.ServerBuilder; +import io.grpc.ServerStreamTracer; +import io.grpc.binder.internal.BinderServer; +import io.grpc.binder.internal.BinderTransportSecurity; +import io.grpc.ForwardingServerBuilder; +import io.grpc.internal.FixedObjectPool; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.InternalServer; +import io.grpc.internal.ServerImplBuilder; +import io.grpc.internal.ServerImplBuilder.ClientTransportServersBuilder; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.SharedResourcePool; +import java.io.File; +import java.io.IOException; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import javax.annotation.Nullable; + +/** + * Builder for a server that services requests from an Android Service. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") +public final class BinderServerBuilder + extends ForwardingServerBuilder { + + /** + * Creates a server builder that will bind with the given name. + * + *

The listening {@link IBinder} associated with new {@link Server}s will be stored in {@code + * binderReceiver} upon {@link #build()}. + * + * @param service the concrete Android Service that will host this server. + * @param receiver an "out param" for the new {@link Server}'s listening {@link IBinder} + * @return a new builder + */ + public static BinderServerBuilder forService(Service service, IBinderReceiver receiver) { + return new BinderServerBuilder(service, receiver); + } + + /** + * Always fails. Call {@link #forService(Service, IBinderReceiver)} instead. + */ + @DoNotCall("Unsupported. Use forService() instead") + public static BinderServerBuilder forPort(int port) { + throw new UnsupportedOperationException("call forService() instead"); + } + + private final ServerImplBuilder serverImplBuilder; + private ObjectPool schedulerPool = + SharedResourcePool.forResource(GrpcUtil.TIMER_SERVICE); + private ServerSecurityPolicy securityPolicy; + private InboundParcelablePolicy inboundParcelablePolicy; + + private BinderServerBuilder(Service service, IBinderReceiver binderReceiver) { + securityPolicy = SecurityPolicies.serverInternalOnly(); + inboundParcelablePolicy = InboundParcelablePolicy.DEFAULT; + + serverImplBuilder = new ServerImplBuilder(streamTracerFactories -> { + BinderServer server = new BinderServer( + AndroidComponentAddress.forContext(service), + schedulerPool, + streamTracerFactories, + securityPolicy, + inboundParcelablePolicy); + binderReceiver.set(server.getHostBinder()); + return server; + }); + + // Disable compression by default, since there's little benefit when all communication is + // on-device, and it means sending supported-encoding headers with every call. + decompressorRegistry(DecompressorRegistry.emptyInstance()); + compressorRegistry(CompressorRegistry.newEmptyInstance()); + + // Disable stats and tracing by default. + serverImplBuilder.setStatsEnabled(false); + serverImplBuilder.setTracingEnabled(false); + + BinderTransportSecurity.installAuthInterceptor(this); + } + + @Override + protected ServerBuilder delegate() { + return serverImplBuilder; + } + + /** Enable stats collection using census. */ + public BinderServerBuilder enableStats() { + serverImplBuilder.setStatsEnabled(true); + return this; + } + + /** Enable tracing using census. */ + public BinderServerBuilder enableTracing() { + serverImplBuilder.setTracingEnabled(true); + return this; + } + + /** + * Provides a custom scheduled executor service. + * + *

It's an optional parameter. If the user has not provided a scheduled executor service when + * the channel is built, the builder will use a static cached thread pool. + * + * @return this + */ + public BinderServerBuilder scheduledExecutorService( + ScheduledExecutorService scheduledExecutorService) { + schedulerPool = + new FixedObjectPool<>(checkNotNull(scheduledExecutorService, "scheduledExecutorService")); + return this; + } + + /** + * Provides a custom security policy. + * + *

This is optional. If the user has not provided a security policy, the server will default to + * only accepting calls from the same application UID. + * + * @return this + */ + public BinderServerBuilder securityPolicy(ServerSecurityPolicy securityPolicy) { + this.securityPolicy = checkNotNull(securityPolicy, "securityPolicy"); + return this; + } + + /** Sets the policy for inbound parcelable objects. */ + public BinderServerBuilder inboundParcelablePolicy( + InboundParcelablePolicy inboundParcelablePolicy) { + this.inboundParcelablePolicy = checkNotNull(inboundParcelablePolicy, "inboundParcelablePolicy"); + return this; + } + + @Override + public BinderServerBuilder useTransportSecurity(File certChain, File privateKey) { + throw new UnsupportedOperationException("TLS not supported in BinderServer"); + } + + /** + * Builds a {@link Server} according to this builder's parameters and stores its listening {@link + * IBinder} in the {@link IBinderReceiver} passed to {@link #forService(Service, + * IBinderReceiver)}. + * + * @return the new Server + */ + @Override // For javadoc refinement only. + public Server build() { + return super.build(); + } +} diff --git a/binder/src/main/java/io/grpc/binder/IBinderReceiver.java b/binder/src/main/java/io/grpc/binder/IBinderReceiver.java new file mode 100644 index 00000000000..bd8e1f50af9 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/IBinderReceiver.java @@ -0,0 +1,40 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import android.os.IBinder; +import io.grpc.ExperimentalApi; +import javax.annotation.Nullable; + +/** A container for at most one instance of {@link IBinder}, useful as an "out parameter". */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") +public final class IBinderReceiver { + @Nullable private IBinder value; + + /** Constructs a new, initially empty, container. */ + public IBinderReceiver() {} + + /** Returns the contents of this container or null if it is empty. */ + @Nullable + public synchronized IBinder get() { + return value; + } + + public synchronized void set(IBinder value) { + this.value = value; + } +} diff --git a/binder/src/main/java/io/grpc/binder/InboundParcelablePolicy.java b/binder/src/main/java/io/grpc/binder/InboundParcelablePolicy.java new file mode 100644 index 00000000000..23470829d74 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/InboundParcelablePolicy.java @@ -0,0 +1,107 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import com.google.common.base.Preconditions; +import io.grpc.ExperimentalApi; + +/** + * Contains the policy for accepting inbound parcelable objects. + * + *

Since parcelables are generally error prone and parsing a parcelable can have unspecified + * side-effects, their use is generally discouraged. Some use cases require them though (E.g. when + * dealing with some platform-defined objects), so this policy allows them to be supported. + * + *

Parcelables can arrive as RPC messages, or as metadata values (in headers or footers). The + * default is to reject both cases, failing the RPC with a PERMISSION_DENED status code. This policy + * can be updated to accept one or both cases. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") +public final class InboundParcelablePolicy { + + /** The maximum allowed total size of Parcelables in metadata. */ + public static final int MAX_PARCELABLE_METADATA_SIZE = 32 * 1024; + + public static final InboundParcelablePolicy DEFAULT = + new InboundParcelablePolicy(false, false, MAX_PARCELABLE_METADATA_SIZE); + + private final boolean acceptParcelableMetadataValues; + private final boolean acceptParcelableMessages; + private final int maxParcelableMetadataSize; + + private InboundParcelablePolicy( + boolean acceptParcelableMetadataValues, + boolean acceptParcelableMessages, + int maxParcelableMetadataSize) { + this.acceptParcelableMetadataValues = acceptParcelableMetadataValues; + this.acceptParcelableMessages = acceptParcelableMessages; + this.maxParcelableMetadataSize = maxParcelableMetadataSize; + } + + public boolean shouldAcceptParcelableMetadataValues() { + return acceptParcelableMetadataValues; + } + + public boolean shouldAcceptParcelableMessages() { + return acceptParcelableMessages; + } + + public int getMaxParcelableMetadataSize() { + return maxParcelableMetadataSize; + } + + public static Builder newBuilder() { + return new Builder(); + } + + /** A builder for InboundParcelablePolicy. */ + public static final class Builder { + private boolean acceptParcelableMetadataValues = DEFAULT.acceptParcelableMetadataValues; + private boolean acceptParcelableMessages = DEFAULT.acceptParcelableMessages; + private int maxParcelableMetadataSize = DEFAULT.maxParcelableMetadataSize; + + /** Sets whether the policy should accept parcelable metadata values. */ + public Builder setAcceptParcelableMetadataValues(boolean acceptParcelableMetadataValues) { + this.acceptParcelableMetadataValues = acceptParcelableMetadataValues; + return this; + } + + /** Sets whether the policy should accept parcelable messages. */ + public Builder setAcceptParcelableMessages(boolean acceptParcelableMessages) { + this.acceptParcelableMessages = acceptParcelableMessages; + return this; + } + + /** + * Sets the maximum allowed total size of parcelables in metadata. + * + * @param maxParcelableMetadataSize must not exceed {@link #MAX_PARCELABLE_METADATA_SIZE} + */ + public Builder setMaxParcelableMetadataSize(int maxParcelableMetadataSize) { + Preconditions.checkArgument( + maxParcelableMetadataSize <= MAX_PARCELABLE_METADATA_SIZE, + "Parcelable metadata size can't exceed MAX_PARCELABLE_METADATA_SIZE."); + this.maxParcelableMetadataSize = maxParcelableMetadataSize; + return this; + } + + public InboundParcelablePolicy build() { + return new InboundParcelablePolicy( + acceptParcelableMetadataValues, acceptParcelableMessages, maxParcelableMetadataSize); + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/ParcelableUtils.java b/binder/src/main/java/io/grpc/binder/ParcelableUtils.java new file mode 100644 index 00000000000..164de7de8b8 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/ParcelableUtils.java @@ -0,0 +1,59 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import android.os.Parcelable; +import io.grpc.ExperimentalApi; +import io.grpc.Metadata; +import io.grpc.binder.internal.MetadataHelper; + +/** + * Utility methods for using Android Parcelable objects with gRPC. + * + *

This class models the same pattern as the {@code ProtoLiteUtils} class. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") +public final class ParcelableUtils { + + private ParcelableUtils() {} + + /** + * Create a {@link Metadata.Key} for passing a Parcelable object in the metadata of an RPC, + * treating instances as mutable. + * + *

Note:Parcelables can only be sent across in-process and binder channels. + */ + public static

Metadata.Key

metadataKey( + String name, Parcelable.Creator

creator) { + return Metadata.Key.of( + name, new MetadataHelper.ParcelableMetadataMarshaller

(creator, false)); + } + + /** + * Create a {@link Metadata.Key} for passing a Parcelable object in the metadata of an RPC, + * treating instances as immutable. Immutability may be used for optimization purposes (e.g. Not + * copying for in-process calls). + * + *

Note:Parcelables can only be sent across in-process and binder channels. + */ + public static

Metadata.Key

metadataKeyForImmutableType( + String name, Parcelable.Creator

creator) { + return Metadata.Key.of( + name, new MetadataHelper.ParcelableMetadataMarshaller

(creator, true)); + } +} + diff --git a/binder/src/main/java/io/grpc/binder/SecurityPolicies.java b/binder/src/main/java/io/grpc/binder/SecurityPolicies.java new file mode 100644 index 00000000000..be46b9e3e54 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/SecurityPolicies.java @@ -0,0 +1,58 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import android.os.Process; +import io.grpc.ExperimentalApi; +import io.grpc.Status; +import javax.annotation.CheckReturnValue; + +/** Static factory methods for creating standard security policies. */ +@CheckReturnValue +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") +public final class SecurityPolicies { + + private static final int MY_UID = Process.myUid(); + + private SecurityPolicies() {} + + public static ServerSecurityPolicy serverInternalOnly() { + return new ServerSecurityPolicy(); + } + + public static SecurityPolicy internalOnly() { + return new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + return uid == MY_UID + ? Status.OK + : Status.PERMISSION_DENIED.withDescription( + "Rejected by (internal-only) security policy"); + } + }; + } + + public static SecurityPolicy permissionDenied(String description) { + Status denied = Status.PERMISSION_DENIED.withDescription(description); + return new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + return denied; + } + }; + } +} diff --git a/binder/src/main/java/io/grpc/binder/SecurityPolicy.java b/binder/src/main/java/io/grpc/binder/SecurityPolicy.java new file mode 100644 index 00000000000..55aa33e0216 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/SecurityPolicy.java @@ -0,0 +1,53 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import io.grpc.ExperimentalApi; +import io.grpc.Status; +import javax.annotation.CheckReturnValue; + +/** + * Decides whether a given Android UID is authorized to access some resource. + * + *

IMPORTANT For any concrete extensions of this class, it's assumed that the + * authorization status of a given UID will not change as long as a process with that UID is + * alive. + * + *

In order words, we expect the security policy for a given transport to remain constant for the + * lifetime of that transport. This is considered acceptable because no transport will survive the + * re-installation of the applications involved. + */ +@CheckReturnValue +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") +public abstract class SecurityPolicy { + + public SecurityPolicy() {} + + /** + * Decides whether the given Android UID is authorized. (Validity is implementation dependent). + * + *

IMPORTANT: This method may block for extended periods of time. + * + *

As long as any given UID has active processes, this method should return the same value for + * that UID. In order words, policy changes which occur while a transport instance is active, will + * have no effect on that transport instance. + * + * @param uid The Android UID to authenticate. + * @return A gRPC {@link Status} object, with OK indicating authorized. + */ + public abstract Status checkAuthorization(int uid); +} diff --git a/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java b/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java new file mode 100644 index 00000000000..46a124e1f47 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/ServerSecurityPolicy.java @@ -0,0 +1,86 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import com.google.common.collect.ImmutableMap; +import io.grpc.ExperimentalApi; +import io.grpc.Status; +import java.util.HashMap; +import java.util.Map; +import javax.annotation.CheckReturnValue; + +/** + * A security policy for a gRPC server. + * + * Contains a default policy, and optional policies for each server. + */ +@ExperimentalApi("https://github.com/grpc/grpc-java/issues/8022") +public final class ServerSecurityPolicy { + + private final SecurityPolicy defaultPolicy; + private final ImmutableMap perServicePolicies; + + ServerSecurityPolicy() { + this(ImmutableMap.of()); + } + + private ServerSecurityPolicy(ImmutableMap perServicePolicies) { + this.defaultPolicy = SecurityPolicies.internalOnly(); + this.perServicePolicies = perServicePolicies; + } + + /** + * Return whether the given Android UID is authorized to access a particular service. + * + * IMPORTANT: This method may block for extended periods of time. + * + * @param uid The Android UID to authenticate. + * @param serviceName The name of the gRPC service being called. + */ + @CheckReturnValue + public Status checkAuthorizationForService(int uid, String serviceName) { + return perServicePolicies.getOrDefault(serviceName, defaultPolicy).checkAuthorization(uid); + } + + public static Builder newBuilder() { + return new Builder(); + } + + /** Builder for an AndroidServiceSecurityPolicy. */ + public static final class Builder { + private final Map grpcServicePolicies; + + private Builder() { + grpcServicePolicies = new HashMap<>(); + } + + /** + * Specify a policy specific to a particular gRPC service. + * + * @param serviceName The fully qualified name of the gRPC service (from the proto). + * @param policy The security policy to apply to the service. + */ + public Builder servicePolicy(String serviceName, SecurityPolicy policy) { + grpcServicePolicies.put(serviceName, policy); + return this; + } + + public ServerSecurityPolicy build() { + return new ServerSecurityPolicy(ImmutableMap.copyOf(grpcServicePolicies)); + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderServer.java b/binder/src/main/java/io/grpc/binder/internal/BinderServer.java new file mode 100644 index 00000000000..74ed5caceea --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/BinderServer.java @@ -0,0 +1,165 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; + +import android.os.Binder; +import android.os.IBinder; +import android.os.Parcel; +import com.google.common.collect.ImmutableList; +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.InternalChannelz.SocketStats; +import io.grpc.InternalInstrumented; +import io.grpc.ServerStreamTracer; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.InboundParcelablePolicy; +import io.grpc.binder.ServerSecurityPolicy; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.InternalServer; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.ServerListener; +import io.grpc.internal.SharedResourceHolder; +import java.io.IOException; +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.ScheduledExecutorService; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +/** + * A gRPC InternalServer which accepts connections via a host AndroidService. + * + *

Multiple incoming connections transports may be active at a time. + * + * IMPORTANT: This implementation must comply with this published wire format. + * https://github.com/grpc/proposal/blob/master/L73-java-binderchannel/wireformat.md + */ +@ThreadSafe +public final class BinderServer implements InternalServer, LeakSafeOneWayBinder.TransactionHandler { + + private final ObjectPool executorServicePool; + private final ImmutableList streamTracerFactories; + private final AndroidComponentAddress listenAddress; + private final LeakSafeOneWayBinder hostServiceBinder; + private final ServerSecurityPolicy serverSecurityPolicy; + private final InboundParcelablePolicy inboundParcelablePolicy; + + @GuardedBy("this") + private ServerListener listener; + + @GuardedBy("this") + private ScheduledExecutorService executorService; + + @GuardedBy("this") + private boolean shutdown; + + public BinderServer( + AndroidComponentAddress listenAddress, + ObjectPool executorServicePool, + List streamTracerFactories, + ServerSecurityPolicy serverSecurityPolicy, + InboundParcelablePolicy inboundParcelablePolicy) { + this.listenAddress = listenAddress; + this.executorServicePool = executorServicePool; + this.streamTracerFactories = + ImmutableList.copyOf(checkNotNull(streamTracerFactories, "streamTracerFactories")); + this.serverSecurityPolicy = checkNotNull(serverSecurityPolicy, "serverSecurityPolicy"); + this.inboundParcelablePolicy = inboundParcelablePolicy; + hostServiceBinder = new LeakSafeOneWayBinder(this); + } + + /** Return the binder we're listening on. */ + public IBinder getHostBinder() { + return hostServiceBinder; + } + + @Override + public synchronized void start(ServerListener serverListener) throws IOException { + this.listener = serverListener; + executorService = executorServicePool.getObject(); + } + + @Override + public SocketAddress getListenSocketAddress() { + return listenAddress; + } + + @Override + public List getListenSocketAddresses() { + return ImmutableList.of(listenAddress); + } + + @Override + public InternalInstrumented getListenSocketStats() { + return null; + } + + @Override + @Nullable + public List> getListenSocketStatsList() { + return null; + } + + @Override + public synchronized void shutdown() { + if (!shutdown) { + shutdown = true; + // Break the connection to the binder. We'll receive no more transactions. + hostServiceBinder.detach(); + listener.serverShutdown(); + executorService = executorServicePool.returnObject(executorService); + } + } + + @Override + public String toString() { + return "BinderServer[" + listenAddress + "]"; + } + + @Override + public synchronized boolean handleTransaction(int code, Parcel parcel) { + if (code == BinderTransport.SETUP_TRANSPORT) { + int version = parcel.readInt(); + // If the client-provided version is more recent, we accept the connection, + // but specify the older version which we support. + if (version >= BinderTransport.EARLIEST_SUPPORTED_WIRE_FORMAT_VERSION) { + IBinder callbackBinder = parcel.readStrongBinder(); + if (callbackBinder != null) { + int callingUid = Binder.getCallingUid(); + Attributes.Builder attrsBuilder = + Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, listenAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, new BoundClientAddress(callingUid)) + .set(BinderTransport.REMOTE_UID, callingUid) + .set(BinderTransport.SERVER_AUTHORITY, listenAddress.getAuthority()) + .set(BinderTransport.INBOUND_PARCELABLE_POLICY, inboundParcelablePolicy); + BinderTransportSecurity.attachAuthAttrs(attrsBuilder, callingUid, serverSecurityPolicy); + // Create a new transport and let our listener know about it. + BinderTransport.BinderServerTransport transport = + new BinderTransport.BinderServerTransport( + executorServicePool, attrsBuilder.build(), streamTracerFactories, callbackBinder); + transport.setServerTransportListener(listener.transportCreated(transport)); + return true; + } + } + } + return false; + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java new file mode 100644 index 00000000000..508f0351b0e --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java @@ -0,0 +1,920 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.Futures.immediateFuture; + +import android.content.Context; +import android.os.Binder; +import android.os.DeadObjectException; +import android.os.IBinder; +import android.os.Parcel; +import android.os.Process; +import android.os.RemoteException; +import android.os.TransactionTooLargeException; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.util.concurrent.ListenableFuture; +import io.grpc.Attributes; +import io.grpc.CallOptions; +import io.grpc.Grpc; +import io.grpc.InternalChannelz.SocketStats; +import io.grpc.InternalLogId; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.SecurityLevel; +import io.grpc.ServerStreamTracer; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.ApiConstants; +import io.grpc.binder.BindServiceFlags; +import io.grpc.binder.InboundParcelablePolicy; +import io.grpc.binder.SecurityPolicy; +import io.grpc.internal.ClientStream; +import io.grpc.internal.ConnectionClientTransport; +import io.grpc.internal.FailingClientStream; +import io.grpc.internal.GrpcAttributes; +import io.grpc.internal.GrpcUtil; +import io.grpc.internal.ManagedClientTransport; +import io.grpc.internal.ObjectPool; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerTransport; +import io.grpc.internal.ServerTransportListener; +import io.grpc.internal.StatsTraceContext; +import io.grpc.internal.TimeProvider; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.Executor; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.CheckReturnValue; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; +import javax.annotation.concurrent.ThreadSafe; + +/** + * Base class for binder-based gRPC transport. + * + *

This is used on both the client and service sides of the transport. + * + *

A note on synchronization. The nature of this class's interaction with each stream + * (bi-directional communication between gRPC calls and binder transactions) means that acquiring + * multiple locks in two different orders can happen easily. E.g. binder transactions will arrive in + * this class, and need to passed to a stream instance, whereas gRPC calls on a stream instance will + * need to call into this class to send a transaction (possibly waiting for the transport to become + * ready). + * + *

The split between Outbound & Inbound helps reduce this risk, but not entirely remove it. + * + *

For this reason, while most state within this class is guarded by this instance, methods + * exposed to individual stream instances need to use atomic or volatile types, since those calls + * will already be synchronized on the individual RPC objects. + * + *

IMPORTANT: This implementation must comply with this published wire format. + * https://github.com/grpc/proposal/blob/master/L73-java-binderchannel/wireformat.md + */ +@ThreadSafe +public abstract class BinderTransport + implements LeakSafeOneWayBinder.TransactionHandler, IBinder.DeathRecipient { + + private static final Logger logger = Logger.getLogger(BinderTransport.class.getName()); + + /** + * Attribute used to store the Android UID of the remote app. This is guaranteed to be set on any + * active transport. + */ + static final Attributes.Key REMOTE_UID = Attributes.Key.create("remote-uid"); + + /** The authority of the server. */ + static final Attributes.Key SERVER_AUTHORITY = Attributes.Key.create("server-authority"); + + /** A transport attribute to hold the {@link InboundParcelablePolicy}. */ + static final Attributes.Key INBOUND_PARCELABLE_POLICY = + Attributes.Key.create("inbound-parcelable-policy"); + + /** + * Version code for this wire format. + * + *

Should this change, we should still endeavor to support earlier wire-format versions. If + * that's not possible, {@link EARLIEST_SUPPORTED_WIRE_FORMAT_VERSION} should be updated below. + */ + static final int WIRE_FORMAT_VERSION = 1; + + /** The version code of the earliest wire format we support. */ + static final int EARLIEST_SUPPORTED_WIRE_FORMAT_VERSION = 1; + + /** The max number of "in-flight" bytes before we start buffering transactions. */ + private static final int TRANSACTION_BYTES_WINDOW = 128 * 1024; + + /** The number of in-flight bytes we should receive between sendings acks to our peer. */ + private static final int TRANSACTION_BYTES_WINDOW_FORCE_ACK = 16 * 1024; + + /** + * Sent from the client to host service binder to initiate a new transport, and from the host to + * the binder. and from the host s Followed by: int wire_protocol_version IBinder + * client_transports_callback_binder + */ + static final int SETUP_TRANSPORT = IBinder.FIRST_CALL_TRANSACTION; + + /** Send to shutdown the transport from either end. */ + static final int SHUTDOWN_TRANSPORT = IBinder.FIRST_CALL_TRANSACTION + 1; + + /** Send to acknowledge receipt of rpc bytes, for flow control. */ + static final int ACKNOWLEDGE_BYTES = IBinder.FIRST_CALL_TRANSACTION + 2; + + /** A ping request. */ + private static final int PING = IBinder.FIRST_CALL_TRANSACTION + 3; + + /** A response to a ping. */ + private static final int PING_RESPONSE = IBinder.FIRST_CALL_TRANSACTION + 4; + + /** Reserved transaction IDs for any special events we might need. */ + private static final int RESERVED_TRANSACTIONS = 1000; + + /** The first call ID we can use. */ + private static final int FIRST_CALL_ID = IBinder.FIRST_CALL_TRANSACTION + RESERVED_TRANSACTIONS; + + /** The last call ID we can use. */ + private static final int LAST_CALL_ID = IBinder.LAST_CALL_TRANSACTION; + + /** The states of this transport. */ + protected enum TransportState { + NOT_STARTED, // We haven't been started yet. + SETUP, // We're setting up the connection. + READY, // The transport is ready. + SHUTDOWN, // We've been shutdown and won't accept any additional calls (thought existing calls + // may continue). + SHUTDOWN_TERMINATED // We've been shutdown completely (or we failed to start). We can't send or + // receive any data. + } + + private final ObjectPool executorServicePool; + private final ScheduledExecutorService scheduledExecutorService; + private final InternalLogId logId; + private final LeakSafeOneWayBinder incomingBinder; + + protected final ConcurrentHashMap> ongoingCalls; + + @GuardedBy("this") + protected Attributes attributes; + + @GuardedBy("this") + private TransportState transportState = TransportState.NOT_STARTED; + + @GuardedBy("this") + @Nullable + protected Status shutdownStatus; + + @Nullable private IBinder outgoingBinder; + + /** The number of outgoing bytes we've transmitted. */ + private final AtomicLong numOutgoingBytes; + + /** The number of incoming bytes we've received. */ + private final AtomicLong numIncomingBytes; + + /** The number of our outgoing bytes our peer has told us it received. */ + private long acknowledgedOutgoingBytes; + + /** The number of incoming bytes we've told our peer we've received. */ + private long acknowledgedIncomingBytes; + + /** + * Whether there are too many unacknowledged outgoing bytes to allow more RPCs right now. This is + * volatile because it'll be read without holding the lock. + */ + private volatile boolean transmitWindowFull; + + private BinderTransport( + ObjectPool executorServicePool, + Attributes attributes, + InternalLogId logId) { + this.executorServicePool = executorServicePool; + this.attributes = attributes; + this.logId = logId; + scheduledExecutorService = executorServicePool.getObject(); + incomingBinder = new LeakSafeOneWayBinder(this); + ongoingCalls = new ConcurrentHashMap<>(); + numOutgoingBytes = new AtomicLong(); + numIncomingBytes = new AtomicLong(); + } + + // Override in child class. + public final ScheduledExecutorService getScheduledExecutorService() { + return scheduledExecutorService; + } + + // Override in child class. + public final ListenableFuture getStats() { + return immediateFuture(null); + } + + // Override in child class. + public final InternalLogId getLogId() { + return logId; + } + + // Override in child class. + public final synchronized Attributes getAttributes() { + return attributes; + } + + /** + * Returns whether this transport is able to send rpc transactions. Intentionally unsynchronized + * since this will be called while Outbound is held. + */ + final boolean isReady() { + return !transmitWindowFull; + } + + abstract void notifyShutdown(Status shutdownStatus); + + abstract void notifyTerminated(); + + void releaseExecutors() { + executorServicePool.returnObject(scheduledExecutorService); + } + + @GuardedBy("this") + boolean inState(TransportState transportState) { + return this.transportState == transportState; + } + + @GuardedBy("this") + boolean isShutdown() { + return inState(TransportState.SHUTDOWN) || inState(TransportState.SHUTDOWN_TERMINATED); + } + + @GuardedBy("this") + final void setState(TransportState newState) { + checkTransition(transportState, newState); + transportState = newState; + } + + @GuardedBy("this") + protected boolean setOutgoingBinder(IBinder binder) { + this.outgoingBinder = binder; + try { + binder.linkToDeath(this, 0); + return true; + } catch (RemoteException re) { + return false; + } + } + + @Override + public synchronized void binderDied() { + shutdownInternal(Status.UNAVAILABLE.withDescription("binderDied"), true); + } + + @GuardedBy("this") + final void shutdownInternal(Status shutdownStatus, boolean forceTerminate) { + if (!isShutdown()) { + this.shutdownStatus = shutdownStatus; + setState(TransportState.SHUTDOWN); + notifyShutdown(shutdownStatus); + } + if (!inState(TransportState.SHUTDOWN_TERMINATED) + && (forceTerminate || ongoingCalls.isEmpty())) { + incomingBinder.detach(); + setState(TransportState.SHUTDOWN_TERMINATED); + sendShutdownTransaction(); + ArrayList> calls = new ArrayList<>(ongoingCalls.values()); + ongoingCalls.clear(); + scheduledExecutorService.execute( + () -> { + for (Inbound inbound : calls) { + synchronized (inbound) { + inbound.closeAbnormal(shutdownStatus); + } + } + notifyTerminated(); + releaseExecutors(); + }); + } + } + + @GuardedBy("this") + final void sendSetupTransaction() { + sendSetupTransaction(checkNotNull(outgoingBinder)); + } + + @GuardedBy("this") + final void sendSetupTransaction(IBinder iBinder) { + Parcel parcel = Parcel.obtain(); + parcel.writeInt(WIRE_FORMAT_VERSION); + parcel.writeStrongBinder(incomingBinder); + try { + if (!iBinder.transact(SETUP_TRANSPORT, parcel, null, IBinder.FLAG_ONEWAY)) { + shutdownInternal( + Status.UNAVAILABLE.withDescription("Failed sending SETUP_TRANSPORT transaction"), true); + } + } catch (RemoteException re) { + shutdownInternal(statusFromRemoteException(re), true); + } + parcel.recycle(); + } + + @GuardedBy("this") + private final void sendShutdownTransaction() { + if (outgoingBinder != null) { + try { + outgoingBinder.unlinkToDeath(this, 0); + } catch (NoSuchElementException e) { + // Ignore. + } + Parcel parcel = Parcel.obtain(); + try { + outgoingBinder.transact(SHUTDOWN_TRANSPORT, parcel, null, IBinder.FLAG_ONEWAY); + } catch (RemoteException re) { + // Ignore. + } + parcel.recycle(); + } + } + + protected synchronized void sendPing(int id) throws StatusException { + if (inState(TransportState.SHUTDOWN_TERMINATED)) { + throw shutdownStatus.asException(); + } else if (outgoingBinder == null) { + throw Status.FAILED_PRECONDITION.withDescription("Transport not ready.").asException(); + } else { + Parcel parcel = Parcel.obtain(); + parcel.writeInt(id); + try { + outgoingBinder.transact(PING, parcel, null, IBinder.FLAG_ONEWAY); + } catch (RemoteException re) { + throw statusFromRemoteException(re).asException(); + } finally { + parcel.recycle(); + } + } + } + + protected void unregisterInbound(Inbound inbound) { + unregisterCall(inbound.callId); + } + + final void unregisterCall(int callId) { + boolean removed = (ongoingCalls.remove(callId) != null); + if (removed && ongoingCalls.isEmpty()) { + // Possibly shutdown (not synchronously, since inbound is held). + scheduledExecutorService.execute( + () -> { + synchronized (this) { + if (inState(TransportState.SHUTDOWN)) { + // No more ongoing calls, and we're shutdown. Finish the shutdown. + shutdownInternal(shutdownStatus, true); + } + } + }); + } + } + + final void sendTransaction(int callId, Parcel parcel) throws StatusException { + int dataSize = parcel.dataSize(); + try { + if (!outgoingBinder.transact(callId, parcel, null, IBinder.FLAG_ONEWAY)) { + throw Status.UNAVAILABLE.withDescription("Failed sending transaction").asException(); + } + } catch (RemoteException re) { + throw statusFromRemoteException(re).asException(); + } + long nob = numOutgoingBytes.addAndGet(dataSize); + if ((nob - acknowledgedOutgoingBytes) > TRANSACTION_BYTES_WINDOW) { + logger.log(Level.FINE, "transmist window full. Outgoing=" + nob + " Ack'd Outgoing=" + + acknowledgedOutgoingBytes + " " + this); + transmitWindowFull = true; + } + } + + final void sendOutOfBandClose(int callId, Status status) { + Parcel parcel = Parcel.obtain(); + parcel.writeInt(0); // Placeholder for flags. Will be filled in below. + int flags = TransactionUtils.writeStatus(parcel, status); + TransactionUtils.fillInFlags(parcel, flags | TransactionUtils.FLAG_OUT_OF_BAND_CLOSE); + try { + sendTransaction(callId, parcel); + } catch (StatusException e) { + logger.log(Level.WARNING, "Failed sending oob close transaction", e); + } + parcel.recycle(); + } + + @Override + public final boolean handleTransaction(int code, Parcel parcel) { + if (code < FIRST_CALL_ID) { + synchronized (this) { + switch (code) { + case ACKNOWLEDGE_BYTES: + handleAcknowledgedBytes(parcel.readLong()); + break; + case SHUTDOWN_TRANSPORT: + shutdownInternal( + Status.UNAVAILABLE.withDescription("transport shutdown by peer"), true); + break; + case SETUP_TRANSPORT: + handleSetupTransport(parcel); + break; + case PING: + handlePing(parcel); + break; + case PING_RESPONSE: + handlePingResponse(parcel); + break; + default: + return false; + } + return true; + } + } else { + int size = parcel.dataSize(); + Inbound inbound = ongoingCalls.get(code); + if (inbound == null) { + synchronized (this) { + if (!isShutdown()) { + // Create a new inbound. Strictly speaking we could end up doing this twice on + // two threads, hence the need to use putIfAbsent, and check its result. + inbound = createInbound(code); + if (inbound != null) { + Inbound inbound2 = ongoingCalls.putIfAbsent(code, inbound); + if (inbound2 != null) { + inbound = inbound2; + } + } + } + } + } + if (inbound != null) { + inbound.handleTransaction(parcel); + } + long nib = numIncomingBytes.addAndGet(size); + if ((nib - acknowledgedIncomingBytes) > TRANSACTION_BYTES_WINDOW_FORCE_ACK) { + synchronized (this) { + sendAcknowledgeBytes(checkNotNull(outgoingBinder)); + } + } + return true; + } + } + + @Nullable + @GuardedBy("this") + protected Inbound createInbound(int callId) { + return null; + } + + @GuardedBy("this") + protected void handleSetupTransport(Parcel parcel) {} + + @GuardedBy("this") + private final void handlePing(Parcel parcel) { + if (transportState == TransportState.READY) { + try { + outgoingBinder.transact(PING_RESPONSE, parcel, null, IBinder.FLAG_ONEWAY); + } catch (RemoteException re) { + // Ignore. + } + } + } + + @GuardedBy("this") + protected void handlePingResponse(Parcel parcel) {} + + @GuardedBy("this") + private void sendAcknowledgeBytes(IBinder iBinder) { + // Send a transaction to acknowledge reception of incoming data. + long n = numIncomingBytes.get(); + acknowledgedIncomingBytes = n; + Parcel parcel = Parcel.obtain(); + parcel.writeLong(n); + try { + if (!iBinder.transact(ACKNOWLEDGE_BYTES, parcel, null, IBinder.FLAG_ONEWAY)) { + shutdownInternal( + Status.UNAVAILABLE.withDescription("Failed sending ack bytes transaction"), true); + } + } catch (RemoteException re) { + shutdownInternal(statusFromRemoteException(re), true); + } + parcel.recycle(); + } + + @GuardedBy("this") + final void handleAcknowledgedBytes(long numBytes) { + // The remote side has acknowledged reception of rpc data. + // (update with Math.max in case transactions are delivered out of order). + acknowledgedOutgoingBytes = wrapAwareMax(acknowledgedOutgoingBytes, numBytes); + if ((numOutgoingBytes.get() - acknowledgedOutgoingBytes) < TRANSACTION_BYTES_WINDOW + && transmitWindowFull) { + logger.log(Level.FINE, + "handleAcknowledgedBytes: Transmit Window No-Longer Full. Unblock calls: " + this); + // We're ready again, and need to poke any waiting transactions. + transmitWindowFull = false; + for (Inbound inbound : ongoingCalls.values()) { + inbound.onTransportReady(); + } + } + } + + private static final long wrapAwareMax(long a, long b) { + return a - b < 0 ? b : a; + } + + /** Concrete client-side transport implementation. */ + @ThreadSafe + public static final class BinderClientTransport extends BinderTransport + implements ConnectionClientTransport, Bindable.Observer { + + private final ObjectPool offloadExecutorPool; + private final Executor offloadExecutor; + private final SecurityPolicy securityPolicy; + private final Bindable serviceBinding; + /** Number of ongoing calls which keep this transport "in-use". */ + private final AtomicInteger numInUseStreams; + + private final PingTracker pingTracker; + + @Nullable private ManagedClientTransport.Listener clientTransportListener; + + @GuardedBy("this") + private int latestCallId = FIRST_CALL_ID; + + public BinderClientTransport( + Context sourceContext, + AndroidComponentAddress targetAddress, + BindServiceFlags bindServiceFlags, + Executor mainThreadExecutor, + ObjectPool executorServicePool, + ObjectPool offloadExecutorPool, + SecurityPolicy securityPolicy, + InboundParcelablePolicy inboundParcelablePolicy, + Attributes eagAttrs) { + super( + executorServicePool, + buildClientAttributes(eagAttrs, sourceContext, targetAddress, inboundParcelablePolicy), + buildLogId(sourceContext, targetAddress)); + this.offloadExecutorPool = offloadExecutorPool; + this.securityPolicy = securityPolicy; + this.offloadExecutor = offloadExecutorPool.getObject(); + numInUseStreams = new AtomicInteger(); + pingTracker = new PingTracker(TimeProvider.SYSTEM_TIME_PROVIDER, (id) -> sendPing(id)); + + serviceBinding = + new ServiceBinding( + mainThreadExecutor, + sourceContext, + targetAddress.getComponent(), + ApiConstants.ACTION_BIND, + bindServiceFlags.toInteger(), + this); + } + + @Override + void releaseExecutors() { + super.releaseExecutors(); + offloadExecutorPool.returnObject(offloadExecutor); + } + + @Override + public synchronized void onBound(IBinder binder) { + sendSetupTransaction(binder); + } + + @Override + public synchronized void onUnbound(Status reason) { + shutdownInternal(reason, true); + } + + @CheckReturnValue + @Override + public synchronized Runnable start(ManagedClientTransport.Listener clientTransportListener) { + this.clientTransportListener = checkNotNull(clientTransportListener); + return () -> { + synchronized (BinderClientTransport.this) { + if (inState(TransportState.NOT_STARTED)) { + setState(TransportState.SETUP); + serviceBinding.bind(); + } + } + }; + } + + @Override + public synchronized ClientStream newStream( + final MethodDescriptor method, + final Metadata headers, + final CallOptions callOptions) { + if (isShutdown()) { + return newFailingClientStream(shutdownStatus, callOptions, attributes, headers); + } else { + int callId = latestCallId++; + if (latestCallId == LAST_CALL_ID) { + latestCallId = FIRST_CALL_ID; + } + Inbound.ClientInbound inbound = + new Inbound.ClientInbound( + this, attributes, callId, GrpcUtil.shouldBeCountedForInUse(callOptions)); + if (ongoingCalls.putIfAbsent(callId, inbound) != null) { + Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); + shutdownInternal(failure, true); + return newFailingClientStream(failure, callOptions, attributes, headers); + } else { + if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { + clientTransportListener.transportInUse(true); + } + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(callOptions, attributes, headers); + + Outbound.ClientOutbound outbound = + new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); + if (method.getType().clientSendsOneMessage()) { + return new SingleMessageClientStream(inbound, outbound, attributes); + } else { + return new MultiMessageClientStream(inbound, outbound, attributes); + } + } + } + } + + @Override + protected void unregisterInbound(Inbound inbound) { + if (inbound.countsForInUse() && numInUseStreams.decrementAndGet() == 0) { + clientTransportListener.transportInUse(false); + } + super.unregisterInbound(inbound); + } + + @Override + public void ping(final PingCallback callback, Executor executor) { + pingTracker.startPing(callback, executor); + } + + @Override + public synchronized void shutdown(Status reason) { + checkNotNull(reason, "reason"); + shutdownInternal(reason, false); + } + + @Override + public synchronized void shutdownNow(Status reason) { + checkNotNull(reason, "reason"); + shutdownInternal(reason, true); + } + + @Override + @GuardedBy("this") + public void notifyShutdown(Status status) { + clientTransportListener.transportShutdown(status); + } + + @Override + @GuardedBy("this") + public void notifyTerminated() { + if (numInUseStreams.getAndSet(0) > 0) { + clientTransportListener.transportInUse(false); + } + serviceBinding.unbind(); + clientTransportListener.transportTerminated(); + } + + @Override + @GuardedBy("this") + protected void handleSetupTransport(Parcel parcel) { + // Add the remote uid to our attributes. + attributes = setSecurityAttrs(attributes, Binder.getCallingUid()); + if (inState(TransportState.SETUP)) { + int version = parcel.readInt(); + IBinder binder = parcel.readStrongBinder(); + if (version != WIRE_FORMAT_VERSION) { + shutdownInternal( + Status.UNAVAILABLE.withDescription("Wire format version mismatch"), true); + } else if (binder == null) { + shutdownInternal( + Status.UNAVAILABLE.withDescription("Malformed SETUP_TRANSPORT data"), true); + } else { + offloadExecutor.execute(() -> checkSecurityPolicy(binder)); + } + } + } + + private void checkSecurityPolicy(IBinder binder) { + Status authorization; + Integer remoteUid; + synchronized (this) { + remoteUid = attributes.get(REMOTE_UID); + } + if (remoteUid == null) { + authorization = Status.UNAUTHENTICATED.withDescription("No remote UID available"); + } else { + authorization = securityPolicy.checkAuthorization(remoteUid); + } + synchronized (this) { + if (inState(TransportState.SETUP)) { + if (!authorization.isOk()) { + shutdownInternal(authorization, true); + } else if (!setOutgoingBinder(binder)) { + shutdownInternal( + Status.UNAVAILABLE.withDescription("Failed to observe outgoing binder"), true); + } else { + // Check state again, since a failure inside setOutgoingBinder (or a callback it + // triggers), could have shut us down. + if (!isShutdown()) { + setState(TransportState.READY); + clientTransportListener.transportReady(); + } + } + } + } + } + + @GuardedBy("this") + @Override + protected void handlePingResponse(Parcel parcel) { + pingTracker.onPingResponse(parcel.readInt()); + } + + private static ClientStream newFailingClientStream( + Status failure, CallOptions callOptions, Attributes attributes, Metadata headers) { + StatsTraceContext statsTraceContext = + StatsTraceContext.newClientContext(callOptions, attributes, headers); + statsTraceContext.clientOutboundHeaders(); + statsTraceContext.streamClosed(failure); + return new FailingClientStream(failure); + } + + private static InternalLogId buildLogId( + Context sourceContext, AndroidComponentAddress targetAddress) { + return InternalLogId.allocate( + BinderClientTransport.class, + sourceContext.getClass().getSimpleName() + + "->" + + targetAddress.getComponent().toShortString()); + } + + private static Attributes buildClientAttributes( + Attributes eagAttrs, + Context sourceContext, + AndroidComponentAddress targetAddress, + InboundParcelablePolicy inboundParcelablePolicy) { + return Attributes.newBuilder() + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.NONE) // Trust noone for now. + .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs) + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, AndroidComponentAddress.forContext(sourceContext)) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, targetAddress) + .set(INBOUND_PARCELABLE_POLICY, inboundParcelablePolicy) + .build(); + } + + private static Attributes setSecurityAttrs(Attributes attributes, int uid) { + return attributes.toBuilder() + .set(REMOTE_UID, uid) + .set( + GrpcAttributes.ATTR_SECURITY_LEVEL, + uid == Process.myUid() + ? SecurityLevel.PRIVACY_AND_INTEGRITY + : SecurityLevel.INTEGRITY) // TODO: Have the SecrityPolicy decide this. + .build(); + } + } + + /** Concrete server-side transport implementation. */ + static final class BinderServerTransport extends BinderTransport implements ServerTransport { + + private final List streamTracerFactories; + @Nullable private ServerTransportListener serverTransportListener; + + BinderServerTransport( + ObjectPool executorServicePool, + Attributes attributes, + List streamTracerFactories, + IBinder callbackBinder) { + super(executorServicePool, attributes, buildLogId(attributes)); + this.streamTracerFactories = streamTracerFactories; + setOutgoingBinder(callbackBinder); + } + + synchronized void setServerTransportListener(ServerTransportListener serverTransportListener) { + this.serverTransportListener = serverTransportListener; + if (isShutdown()) { + setState(TransportState.SHUTDOWN_TERMINATED); + notifyTerminated(); + releaseExecutors(); + } else { + sendSetupTransaction(); + // Check we're not shutdown again, since a failure inside sendSetupTransaction (or a + // callback it triggers), could have shut us down. + if (!isShutdown()) { + setState(TransportState.READY); + attributes = serverTransportListener.transportReady(attributes); + } + } + } + + StatsTraceContext createStatsTraceContext(String methodName, Metadata headers) { + return StatsTraceContext.newServerContext(streamTracerFactories, methodName, headers); + } + + synchronized Status startStream(ServerStream stream, String methodName, Metadata headers) { + if (isShutdown()) { + return Status.UNAVAILABLE.withDescription("transport is shutdown"); + } else { + serverTransportListener.streamCreated(stream, methodName, headers); + return Status.OK; + } + } + + @Override + @GuardedBy("this") + public void notifyShutdown(Status status) { + // Nothing to do. + } + + @Override + @GuardedBy("this") + public void notifyTerminated() { + if (serverTransportListener != null) { + serverTransportListener.transportTerminated(); + } + } + + @Override + public synchronized void shutdown() { + shutdownInternal(Status.OK, false); + } + + @Override + public synchronized void shutdownNow(Status reason) { + shutdownInternal(reason, true); + } + + @Override + @Nullable + @GuardedBy("this") + protected Inbound createInbound(int callId) { + return new Inbound.ServerInbound(this, attributes, callId); + } + + private static InternalLogId buildLogId(Attributes attributes) { + return InternalLogId.allocate( + BinderServerTransport.class, "from " + attributes.get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR)); + } + } + + private static void checkTransition(TransportState current, TransportState next) { + switch (next) { + case SETUP: + checkState(current == TransportState.NOT_STARTED); + break; + case READY: + checkState(current == TransportState.NOT_STARTED || current == TransportState.SETUP); + break; + case SHUTDOWN: + checkState( + current == TransportState.NOT_STARTED + || current == TransportState.SETUP + || current == TransportState.READY); + break; + case SHUTDOWN_TERMINATED: + checkState(current == TransportState.SHUTDOWN); + break; + default: + throw new AssertionError(); + } + } + + @VisibleForTesting + Map> getOngoingCalls() { + return ongoingCalls; + } + + private static Status statusFromRemoteException(RemoteException e) { + if (e instanceof DeadObjectException || e instanceof TransactionTooLargeException) { + // These are to be expected from time to time and can simply be retried. + return Status.UNAVAILABLE.withCause(e); + } + // Otherwise, this exception from transact is unexpected. + return Status.INTERNAL.withCause(e); + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java new file mode 100644 index 00000000000..3a06aa1c120 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransportSecurity.java @@ -0,0 +1,130 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import io.grpc.Attributes; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.SecurityLevel; +import io.grpc.ServerBuilder; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.binder.ServerSecurityPolicy; +import io.grpc.internal.GrpcAttributes; +import java.util.concurrent.ConcurrentHashMap; +import javax.annotation.CheckReturnValue; + +/** + * Manages security for an Android Service hosted gRPC server. + * + *

Attaches authorization state to a newly-created transport, and contains a + * ServerInterceptor which ensures calls are authorized before allowing them to proceed. + */ +public final class BinderTransportSecurity { + + private static final Attributes.Key TRANSPORT_AUTHORIZATION_STATE = + Attributes.Key.create("transport-authorization-state"); + + private BinderTransportSecurity() {} + + /** + * Install a security policy on an about-to-be created server. + * + * @param serverBuilder The ServerBuilder being used to create the server. + */ + public static void installAuthInterceptor(ServerBuilder serverBuilder) { + serverBuilder.intercept(new ServerAuthInterceptor()); + } + + /** + * Attach the given security policy to the transport attributes being built. Will be used by the + * auth interceptor to confirm accept or reject calls. + * + * @param builder The {@link Attributes.Builder} for the transport being created. + * @param remoteUid The remote UID of the transport. + * @param securityPolicy The policy to enforce on this transport. + */ + static void attachAuthAttrs( + Attributes.Builder builder, int remoteUid, ServerSecurityPolicy securityPolicy) { + builder + .set( + TRANSPORT_AUTHORIZATION_STATE, + new TransportAuthorizationState(remoteUid, securityPolicy)) + .set(GrpcAttributes.ATTR_SECURITY_LEVEL, SecurityLevel.PRIVACY_AND_INTEGRITY); + } + + /** + * Intercepts server calls and ensures they're authorized before allowing them to proceed. + * Authentication state is fetched from the call attributes, inherited from the transport. + */ + private static final class ServerAuthInterceptor implements ServerInterceptor { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + Status authStatus = + call.getAttributes() + .get(TRANSPORT_AUTHORIZATION_STATE) + .checkAuthorization(call.getMethodDescriptor()); + if (authStatus.isOk()) { + return next.startCall(call, headers); + } else { + call.close(authStatus, new Metadata()); + return new ServerCall.Listener() {}; + } + } + } + + /** + * Maintaines the authorization state for a single transport instance. This class lives for the + * lifetime of a single transport. + */ + private static final class TransportAuthorizationState { + private final int uid; + private final ServerSecurityPolicy policy; + private final ConcurrentHashMap serviceAuthorization; + + TransportAuthorizationState(int uid, ServerSecurityPolicy policy) { + this.uid = uid; + this.policy = policy; + serviceAuthorization = new ConcurrentHashMap<>(8); + } + + /** Get whether we're authorized to make this call. */ + @CheckReturnValue + Status checkAuthorization(MethodDescriptor method) { + String serviceName = method.getServiceName(); + // Only cache decisions if the method can be sampled for tracing, + // which is true for all generated methods. Otherwise, programatically + // created methods could casue this cahe to grow unbounded. + boolean useCache = method.isSampledToLocalTracing(); + Status authorization; + if (useCache) { + authorization = serviceAuthorization.get(serviceName); + if (authorization != null) { + return authorization; + } + } + authorization = policy.checkAuthorizationForService(uid, serviceName); + if (useCache) { + serviceAuthorization.putIfAbsent(serviceName, authorization); + } + return authorization; + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BlockInputStream.java b/binder/src/main/java/io/grpc/binder/internal/BlockInputStream.java new file mode 100644 index 00000000000..1ac1531da18 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/BlockInputStream.java @@ -0,0 +1,159 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import com.google.common.primitives.Ints; +import io.grpc.Drainable; +import io.grpc.KnownLength; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import javax.annotation.Nullable; +import javax.annotation.concurrent.NotThreadSafe; + +/** + * A simple InputStream from a 2-dimensional byte array. + * + * Used to provide message data from incoming blocks of data. It is assumed that + * all byte arrays passed in the constructor of this this class are owned by the new + * instance. + * + * This also assumes byte arrays are created by the BlockPool class, and should + * be returned to it when this class is closed. + */ +@NotThreadSafe +final class BlockInputStream extends InputStream implements KnownLength, Drainable { + + @Nullable + private byte[][] blocks; + @Nullable + private byte[] currentBlock; + private int blockIndex; + private int blockOffset; + private int available; + private boolean closed; + + /** + * Creates a new stream with a single block. + * + * @param block The single byte array block, ownership of which is + * passed to this instance. + */ + BlockInputStream(byte[] block) { + this.blocks = null; + currentBlock = block.length > 0 ? block : null; + available = block.length; + } + + /** + * Creates a new stream from a sequence of blocks. + * + * @param blocks A two dimensional byte array containing the data. Ownership + * of all blocks is passed to this instance. + * @param available The number of bytes available in total. This may be + * less than (but never more than) the total size of all byte arrays in blocks. + */ + BlockInputStream(byte[][] blocks, int available) { + this.blocks = blocks; + this.available = available; + if (blocks.length > 0) { + currentBlock = blocks[0]; + } + } + + @Override + public int read() throws IOException { + if (currentBlock != null) { + int res = currentBlock[blockOffset++]; + available -= 1; + if (blockOffset == currentBlock.length) { + nextBlock(); + } + return res; + } + return -1; + } + + @Override + public int read(byte[] data, int off, int len) throws IOException { + int stillToRead = len; + while (currentBlock != null) { + int n = Ints.min(stillToRead, currentBlock.length - blockOffset, available); + System.arraycopy(currentBlock, blockOffset, data, off, n); + off += n; + stillToRead -= n; + available -= n; + if (stillToRead == 0) { + blockOffset += n; + if (blockOffset == currentBlock.length) { + nextBlock(); + } + break; + } else { + nextBlock(); + } + } + int bytesRead = len - stillToRead; + if (bytesRead > 0 || available > 0) { + return bytesRead; + } + return -1; + } + + private void nextBlock() { + blockIndex += 1; + blockOffset = 0; + if (blocks != null && blockIndex < blocks.length) { + currentBlock = blocks[blockIndex]; + } else { + currentBlock = null; + } + } + + @Override + public int available() { + return available; + } + + @Override + public int drainTo(OutputStream output) throws IOException { + int res = available; + while (available > 0) { + int n = Math.min(currentBlock.length - blockOffset, available); + output.write(currentBlock, blockOffset, n); + available -= n; + nextBlock(); + } + return res; + } + + @Override + public void close() { + if (!closed) { + closed = true; + if (blocks != null) { + for (byte[] block : blocks) { + BlockPool.releaseBlock(block); + } + } else if (currentBlock != null) { + BlockPool.releaseBlock(currentBlock); + } + currentBlock = null; + blocks = null; + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BlockPool.java b/binder/src/main/java/io/grpc/binder/internal/BlockPool.java new file mode 100644 index 00000000000..9ca766791bf --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/BlockPool.java @@ -0,0 +1,80 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import io.grpc.internal.GrpcUtil; +import java.util.Queue; +import java.util.concurrent.LinkedBlockingQueue; + +/** + * Manages a pool of byte-array blocks. + * + *

Unfortunately, the Android Parcel api only allws us to read a block of N bytes when we have a + * byte array of size N. This means we can't simply read into a large block and be done with it, we + * need to allocate a new buffer specifically. Boo, Android. + * + *

When writing data though, we can use a fixed-size buffer, so when large messages are + * split into standard-sized blocks, we only need a byte array allocation to read the last + * block. + * + *

This class maintains a pool of blocks of standard size, but also provides smaller blocks when + * requested. Currently, blocks of standard size are retained in the pool, when released, but we + * could chose to change this strategy. + */ +final class BlockPool { + + /** + * The size of each standard block. (Currently 16k) + * The block size must be at least as large as the maximum header list size. + */ + private static final int BLOCK_SIZE = Math.max(16 * 1024, GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE); + + /** + * Maximum number of blocks to keep around. (Max 128k). This limit is a judgement call. 128k is + * small enough that it shouldn't significantly affect the memory usage of a large app, but large + * enough that it should to reduce allocation churn while gRPC is in use. + */ + private static final int BLOCK_POOL_SIZE = 128 * 1024 / BLOCK_SIZE; + + /** A pool of byte arrays of standard size. We don't use any blocking methods of this instance. */ + private static final Queue blockPool = new LinkedBlockingQueue<>(BLOCK_POOL_SIZE); + + private BlockPool() {} + + /** Acquire a block of standard size. */ + static byte[] acquireBlock() { + return acquireBlock(BLOCK_SIZE); + } + + /** Acquire a block of the specified size. */ + static byte[] acquireBlock(int size) { + if (size == BLOCK_SIZE) { + byte[] block = blockPool.poll(); + if (block != null) { + return block; + } + } + return new byte[size]; + } + + /** Release a now-unused block. */ + static void releaseBlock(byte[] block) { + if (block.length == BLOCK_SIZE) { + blockPool.offer(block); + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/BoundClientAddress.java b/binder/src/main/java/io/grpc/binder/internal/BoundClientAddress.java new file mode 100644 index 00000000000..d54ad9c8106 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/BoundClientAddress.java @@ -0,0 +1,51 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import java.net.SocketAddress; + +/** An address to represent a binding from a remote client. */ +final class BoundClientAddress extends SocketAddress { + + private static final long serialVersionUID = 0L; + + /** The UID of the address. For incoming binder transactions, this is all the info we have. */ + private final int uid; + + BoundClientAddress(int uid) { + this.uid = uid; + } + + @Override + public int hashCode() { + return uid; + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof BoundClientAddress) { + BoundClientAddress that = (BoundClientAddress) obj; + return uid == that.uid; + } + return false; + } + + @Override + public String toString() { + return "BoundClientAddress[" + uid + "]"; + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/Inbound.java b/binder/src/main/java/io/grpc/binder/internal/Inbound.java new file mode 100644 index 00000000000..f28b9bfb29a --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/Inbound.java @@ -0,0 +1,731 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import android.os.Parcel; +import io.grpc.Attributes; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.binder.InboundParcelablePolicy; +import io.grpc.internal.ClientStreamListener; +import io.grpc.internal.ClientStreamListener.RpcProgress; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerStreamListener; +import io.grpc.internal.StatsTraceContext; +import io.grpc.internal.StreamListener; +import java.io.InputStream; +import java.util.ArrayList; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; + +/** + * Handles incoming binder transactions for a single stream, turning those transactions into calls + * to the stream listener. + * + *

Out-of-order messages are reassembled into their correct order. + */ +abstract class Inbound implements StreamListener.MessageProducer { + + protected final BinderTransport transport; + protected final Attributes attributes; + final int callId; + + // ========================== + // Values set when we're initialized. + + @Nullable + @GuardedBy("this") + protected Outbound outbound; + + @Nullable + @GuardedBy("this") + protected StatsTraceContext statsTraceContext; + + @Nullable + @GuardedBy("this") + protected L listener; + + // ========================== + // State of inbound data. + + @Nullable + @GuardedBy("this") + private InputStream firstMessage; + + @GuardedBy("this") + private int firstQueuedTransactionIndex; + + @GuardedBy("this") + private int nextCompleteMessageEnd; + + @Nullable + @GuardedBy("this") + private ArrayList queuedTransactionData; + + @GuardedBy("this") + private boolean suffixAvailable; + + @GuardedBy("this") + private int suffixTransactionIndex; + + @GuardedBy("this") + private int inboundDataSize; + + // ========================== + // State of what we've delivered to gRPC. + + /** + * Each rpc transmits (or receives) a prefix (including headers and possibly a method name), the + * data of zero or more request (or response) messages, and a suffix (possibly including a close + * status and trailers). + * + *

This enum represents those stages, for both availability (what we've been given), and + * delivery what we've sent. + */ + enum State { + // We aren't yet connected to a BinderStream instance and listener. Due to potentially + // out-of-order messages, a server-side instance can remain in this state for multiple + // transactions. + UNINITIALIZED, + + // We're attached to a BinderStream instance and we have a listener we can report to. + // On the client-side, this happens as soon as the start() method is called (almost + // immediately), and on the server side, this happens as soon as we receive the prefix + // (so we know which method is being called). + INITIALIZED, + + // We've delivered the prefix data to the listener. On the client side, this means we've + // delivered the response headers, and on the server side this state is effectively the same + // as INITIALIZED (since we initialize only by delivering the prefix). + PREFIX_DELIVERED, + + // All messages have been received, and delivered to the listener. + ALL_MESSAGES_DELIVERED, + + // We've delivered the suffix. + SUFFIX_DELIVERED, + + // The stream is closed. + CLOSED + } + + /* + * Represents which data we've delivered to the gRPC listener. + */ + @GuardedBy("this") + private State deliveryState = State.UNINITIALIZED; + + @GuardedBy("this") + private int numReceivedMessages; + + @GuardedBy("this") + private int numRequestedMessages; + + @GuardedBy("this") + private boolean delivering; + + @GuardedBy("this") + private boolean producingMessages; + + private Inbound(BinderTransport transport, Attributes attributes, int callId) { + this.transport = transport; + this.attributes = attributes; + this.callId = callId; + } + + @GuardedBy("this") + final void init(Outbound outbound, L listener) { + this.outbound = outbound; + this.statsTraceContext = outbound.getStatsTraceContext(); + this.listener = listener; + if (!isClosed()) { + onDeliveryState(State.INITIALIZED); + } + } + + final void unregister() { + transport.unregisterInbound(this); + } + + boolean countsForInUse() { + return false; + } + + // ===================== + // Updates to delivery. + + @GuardedBy("this") + protected final void onDeliveryState(State deliveryState) { + checkTransition(this.deliveryState, deliveryState); + this.deliveryState = deliveryState; + } + + @GuardedBy("this") + protected final boolean isClosed() { + return deliveryState == State.CLOSED; + } + + @GuardedBy("this") + private final boolean messageAvailable() { + return firstMessage != null || nextCompleteMessageEnd > 0; + } + + @GuardedBy("this") + private boolean receivedAllTransactions() { + return suffixAvailable && firstQueuedTransactionIndex >= suffixTransactionIndex; + } + + // =================== + // Internals. + + @GuardedBy("this") + final void deliver() { + if (delivering) { + // Don't re-enter. + return; + } + delivering = true; + while (canDeliver()) { + deliverInternal(); + } + delivering = false; + } + + @GuardedBy("this") + private final boolean canDeliver() { + switch (deliveryState) { + case PREFIX_DELIVERED: + if (listener != null) { + if (producingMessages) { + // We're waiting for the listener to consume messages. Nothing to do. + return false; + } else if (messageAvailable()) { + // There's a message. We can deliver if we've been asked for messages, and we haven't + // already given the listener a MessageProducer. + return numRequestedMessages != 0; + } else { + // There are no messages available. Return true if that's the last of them, because we + // can send the suffix. + return receivedAllTransactions(); + } + } + return false; + case ALL_MESSAGES_DELIVERED: + return listener != null && suffixAvailable; + default: + return false; + } + } + + @GuardedBy("this") + @SuppressWarnings("fallthrough") + private final void deliverInternal() { + switch (deliveryState) { + case PREFIX_DELIVERED: + if (producingMessages) { + break; + } else if (messageAvailable()) { + producingMessages = true; + listener.messagesAvailable(this); + break; + } else if (!suffixAvailable) { + break; + } + onDeliveryState(State.ALL_MESSAGES_DELIVERED); + // Fall-through. + case ALL_MESSAGES_DELIVERED: + if (suffixAvailable) { + onDeliveryState(State.SUFFIX_DELIVERED); + deliverSuffix(); + } + break; + default: + throw new AssertionError(); + } + } + + /** Deliver the suffix to gRPC. */ + protected abstract void deliverSuffix(); + + @GuardedBy("this") + final void closeOnCancel(Status status) { + closeAbnormal(Status.CANCELLED, status, false); + } + + @GuardedBy("this") + private final void closeOutOfBand(Status status) { + closeAbnormal(status, status, true); + } + + @GuardedBy("this") + final void closeAbnormal(Status status) { + closeAbnormal(status, status, false); + } + + @GuardedBy("this") + private final void closeAbnormal( + Status outboundStatus, Status internalStatus, boolean isOobFromRemote) { + if (!isClosed()) { + boolean wasInitialized = (deliveryState != State.UNINITIALIZED); + onDeliveryState(State.CLOSED); + if (wasInitialized) { + statsTraceContext.streamClosed(internalStatus); + } + if (!isOobFromRemote) { + transport.sendOutOfBandClose(callId, outboundStatus); + } + if (wasInitialized) { + deliverCloseAbnormal(internalStatus); + } + unregister(); + } + } + + @GuardedBy("this") + protected abstract void deliverCloseAbnormal(Status status); + + final void onTransportReady() { + // Report transport readiness to the listener, and the outbound data. + Outbound outbound = null; + StreamListener listener = null; + synchronized (this) { + outbound = this.outbound; + listener = this.listener; + } + if (listener != null) { + listener.onReady(); + } + if (outbound != null) { + try { + synchronized (outbound) { + outbound.onTransportReady(); + } + } catch (StatusException se) { + synchronized (this) { + closeAbnormal(se.getStatus()); + } + } + } + } + + @GuardedBy("this") + public void requestMessages(int num) { + numRequestedMessages += num; + deliver(); + } + + final synchronized void handleTransaction(Parcel parcel) { + if (isClosed()) { + return; + } + try { + int flags = parcel.readInt(); + if (TransactionUtils.hasFlag(flags, TransactionUtils.FLAG_OUT_OF_BAND_CLOSE)) { + closeOutOfBand(TransactionUtils.readStatus(flags, parcel)); + return; + } + int index = parcel.readInt(); + boolean hasPrefix = TransactionUtils.hasFlag(flags, TransactionUtils.FLAG_PREFIX); + boolean hasMessageData = + (TransactionUtils.hasFlag(flags, TransactionUtils.FLAG_MESSAGE_DATA)); + boolean hasSuffix = (TransactionUtils.hasFlag(flags, TransactionUtils.FLAG_SUFFIX)); + if (hasPrefix) { + handlePrefix(flags, parcel); + onDeliveryState(State.PREFIX_DELIVERED); + } + if (hasMessageData) { + handleMessageData(flags, index, parcel); + } + if (hasSuffix) { + handleSuffix(flags, parcel); + suffixTransactionIndex = index; + suffixAvailable = true; + } + if (index == firstQueuedTransactionIndex) { + if (queuedTransactionData == null) { + // This message was in order, and we haven't needed to queue anything yet. + firstQueuedTransactionIndex += 1; + } else if (!hasMessageData && !hasSuffix) { + // The first transaction arrived, but it contained no message data. + queuedTransactionData.remove(0); + firstQueuedTransactionIndex += 1; + } + } + reportInboundSize(parcel.dataSize()); + deliver(); + } catch (StatusException se) { + closeAbnormal(se.getStatus()); + } + } + + @GuardedBy("this") + abstract void handlePrefix(int flags, Parcel parcel) throws StatusException; + + @GuardedBy("this") + abstract void handleSuffix(int flags, Parcel parcel) throws StatusException; + + @GuardedBy("this") + private void handleMessageData(int flags, int index, Parcel parcel) throws StatusException { + InputStream stream = null; + byte[] block = null; + boolean lastBlockOfMessage = true; + int numBytes = 0; + if ((flags & TransactionUtils.FLAG_MESSAGE_DATA_IS_PARCELABLE) != 0) { + InboundParcelablePolicy policy = attributes.get(BinderTransport.INBOUND_PARCELABLE_POLICY); + if (policy == null || !policy.shouldAcceptParcelableMessages()) { + throw Status.PERMISSION_DENIED + .withDescription("Parcelable messages not allowed") + .asException(); + } + int startPos = parcel.dataPosition(); + stream = ParcelableInputStream.readFromParcel(parcel, getClass().getClassLoader()); + numBytes = parcel.dataPosition() - startPos; + } else { + numBytes = parcel.readInt(); + block = BlockPool.acquireBlock(numBytes); + if (numBytes > 0) { + parcel.readByteArray(block); + } + if ((flags & TransactionUtils.FLAG_MESSAGE_DATA_IS_PARTIAL) != 0) { + // Partial message. Ensure we have a message assembler. + lastBlockOfMessage = false; + } + } + if (queuedTransactionData == null) { + if (numReceivedMessages == 0 && lastBlockOfMessage && index == firstQueuedTransactionIndex) { + // Shortcut for when we receive a single message in one transaction. + checkState(firstMessage == null); + firstMessage = (stream != null) ? stream : new BlockInputStream(block); + reportInboundMessage(numBytes); + return; + } + queuedTransactionData = new ArrayList<>(16); + } + enqueueTransactionData(index, new TransactionData(stream, block, numBytes, lastBlockOfMessage)); + } + + @GuardedBy("this") + private void enqueueTransactionData(int index, TransactionData data) { + int offset = index - firstQueuedTransactionIndex; + if (offset < queuedTransactionData.size()) { + queuedTransactionData.set(offset, data); + lookForCompleteMessage(); + } else if (offset > queuedTransactionData.size()) { + do { + queuedTransactionData.add(null); + } while (offset > queuedTransactionData.size()); + queuedTransactionData.add(data); + } else { + queuedTransactionData.add(data); + lookForCompleteMessage(); + } + } + + @GuardedBy("this") + private void lookForCompleteMessage() { + int numBytes = 0; + if (nextCompleteMessageEnd == 0) { + for (int i = 0; i < queuedTransactionData.size(); i++) { + TransactionData data = queuedTransactionData.get(i); + if (data == null) { + // Missing block. + return; + } else { + numBytes += data.numBytes; + if (data.lastBlockOfMessage) { + // Found a complete message. + nextCompleteMessageEnd = i + 1; + reportInboundMessage(numBytes); + return; + } + } + } + } + } + + @Override + @Nullable + public final synchronized InputStream next() { + InputStream stream = null; + if (firstMessage != null) { + stream = firstMessage; + firstMessage = null; + } else if (messageAvailable()) { + stream = assembleNextMessage(); + } + if (stream != null) { + numRequestedMessages -= 1; + } else { + producingMessages = false; + if (receivedAllTransactions()) { + // That's the last of the messages delivered. + if (!isClosed()) { + onDeliveryState(State.ALL_MESSAGES_DELIVERED); + deliver(); + } + } + } + return stream; + } + + @GuardedBy("this") + private InputStream assembleNextMessage() { + InputStream message; + int numBlocks = nextCompleteMessageEnd; + nextCompleteMessageEnd = 0; + int numBytes = 0; + if (numBlocks == 1) { + // Single block. + TransactionData data = queuedTransactionData.remove(0); + numBytes = data.numBytes; + if (data.stream != null) { + message = data.stream; + } else { + message = new BlockInputStream(data.block); + } + } else { + byte[][] blocks = new byte[numBlocks][]; + for (int i = 0; i < numBlocks; i++) { + TransactionData data = queuedTransactionData.remove(0); + blocks[i] = checkNotNull(data.block); + numBytes += blocks[i].length; + } + message = new BlockInputStream(blocks, numBytes); + } + firstQueuedTransactionIndex += numBlocks; + lookForCompleteMessage(); + return message; + } + + // ------------------------------------ + // stats collection. + + @GuardedBy("this") + private void reportInboundSize(int size) { + inboundDataSize += size; + if (statsTraceContext != null && inboundDataSize != 0) { + statsTraceContext.inboundWireSize(inboundDataSize); + statsTraceContext.inboundUncompressedSize(inboundDataSize); + inboundDataSize = 0; + } + } + + @GuardedBy("this") + private void reportInboundMessage(int numBytes) { + checkNotNull(statsTraceContext); + statsTraceContext.inboundMessage(numReceivedMessages); + statsTraceContext.inboundMessageRead(numReceivedMessages, numBytes, numBytes); + numReceivedMessages += 1; + } + + @Override + public synchronized String toString() { + return getClass().getSimpleName() + + "[SfxA=" + + suffixAvailable + + "/De=" + + deliveryState + + "/Msg=" + + messageAvailable() + + "/Lis=" + + (listener != null) + + "]"; + } + + // ====================================== + // Client-side inbound transactions. + static final class ClientInbound extends Inbound { + + private final boolean countsForInUse; + + @Nullable + @GuardedBy("this") + private Status closeStatus; + + @Nullable + @GuardedBy("this") + private Metadata trailers; + + ClientInbound( + BinderTransport transport, Attributes attributes, int callId, boolean countsForInUse) { + super(transport, attributes, callId); + this.countsForInUse = countsForInUse; + } + + @Override + boolean countsForInUse() { + return countsForInUse; + } + + @Override + @GuardedBy("this") + protected void handlePrefix(int flags, Parcel parcel) throws StatusException { + Metadata headers = MetadataHelper.readMetadata(parcel, attributes); + statsTraceContext.clientInboundHeaders(); + listener.headersRead(headers); + } + + @Override + @GuardedBy("this") + protected void handleSuffix(int flags, Parcel parcel) throws StatusException { + closeStatus = TransactionUtils.readStatus(flags, parcel); + trailers = MetadataHelper.readMetadata(parcel, attributes); + } + + @Override + @GuardedBy("this") + protected void deliverSuffix() { + statsTraceContext.clientInboundTrailers(trailers); + statsTraceContext.streamClosed(closeStatus); + onDeliveryState(State.CLOSED); + listener.closed(closeStatus, RpcProgress.PROCESSED, trailers); + unregister(); + } + + @Override + @GuardedBy("this") + protected void deliverCloseAbnormal(Status status) { + listener.closed(status, RpcProgress.PROCESSED, new Metadata()); + } + } + + // ====================================== + // Server-side inbound transactions. + static final class ServerInbound extends Inbound { + + private final BinderTransport.BinderServerTransport serverTransport; + + ServerInbound( + BinderTransport.BinderServerTransport transport, Attributes attributes, int callId) { + super(transport, attributes, callId); + this.serverTransport = transport; + } + + @GuardedBy("this") + @Override + protected void handlePrefix(int flags, Parcel parcel) throws StatusException { + String methodName = parcel.readString(); + Metadata headers = MetadataHelper.readMetadata(parcel, attributes); + + StatsTraceContext statsTraceContext = + serverTransport.createStatsTraceContext(methodName, headers); + Outbound.ServerOutbound outbound = + new Outbound.ServerOutbound(serverTransport, callId, statsTraceContext); + ServerStream stream; + if ((flags & TransactionUtils.FLAG_EXPECT_SINGLE_MESSAGE) != 0) { + stream = new SingleMessageServerStream(this, outbound, attributes); + } else { + stream = new MultiMessageServerStream(this, outbound, attributes); + } + Status status = serverTransport.startStream(stream, methodName, headers); + if (status.isOk()) { + checkNotNull(listener); // Is it ok to assume this will happen synchronously? + if (transport.isReady()) { + listener.onReady(); + } + } else { + closeAbnormal(status); + } + } + + @GuardedBy("this") + @Override + protected void handleSuffix(int flags, Parcel parcel) { + // Nothing to read. + } + + @Override + @GuardedBy("this") + protected void deliverSuffix() { + listener.halfClosed(); + } + + @Override + @GuardedBy("this") + protected void deliverCloseAbnormal(Status status) { + listener.closed(status); + } + + @GuardedBy("this") + void onCloseSent(Status status) { + if (!isClosed()) { + onDeliveryState(State.CLOSED); + statsTraceContext.streamClosed(status); + listener.closed(Status.OK); + } + } + } + + // ====================================== + // Helper methods. + + private static void checkTransition(State current, State next) { + switch (next) { + case INITIALIZED: + checkState(current == State.UNINITIALIZED, "%s -> %s", current, next); + break; + case PREFIX_DELIVERED: + checkState( + current == State.INITIALIZED || current == State.UNINITIALIZED, + "%s -> %s", + current, + next); + break; + case ALL_MESSAGES_DELIVERED: + checkState(current == State.PREFIX_DELIVERED, "%s -> %s", current, next); + break; + case SUFFIX_DELIVERED: + checkState(current == State.ALL_MESSAGES_DELIVERED, "%s -> %s", current, next); + break; + case CLOSED: + break; + default: + throw new AssertionError(); + } + } + + // ====================================== + // Message reassembly. + + /** Part of an unconsumed message. */ + private static final class TransactionData { + @Nullable final InputStream stream; + @Nullable final byte[] block; + final int numBytes; + final boolean lastBlockOfMessage; + + TransactionData(InputStream stream, byte[] block, int numBytes, boolean lastBlockOfMessage) { + this.stream = stream; + this.block = block; + this.numBytes = numBytes; + this.lastBlockOfMessage = lastBlockOfMessage; + } + + @Override + public String toString() { + return "TransactionData[" + + numBytes + + "b " + + (stream != null ? "stream" : "array") + + (lastBlockOfMessage ? "(last)]" : "]"); + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/MetadataHelper.java b/binder/src/main/java/io/grpc/binder/internal/MetadataHelper.java new file mode 100644 index 00000000000..211768f2948 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/MetadataHelper.java @@ -0,0 +1,224 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import android.os.Parcel; +import android.os.Parcelable; +import android.util.AndroidRuntimeException; +import io.grpc.Attributes; +import io.grpc.InternalMetadata; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.binder.InboundParcelablePolicy; +import io.grpc.internal.GrpcUtil; +import java.io.IOException; +import java.io.InputStream; +import javax.annotation.Nullable; + +/** + * Helper class for reading & writing metadata to parcels. + * + *

Metadata is written to a parcel as a single int for the number of name/value pairs, followed + * by the following pattern for each pair. + * + *

    + *
  1. name length (int) + *
  2. name (byte[]) + *
  3. value length OR sentinel (int) + *
  4. value (byte[] OR Parcelable) + *
+ * + * The sentinel int at the start of a value may indicate bad metadata. When this happens, no more + * data follows the sentinel. + */ +public final class MetadataHelper { + + /** The generic metadata marshaller we use for reading parcelables from the transport. */ + private static final Metadata.BinaryStreamMarshaller TRANSPORT_INBOUND_MARSHALLER = + new ParcelableMetadataMarshaller<>(null, true); + + /** Indicates the following value is a parcelable. */ + private static final int PARCELABLE_SENTINEL = -1; + + private MetadataHelper() {} + + /** + * Write a Metadata instance to a Parcel. + * + * @param parcel The {@link Parcel} to write to. + * @param metadata The {@link Metadata} to write. + */ + public static void writeMetadata(Parcel parcel, @Nullable Metadata metadata) + throws StatusException, IOException { + int n = metadata != null ? InternalMetadata.headerCount(metadata) : 0; + if (n == 0) { + parcel.writeInt(0); + return; + } + Object[] serialized = InternalMetadata.serializePartial(metadata); + parcel.writeInt(n); + for (int i = 0; i < n; i++) { + byte[] name = (byte[]) serialized[i * 2]; + parcel.writeInt(name.length); + parcel.writeByteArray(name); + Object value = serialized[i * 2 + 1]; + if (value instanceof byte[]) { + byte[] valueBytes = (byte[]) value; + parcel.writeInt(valueBytes.length); + parcel.writeByteArray(valueBytes); + } else if (value instanceof ParcelableInputStream) { + parcel.writeInt(PARCELABLE_SENTINEL); + ((ParcelableInputStream) value).writeToParcel(parcel); + } else { + // An InputStream which wasn't created by ParcelableUtils, which means there's another use + // of Metadata.BinaryStreamMarshaller. Just read the bytes. + // + // We know that BlockPool will give us a buffer at least as large as the max space for all + // names and values so it'll certainly be large enough (and the limit is only 8k so this + // is fine). + byte[] buffer = BlockPool.acquireBlock(); + try { + InputStream stream = (InputStream) value; + int total = 0; + while (total < buffer.length) { + int read = stream.read(buffer, total, buffer.length - total); + if (read == -1) { + break; + } + total += read; + } + if (total == buffer.length) { + throw Status.RESOURCE_EXHAUSTED.withDescription("Metadata value too large").asException(); + } + parcel.writeInt(total); + if (total > 0) { + parcel.writeByteArray(buffer, 0, total); + } + } finally { + BlockPool.releaseBlock(buffer); + } + } + } + } + + /** + * Read a Metadata instance from a Parcel. + * + * @param parcel The {@link Parcel} to read from. + */ + public static Metadata readMetadata(Parcel parcel, Attributes attributes) throws StatusException { + int n = parcel.readInt(); + if (n == 0) { + return new Metadata(); + } + // For enforcing the header-size limit. Doesn't include parcelable data. + int bytesRead = 0; + // For enforcing the maximum allowed parcelable data (see InboundParcelablePolicy). + int parcelableBytesRead = 0; + Object[] serialized = new Object[n * 2]; + for (int i = 0; i < n; i++) { + int numNameBytes = parcel.readInt(); + bytesRead += 4; + byte[] name = readBytesChecked(parcel, numNameBytes, bytesRead); + bytesRead += numNameBytes; + serialized[i * 2] = name; + int numValueBytes = parcel.readInt(); + bytesRead += 4; + if (numValueBytes == PARCELABLE_SENTINEL) { + InboundParcelablePolicy policy = attributes.get(BinderTransport.INBOUND_PARCELABLE_POLICY); + if (!policy.shouldAcceptParcelableMetadataValues()) { + throw Status.PERMISSION_DENIED + .withDescription("Parcelable metadata values not allowed") + .asException(); + } + int parcelableStartPos = parcel.dataPosition(); + try { + Parcelable value = parcel.readParcelable(MetadataHelper.class.getClassLoader()); + if (value == null) { + throw Status.INTERNAL.withDescription("Read null parcelable in metadata").asException(); + } + serialized[i * 2 + 1] = InternalMetadata.parsedValue(TRANSPORT_INBOUND_MARSHALLER, value); + } catch (AndroidRuntimeException are) { + throw Status.INTERNAL + .withCause(are) + .withDescription("Failure reading parcelable in metadata") + .asException(); + } + int parcelableSize = parcel.dataPosition() - parcelableStartPos; + parcelableBytesRead += parcelableSize; + if (parcelableBytesRead > policy.getMaxParcelableMetadataSize()) { + throw Status.RESOURCE_EXHAUSTED + .withDescription( + "Inbound Parcelables too large according to policy (see InboundParcelablePolicy)") + .asException(); + } + } else if (numValueBytes < 0) { + throw Status.INTERNAL.withDescription("Unrecognized metadata sentinel").asException(); + } else { + byte[] value = readBytesChecked(parcel, numValueBytes, bytesRead); + bytesRead += numValueBytes; + serialized[i * 2 + 1] = value; + } + } + return InternalMetadata.newMetadataWithParsedValues(n, serialized); + } + + /** Read a byte array checking that we're not reading too much. */ + private static byte[] readBytesChecked( + Parcel parcel, + int numBytes, + int bytesRead) throws StatusException { + if (bytesRead + numBytes > GrpcUtil.DEFAULT_MAX_HEADER_LIST_SIZE) { + throw Status.RESOURCE_EXHAUSTED.withDescription("Metadata too large").asException(); + } + byte[] res = new byte[numBytes]; + if (numBytes > 0) { + parcel.readByteArray(res); + } + return res; + } + + /** A marshaller for passing parcelables in gRPC {@link Metadata} */ + public static final class ParcelableMetadataMarshaller

+ implements Metadata.BinaryStreamMarshaller

{ + + @Nullable private final Parcelable.Creator

creator; + private final boolean immutableType; + + public ParcelableMetadataMarshaller(@Nullable Parcelable.Creator

creator, boolean immutableType) { + this.creator = creator; + this.immutableType = immutableType; + } + + @Override + public InputStream toStream(P value) { + return new ParcelableInputStream<>(creator, value, immutableType); + } + + @Override + @SuppressWarnings("unchecked") + public P parseStream(InputStream stream) { + if (stream instanceof ParcelableInputStream) { + return ((ParcelableInputStream

) stream).getParcelable(); + } else { + throw new UnsupportedOperationException( + "Can't unmarshall a parcelable from a regular byte stream"); + } + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/MultiMessageClientStream.java b/binder/src/main/java/io/grpc/binder/internal/MultiMessageClientStream.java new file mode 100644 index 00000000000..317925e2d0b --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/MultiMessageClientStream.java @@ -0,0 +1,192 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import io.grpc.Attributes; +import io.grpc.Compressor; +import io.grpc.Deadline; +import io.grpc.DecompressorRegistry; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.internal.ClientStream; +import io.grpc.internal.ClientStreamListener; +import io.grpc.internal.InsightBuilder; +import java.io.InputStream; +import javax.annotation.Nonnull; + +/** + * The client side of a single RPC, which sends a stream of request messages. + * + *

An instance of this class is effectively a go-between, receiving messages from the gRPC + * ClientCall instance (via calls on the ClientStream interface we implement), and sending them out + * on the transport, as well as receiving messages from the transport, and passing the resultant + * data back to the gRPC ClientCall instance (via calls on the ClientStreamListener instance we're + * given). + * + *

These two communication directions are largely independent of each other, with the {@link + * Outbound} handling the gRPC to transport direction, and the {@link Inbound} class handling + * transport to gRPC direction. + * + *

Since the Inbound and Outbound halves are largely independent, their state is also + * synchronized independently. + */ +final class MultiMessageClientStream implements ClientStream { + + private final Inbound.ClientInbound inbound; + private final Outbound.ClientOutbound outbound; + private final Attributes attributes; + + MultiMessageClientStream( + Inbound.ClientInbound inbound, Outbound.ClientOutbound outbound, Attributes attributes) { + this.inbound = inbound; + this.outbound = outbound; + this.attributes = attributes; + } + + @Override + public void start(ClientStreamListener listener) { + synchronized (inbound) { + inbound.init(outbound, listener); + } + if (outbound.isReady()) { + listener.onReady(); + try { + synchronized (outbound) { + outbound.send(); + } + } catch (StatusException se) { + synchronized (inbound) { + inbound.closeAbnormal(se.getStatus()); + } + } + } + } + + @Override + public void request(int numMessages) { + synchronized (inbound) { + inbound.requestMessages(numMessages); + } + } + + @Override + public boolean isReady() { + return outbound.isReady(); + } + + @Override + public void writeMessage(InputStream message) { + try { + synchronized (outbound) { + outbound.addMessage(message); + outbound.send(); + } + } catch (StatusException se) { + synchronized (inbound) { + inbound.closeAbnormal(se.getStatus()); + } + } + } + + @Override + public void halfClose() { + try { + synchronized (outbound) { + outbound.sendHalfClose(); + } + } catch (StatusException se) { + synchronized (inbound) { + inbound.closeAbnormal(se.getStatus()); + } + } + } + + @Override + public void cancel(Status status) { + synchronized (inbound) { + inbound.closeOnCancel(status); + } + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + @Override + public final String toString() { + return "MultiMessageClientStream[" + inbound + "/" + outbound + "]"; + } + + // ===================== + // Misc stubbed & unsupported methods. + + @Override + public final void flush() { + // Ignore. + } + + @Override + public final void setCompressor(Compressor compressor) { + // Ignore. + } + + @Override + public final void setMessageCompression(boolean enable) { + // Ignore. + } + + @Override + public void setDeadline(@Nonnull Deadline deadline) { + // Ignore. (Deadlines should still work at a higher level). + } + + @Override + public void setAuthority(String authority) { + // Ignore. + } + + @Override + public void setMaxInboundMessageSize(int maxSize) { + // Ignore. + } + + @Override + public void setMaxOutboundMessageSize(int maxSize) { + // Ignore. + } + + @Override + public void appendTimeoutInsight(InsightBuilder insight) { + // Ignore + } + + @Override + public void setFullStreamDecompression(boolean fullStreamDecompression) { + // Ignore. + } + + @Override + public void setDecompressorRegistry(DecompressorRegistry decompressorRegistry) { + // Ignore. + } + + @Override + public void optimizeForDirectExecutor() { + // Ignore. + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/MultiMessageServerStream.java b/binder/src/main/java/io/grpc/binder/internal/MultiMessageServerStream.java new file mode 100644 index 00000000000..f86cea2fe35 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/MultiMessageServerStream.java @@ -0,0 +1,182 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import io.grpc.Attributes; +import io.grpc.Compressor; +import io.grpc.Decompressor; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerStreamListener; +import io.grpc.internal.StatsTraceContext; +import java.io.InputStream; +import javax.annotation.Nullable; + +/** + * The server side of a single RPC, which sends a stream of response messages. + * + *

An instance of this class is effectively a go-between, receiving messages from the gRPC + * ServerCall instance (via calls on the ServerStream interface we implement), and sending them out + * on the transport, as well as receiving messages from the transport, and passing the resultant + * data back to the gRPC ServerCall instance (via calls on the ServerStreamListener instance we're + * given). + * + *

These two communication directions are largely independent of each other, with the {@link + * Outbound} handling the gRPC to transport direction, and the {@link Inbound} class handling + * transport to gRPC direction. + * + *

Since the Inbound and Outbound halves are largely independent, their state is also + * synchronized independently. + */ +final class MultiMessageServerStream implements ServerStream { + + private final Inbound.ServerInbound inbound; + private final Outbound.ServerOutbound outbound; + private final Attributes attributes; + + MultiMessageServerStream( + Inbound.ServerInbound inbound, Outbound.ServerOutbound outbound, Attributes attributes) { + this.inbound = inbound; + this.outbound = outbound; + this.attributes = attributes; + } + + @Override + public void setListener(ServerStreamListener listener) { + synchronized (inbound) { + inbound.init(outbound, listener); + } + } + + @Override + public boolean isReady() { + return outbound.isReady(); + } + + @Override + public void request(int numMessages) { + synchronized (inbound) { + inbound.requestMessages(numMessages); + } + } + + @Override + public void writeHeaders(Metadata headers) { + try { + synchronized (outbound) { + outbound.sendHeaders(headers); + } + } catch (StatusException se) { + synchronized (inbound) { + inbound.closeAbnormal(se.getStatus()); + } + } + } + + @Override + public void writeMessage(InputStream message) { + try { + synchronized (outbound) { + outbound.addMessage(message); + outbound.send(); + } + } catch (StatusException se) { + synchronized (inbound) { + inbound.closeAbnormal(se.getStatus()); + } + } + } + + @Override + public void close(Status status, Metadata trailers) { + try { + synchronized (outbound) { + outbound.sendClose(status, trailers); + } + synchronized (inbound) { + inbound.onCloseSent(status); + } + } catch (StatusException se) { + synchronized (inbound) { + inbound.closeAbnormal(se.getStatus()); + } + } + } + + @Override + public void cancel(Status status) { + synchronized (inbound) { + inbound.closeOnCancel(status); + } + } + + @Override + public StatsTraceContext statsTraceContext() { + return outbound.getStatsTraceContext(); + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + @Nullable + @Override + public String getAuthority() { + return attributes.get(BinderTransport.SERVER_AUTHORITY); + } + + @Override + public String toString() { + return "MultiMessageServerStream[" + inbound + "/" + outbound + "]"; + } + + // ===================== + // Misc stubbed & unsupported methods. + + @Override + public final void flush() { + // Ignore. + } + + @Override + public final void setCompressor(Compressor compressor) { + // Ignore. + } + + @Override + public final void setMessageCompression(boolean enable) { + // Ignore. + } + + @Override + public void setDecompressor(Decompressor decompressor) { + // Ignore. + } + + @Override + public void optimizeForDirectExecutor() { + // Ignore. + } + + @Override + public int streamId() { + return -1; + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/Outbound.java b/binder/src/main/java/io/grpc/binder/internal/Outbound.java new file mode 100644 index 00000000000..56874619987 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/Outbound.java @@ -0,0 +1,495 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkNotNull; +import static com.google.common.base.Preconditions.checkState; + +import android.os.Parcel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.internal.StatsTraceContext; +import java.io.IOException; +import java.io.InputStream; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; + +/** + * Sends the set of outbound transactions for a single BinderStream (rpc). + * + *

Handles buffering internally for flow control, and splitting large messages into multiple + * transactions where necessary. + * + *

Also handles reporting to the {@link StatsTraceContext}. + * + *

A note on threading: All calls into this class are expected to hold this object as a lock. + * However, since calls from gRPC are serialized already, the only reason we need to care about + * threading is the onTransportReady() call (when flow-control unblocks us). + * + *

To reduce the cost of locking, BinderStream endeavors to make only a single call to this class + * for single-message calls (the most common). + * + *

IMPORTANT: To avoid potential deadlocks, this class may only call unsynchronized + * methods of the BinderTransport class. + */ +abstract class Outbound { + + private final BinderTransport transport; + private final int callId; + private final StatsTraceContext statsTraceContext; + + enum State { + INITIAL, + PREFIX_SENT, + ALL_MESSAGES_SENT, + SUFFIX_SENT, + CLOSED, + } + + /* + * Represents the state of data we've sent in binder transactions. + */ + @GuardedBy("this") + private State outboundState = State.INITIAL; // Represents what we've delivered. + + // ---------------------------------- + // For reporting to StatsTraceContext. + /** Indicates we're ready to send the prefix. */ + private boolean prefixReady; + + @Nullable private InputStream firstMessage; + + @Nullable private Queue messageQueue; + + /** + * Indicates we have everything ready to send the suffix. This implies we have all outgoing + * messages, and any additional data which needs to be send after the last message. (e.g. + * trailers). + */ + private boolean suffixReady; + + /** + * The index of the next transaction we'll send, allowing the receiver to re-assemble out-of-order + * messages. + */ + @GuardedBy("this") + private int transactionIndex; + + // ---------------------------------- + // For reporting to StatsTraceContext. + private int numDeliveredMessages; + private int messageSize; + + private Outbound(BinderTransport transport, int callId, StatsTraceContext statsTraceContext) { + this.transport = transport; + this.callId = callId; + this.statsTraceContext = statsTraceContext; + } + + final StatsTraceContext getStatsTraceContext() { + return statsTraceContext; + } + + /** Call to add a message to be delivered. */ + @GuardedBy("this") + final void addMessage(InputStream message) throws StatusException { + onPrefixReady(); // This is implied. + if (messageQueue != null) { + messageQueue.add(message); + } else if (firstMessage == null) { + firstMessage = message; + } else { + messageQueue = new ConcurrentLinkedQueue<>(); + messageQueue.add(message); + } + } + + @GuardedBy("this") + protected final void onPrefixReady() { + this.prefixReady = true; + } + + @GuardedBy("this") + protected final void onSuffixReady() { + this.suffixReady = true; + } + + // ===================== + // Updates to delivery. + @GuardedBy("this") + private void onOutboundState(State outboundState) { + checkTransition(this.outboundState, outboundState); + this.outboundState = outboundState; + } + + // =================== + // Internals. + @GuardedBy("this") + protected final boolean messageAvailable() { + if (messageQueue != null) { + return !messageQueue.isEmpty(); + } else if (firstMessage != null) { + return numDeliveredMessages == 0; + } else { + return false; + } + } + + @Nullable + @GuardedBy("this") + private final InputStream peekNextMessage() { + if (numDeliveredMessages == 0) { + return firstMessage; + } else if (messageQueue != null) { + return messageQueue.peek(); + } + return null; + } + + @GuardedBy("this") + private final boolean canSend() { + switch (outboundState) { + case INITIAL: + if (!prefixReady) { + return false; + } + break; + case PREFIX_SENT: + // We can only send something if we have messages or the suffix. + // Note that if we have the suffix but no messages in this state, it means we've been closed + // early. + if (!messageAvailable() && !suffixReady) { + return false; + } + break; + case ALL_MESSAGES_SENT: + if (!suffixReady) { + return false; + } + break; + default: + return false; + } + return isReady(); + } + + final boolean isReady() { + return transport.isReady(); + } + + @GuardedBy("this") + final void onTransportReady() throws StatusException { + // The transport has become ready, attempt sending. + send(); + } + + @GuardedBy("this") + final void send() throws StatusException { + while (canSend()) { + try { + sendInternal(); + } catch (StatusException se) { + // Ensure we don't send anything else and rethrow. + onOutboundState(State.CLOSED); + throw se; + } + } + } + + @GuardedBy("this") + @SuppressWarnings("fallthrough") + protected final void sendInternal() throws StatusException { + Parcel parcel = Parcel.obtain(); + int flags = 0; + parcel.writeInt(0); // Placeholder for flags. Will be filled in below. + parcel.writeInt(transactionIndex++); + try { + switch (outboundState) { + case INITIAL: + flags |= TransactionUtils.FLAG_PREFIX; + flags |= writePrefix(parcel); + onOutboundState(State.PREFIX_SENT); + if (!messageAvailable() && !suffixReady) { + break; + } + // Fall-through. + case PREFIX_SENT: + InputStream messageStream = peekNextMessage(); + if (messageStream != null) { + flags |= TransactionUtils.FLAG_MESSAGE_DATA; + flags |= writeMessageData(parcel, messageStream); + } else { + checkState(suffixReady); + } + if (suffixReady && !messageAvailable()) { + onOutboundState(State.ALL_MESSAGES_SENT); + } else { + // There's still more message data to deliver, break out. + break; + } + // Fall-through. + case ALL_MESSAGES_SENT: + flags |= TransactionUtils.FLAG_SUFFIX; + flags |= writeSuffix(parcel); + onOutboundState(State.SUFFIX_SENT); + break; + default: + throw new AssertionError(); + } + TransactionUtils.fillInFlags(parcel, flags); + transport.sendTransaction(callId, parcel); + statsTraceContext.outboundWireSize(parcel.dataSize()); + statsTraceContext.outboundUncompressedSize(parcel.dataSize()); + } catch (IOException e) { + throw Status.INTERNAL.withCause(e).asException(); + } finally { + parcel.recycle(); + } + } + + protected final void unregister() { + transport.unregisterCall(callId); + } + + @Override + public synchronized String toString() { + return getClass().getSimpleName() + + "[S=" + + outboundState + + "/NDM=" + + numDeliveredMessages + + "]"; + } + + /** + * Write prefix data to the given {@link Parcel}. + * + * @param parcel the transaction parcel to write to. + * @return any additional flags to be set on the transaction. + */ + @GuardedBy("this") + protected abstract int writePrefix(Parcel parcel) throws IOException, StatusException; + + /** + * Write suffix data to the given {@link Parcel}. + * + * @param parcel the transaction parcel to write to. + * @return any additional flags to be set on the transaction. + */ + @GuardedBy("this") + protected abstract int writeSuffix(Parcel parcel) throws IOException, StatusException; + + @GuardedBy("this") + private final int writeMessageData(Parcel parcel, InputStream stream) throws IOException { + int flags = 0; + boolean dataRemaining = false; + if (stream instanceof ParcelableInputStream) { + flags |= TransactionUtils.FLAG_MESSAGE_DATA_IS_PARCELABLE; + messageSize = ((ParcelableInputStream) stream).writeToParcel(parcel); + } else { + byte[] block = BlockPool.acquireBlock(); + try { + int size = stream.read(block); + if (size <= 0) { + parcel.writeInt(0); + } else { + parcel.writeInt(size); + parcel.writeByteArray(block, 0, size); + messageSize += size; + if (size == block.length) { + flags |= TransactionUtils.FLAG_MESSAGE_DATA_IS_PARTIAL; + dataRemaining = true; + } + } + } finally { + BlockPool.releaseBlock(block); + } + } + if (!dataRemaining) { + stream.close(); + int index = numDeliveredMessages++; + if (index > 0) { + checkNotNull(messageQueue).poll(); + } + statsTraceContext.outboundMessage(index); + statsTraceContext.outboundMessageSent(index, messageSize, messageSize); + messageSize = 0; + } + return flags; + } + + // ====================================== + // Client-side outbound transactions. + static final class ClientOutbound extends Outbound { + + private final MethodDescriptor method; + private final Metadata headers; + private final StatsTraceContext statsTraceContext; + + ClientOutbound( + BinderTransport transport, + int callId, + MethodDescriptor method, + Metadata headers, + StatsTraceContext statsTraceContext) { + super(transport, callId, statsTraceContext); + this.method = method; + this.headers = headers; + this.statsTraceContext = statsTraceContext; + onPrefixReady(); // Client prefix is available immediately. + } + + @Override + @GuardedBy("this") + protected int writePrefix(Parcel parcel) throws IOException, StatusException { + parcel.writeString(method.getFullMethodName()); + MetadataHelper.writeMetadata(parcel, headers); + statsTraceContext.clientOutboundHeaders(); + if (method.getType().serverSendsOneMessage()) { + return TransactionUtils.FLAG_EXPECT_SINGLE_MESSAGE; + } + return 0; + } + + @GuardedBy("this") + void sendSingleMessageAndHalfClose(@Nullable InputStream singleMessage) throws StatusException { + if (singleMessage != null) { + addMessage(singleMessage); + } + onSuffixReady(); + send(); + } + + @GuardedBy("this") + void sendHalfClose() throws StatusException { + onSuffixReady(); + send(); + } + + @Override + @GuardedBy("this") + protected int writeSuffix(Parcel parcel) throws IOException { + // Client doesn't include anything in the suffix. + return 0; + } + } + + // ====================================== + // Server-side outbound transactions. + static final class ServerOutbound extends Outbound { + @GuardedBy("this") + @Nullable + private Metadata headers; + + @GuardedBy("this") + @Nullable + private Status closeStatus; + + @GuardedBy("this") + @Nullable + private Metadata trailers; + + ServerOutbound(BinderTransport transport, int callId, StatsTraceContext statsTraceContext) { + super(transport, callId, statsTraceContext); + } + + @GuardedBy("this") + void sendHeaders(Metadata headers) throws StatusException { + this.headers = headers; + onPrefixReady(); + send(); + } + + @Override + @GuardedBy("this") + protected int writePrefix(Parcel parcel) throws IOException, StatusException { + MetadataHelper.writeMetadata(parcel, headers); + return 0; + } + + @GuardedBy("this") + void sendSingleMessageAndClose( + @Nullable Metadata pendingHeaders, + @Nullable InputStream pendingSingleMessage, + Status closeStatus, + Metadata trailers) + throws StatusException { + if (this.closeStatus != null) { + return; + } + if (pendingHeaders != null) { + this.headers = pendingHeaders; + } + onPrefixReady(); + if (pendingSingleMessage != null) { + addMessage(pendingSingleMessage); + } + checkState(this.trailers == null); + this.closeStatus = closeStatus; + this.trailers = trailers; + onSuffixReady(); + send(); + } + + @GuardedBy("this") + void sendClose(Status closeStatus, Metadata trailers) throws StatusException { + if (this.closeStatus != null) { + return; + } + checkState(this.trailers == null); + this.closeStatus = closeStatus; + this.trailers = trailers; + onPrefixReady(); + onSuffixReady(); + send(); + } + + @Override + @GuardedBy("this") + protected int writeSuffix(Parcel parcel) throws IOException, StatusException { + int flags = TransactionUtils.writeStatus(parcel, closeStatus); + MetadataHelper.writeMetadata(parcel, trailers); + // TODO: This is an ugly place for this side-effect. + unregister(); + return flags; + } + } + + // ====================================== + // Helper methods. + private static void checkTransition(State current, State next) { + switch (next) { + case PREFIX_SENT: + checkState(current == State.INITIAL); + break; + case ALL_MESSAGES_SENT: + checkState(current == State.PREFIX_SENT); + break; + case SUFFIX_SENT: + checkState(current == State.ALL_MESSAGES_SENT); + break; + case CLOSED: // hah. + break; + default: + throw new AssertionError(); + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/ParcelableInputStream.java b/binder/src/main/java/io/grpc/binder/internal/ParcelableInputStream.java new file mode 100644 index 00000000000..09b8cfc43f9 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/ParcelableInputStream.java @@ -0,0 +1,210 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkNotNull; + +import android.os.Parcel; +import android.os.Parcelable; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import javax.annotation.Nullable; + +/** + * An inputstream to serialize a single Android Parcelable object for gRPC calls, with support for + * serializing to a native Android Parcel, and for when a Parcelable is sent in-process. + * + *

Important: It's not actually possible to marshall a parcelable to raw bytes without + * losing data, since a parcelable may contain file descriptors. While this class does + * support marshalling into bytes, this is only supported for the purposes of debugging/logging, and + * we intentionally don't support unmarshalling back to a parcelable. + * + *

This class really just wraps a Parcelable instance and masquerardes as an inputstream. See + * {@code ProtoLiteUtils} for a similar example of this pattern. + * + *

An instance of this class maybe be created from two sources. + * + *

    + *
  • To wrap a Parcelable instance we plan to send. + *
  • To wrap a Parcelable instance we've just received (and read from a Parcel). + *
+ * + *

In the first case, we expect to serialize to a {@link Parcel}, with a call to {@link + * #writeToParcel}. + * + *

In the second case, we only expect the Parcelable to be fetched (and not re-serialized). + * + *

For in-process gRPC calls, the same InputStream used to send the Parcelable (the first case), + * will also be used to parse the parcelable from the stream, in which case we shortcut serializing + * internally (possibly skipping it entirely if the instance is immutable). + */ +final class ParcelableInputStream

extends InputStream { + @Nullable private final Parcelable.Creator

creator; + private final boolean safeToReturnValue; + private final P value; + + @Nullable InputStream delegateStream; + + @Nullable P sharableValue; + + ParcelableInputStream( + @Nullable Parcelable.Creator

creator, P value, boolean safeToReturnValue) { + this.creator = creator; + this.value = value; + this.safeToReturnValue = safeToReturnValue; + // If we're not given a creator, the value must be safe to return unchanged. + checkArgument(creator != null || safeToReturnValue); + } + + /** + * Create a stream from a {@link Parcel} object. Note that this immediately reads the Parcelable + * object, allowing the Parcel to be recycled after calling this method. + */ + @SuppressWarnings("unchecked") + static

ParcelableInputStream

readFromParcel( + Parcel parcel, ClassLoader classLoader) { + P value = (P) parcel.readParcelable(classLoader); + return new ParcelableInputStream<>(null, value, true); + } + + /** Create a stream for a Parcelable object. */ + static

ParcelableInputStream

forInstance( + P value, Parcelable.Creator

creator) { + return new ParcelableInputStream<>(creator, value, false); + } + + /** Create a stream for a Parcelable object, treating the object as immutable. */ + static

ParcelableInputStream

forImmutableInstance( + P value, Parcelable.Creator

creator) { + return new ParcelableInputStream<>(creator, value, true); + } + + private InputStream getDelegateStream() { + if (delegateStream == null) { + Parcel parcel = Parcel.obtain(); + parcel.writeParcelable(value, 0); + byte[] res = parcel.marshall(); + parcel.recycle(); + delegateStream = new ByteArrayInputStream(res); + } + return delegateStream; + } + + @Override + public int read() throws IOException { + return getDelegateStream().read(); + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + return getDelegateStream().read(b, off, len); + } + + @Override + public long skip(long n) throws IOException { + if (n <= 0) { + return 0; + } + return getDelegateStream().skip(n); + } + + @Override + public int available() throws IOException { + return getDelegateStream().available(); + } + + @Override + public void close() throws IOException { + if (delegateStream != null) { + delegateStream.close(); + } + } + + @Override + public void mark(int readLimit) { + // If there's no delegate stream yet, the current position is 0. That's the same + // as the default mark position, so there's nothing to do. + if (delegateStream != null) { + delegateStream.mark(readLimit); + } + } + + @Override + public void reset() throws IOException { + if (delegateStream != null) { + delegateStream.reset(); + } + } + + @Override + public boolean markSupported() { + // We know our delegate (ByteArrayInputStream) supports mark/reset. + return true; + } + + /** + * Write the {@link Parcelable} this stream wraps to the given {@link Parcel}. + * + *

This will retain any android-specific data (e.g. file descriptors) which can't simply be + * serialized to bytes. + * + * @return The number of bytes written to the parcel. + */ + int writeToParcel(Parcel parcel) { + int startPos = parcel.dataPosition(); + parcel.writeParcelable(value, value.describeContents()); + return parcel.dataPosition() - startPos; + } + + /** + * Get the parcelable as if it had been serialized/de-serialized. + * + *

If the parcelable is immutable, or it was already de-serialized from a Parcel (I.e. this + * instance was created with #readFromParcel), the value will be returned directly. + */ + P getParcelable() { + if (safeToReturnValue) { + // We can just return the value directly. + return value; + } else { + // We need to serialize/de-serialize to a parcel internally. + if (sharableValue == null) { + sharableValue = marshallUnmarshall(value, checkNotNull(creator)); + } + return sharableValue; + } + } + + private static

P marshallUnmarshall( + P value, Parcelable.Creator

creator) { + // Serialize/de-serialize the object directly instead of using Parcel.writeParcelable, + // since there's no need to write out the class name. + Parcel parcel = Parcel.obtain(); + value.writeToParcel(parcel, 0); + parcel.setDataPosition(0); + P result = creator.createFromParcel(parcel); + parcel.recycle(); + return result; + } + + @Override + public String toString() { + return "ParcelableInputStream[V: " + value + "]"; + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/PingTracker.java b/binder/src/main/java/io/grpc/binder/internal/PingTracker.java new file mode 100644 index 00000000000..911412d0b66 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/PingTracker.java @@ -0,0 +1,114 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.internal.ClientTransport.PingCallback; +import io.grpc.internal.TimeProvider; +import java.util.concurrent.Executor; +import javax.annotation.Nullable; +import javax.annotation.concurrent.GuardedBy; + +/** + * Tracks an ongoing ping request for a client-side binder transport. We only handle a single active + * ping at a time, since that's all gRPC appears to need. + */ +final class PingTracker { + + interface PingSender { + /** + * Send a ping to the remote endpoint. We expect a subsequent call to {@link #onPingResponse} + * with the same ID (assuming the ping succeeds). + */ + void sendPing(int id) throws StatusException; + } + + private final TimeProvider timeProvider; + private final PingSender pingSender; + + @GuardedBy("this") + @Nullable + private Ping pendingPing; + + @GuardedBy("this") + private int nextPingId; + + PingTracker(TimeProvider timeProvider, PingSender pingSender) { + this.timeProvider = timeProvider; + this.pingSender = pingSender; + } + + /** + * Start a ping. + * + *

See also {@link ClientTransport#ping}. + * + * @param callback The callback to report the ping result on. + * @param executor An executor to call callbacks on. + *

Note that only one ping callback will be active at a time. + */ + synchronized void startPing(PingCallback callback, Executor executor) { + pendingPing = new Ping(callback, executor, nextPingId++); + try { + pingSender.sendPing(pendingPing.id); + } catch (StatusException se) { + pendingPing.fail(se.getStatus()); + pendingPing = null; + } + } + + /** Callback when a ping response with the given ID is received. */ + synchronized void onPingResponse(int id) { + if (pendingPing != null && pendingPing.id == id) { + pendingPing.success(); + pendingPing = null; + } + } + + private final class Ping { + private final PingCallback callback; + private final Executor executor; + private final int id; + private final long startTimeNanos; + + @GuardedBy("this") + private boolean done; + + Ping(PingCallback callback, Executor executor, int id) { + this.callback = callback; + this.executor = executor; + this.id = id; + this.startTimeNanos = timeProvider.currentTimeNanos(); + } + + private synchronized void fail(Status status) { + if (!done) { + done = true; + executor.execute(() -> callback.onFailure(status.asException())); + } + } + + private synchronized void success() { + if (!done) { + done = true; + executor.execute( + () -> callback.onSuccess(timeProvider.currentTimeNanos() - startTimeNanos)); + } + } + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/SingleMessageClientStream.java b/binder/src/main/java/io/grpc/binder/internal/SingleMessageClientStream.java new file mode 100644 index 00000000000..9a1fa0ee6e4 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/SingleMessageClientStream.java @@ -0,0 +1,183 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import io.grpc.Attributes; +import io.grpc.Compressor; +import io.grpc.Deadline; +import io.grpc.DecompressorRegistry; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.internal.ClientStream; +import io.grpc.internal.ClientStreamListener; +import io.grpc.internal.InsightBuilder; +import java.io.InputStream; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +/** + * The client side of a single RPC, which sends a single request message. + * + *

An instance of this class is effectively a go-between, receiving messages from the gRPC + * ClientCall instance (via calls on the ClientStream interface we implement), and sending them out + * on the transport, as well as receiving messages from the transport, and passing the resultant + * data back to the gRPC ClientCall instance (via calls on the ClientStreamListener instance we're + * given). + * + *

These two communication directions are largely independent of each other, with the {@link + * Outbound} handling the gRPC to transport direction, and the {@link Inbound} class handling + * transport to gRPC direction. + * + *

Since the Inbound and Outbound halves are largely independent, their state is also + * synchronized independently. + */ +final class SingleMessageClientStream implements ClientStream { + + private final Inbound.ClientInbound inbound; + private final Outbound.ClientOutbound outbound; + private final Attributes attributes; + + @Nullable private InputStream pendingSingleMessage; + + SingleMessageClientStream( + Inbound.ClientInbound inbound, Outbound.ClientOutbound outbound, Attributes attributes) { + this.inbound = inbound; + this.outbound = outbound; + this.attributes = attributes; + } + + @Override + public void start(ClientStreamListener listener) { + synchronized (inbound) { + inbound.init(outbound, listener); + } + if (outbound.isReady()) { + listener.onReady(); + } + } + + @Override + public boolean isReady() { + return outbound.isReady(); + } + + @Override + public void request(int numMessages) { + synchronized (inbound) { + inbound.requestMessages(numMessages); + } + } + + @Override + public void writeMessage(InputStream message) { + if (pendingSingleMessage != null) { + synchronized (inbound) { + inbound.closeAbnormal(Status.INTERNAL.withDescription("too many messages")); + } + } else { + pendingSingleMessage = message; + } + } + + @Override + public void halfClose() { + try { + synchronized (outbound) { + outbound.sendSingleMessageAndHalfClose(pendingSingleMessage); + } + } catch (StatusException se) { + synchronized (inbound) { + inbound.closeAbnormal(se.getStatus()); + } + } + } + + @Override + public void cancel(Status status) { + synchronized (inbound) { + inbound.closeOnCancel(status); + } + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + @Override + public final String toString() { + return "SingleMessageClientStream[" + inbound + "/" + outbound + "]"; + } + + // ===================== + // Misc stubbed & unsupported methods. + + @Override + public final void flush() { + // Ignore. + } + + @Override + public final void setCompressor(Compressor compressor) { + // Ignore. + } + + @Override + public final void setMessageCompression(boolean enable) { + // Ignore. + } + + @Override + public void setDeadline(@Nonnull Deadline deadline) { + // Ignore. (Deadlines should still work at a higher level). + } + + @Override + public void setAuthority(String authority) { + // Ignore. + } + + @Override + public void setMaxInboundMessageSize(int maxSize) { + // Ignore. + } + + @Override + public void setMaxOutboundMessageSize(int maxSize) { + // Ignore. + } + + @Override + public void appendTimeoutInsight(InsightBuilder insight) { + // Ignore + } + + @Override + public void setFullStreamDecompression(boolean fullStreamDecompression) { + // Ignore. + } + + @Override + public void setDecompressorRegistry(DecompressorRegistry decompressorRegistry) { + // Ignore. + } + + @Override + public void optimizeForDirectExecutor() { + // Ignore. + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/SingleMessageServerStream.java b/binder/src/main/java/io/grpc/binder/internal/SingleMessageServerStream.java new file mode 100644 index 00000000000..ac7a76054e7 --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/SingleMessageServerStream.java @@ -0,0 +1,174 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import io.grpc.Attributes; +import io.grpc.Compressor; +import io.grpc.Decompressor; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerStreamListener; +import io.grpc.internal.StatsTraceContext; +import java.io.InputStream; +import javax.annotation.Nullable; + +/** + * The server side of a single RPC, which sends a single response message. + * + *

An instance of this class is effectively a go-between, receiving messages from the gRPC + * ServerCall instance (via calls on the ServerStream interface we implement), and sending them out + * on the transport, as well as receiving messages from the transport, and passing the resultant + * data back to the gRPC ServerCall instance (via calls on the ServerStreamListener instance we're + * given). + * + *

These two communication directions are largely independent of each other, with the {@link + * Outbound} handling the gRPC to transport direction, and the {@link Inbound} class handling + * transport to gRPC direction. + * + *

Since the Inbound and Outbound halves are largely independent, their state is also + * synchronized independently. + */ +final class SingleMessageServerStream implements ServerStream { + + private final Inbound.ServerInbound inbound; + private final Outbound.ServerOutbound outbound; + private final Attributes attributes; + + @Nullable private Metadata pendingHeaders; + @Nullable private InputStream pendingSingleMessage; + + SingleMessageServerStream( + Inbound.ServerInbound inbound, Outbound.ServerOutbound outbound, Attributes attributes) { + this.inbound = inbound; + this.outbound = outbound; + this.attributes = attributes; + } + + @Override + public void setListener(ServerStreamListener listener) { + synchronized (inbound) { + inbound.init(outbound, listener); + } + } + + @Override + public boolean isReady() { + return outbound.isReady(); + } + + @Override + public void request(int numMessages) { + synchronized (inbound) { + inbound.requestMessages(numMessages); + } + } + + @Override + public void writeHeaders(Metadata headers) { + pendingHeaders = headers; + } + + @Override + public void writeMessage(InputStream message) { + if (pendingSingleMessage != null) { + synchronized (inbound) { + inbound.closeAbnormal(Status.INTERNAL.withDescription("too many messages")); + } + } else { + pendingSingleMessage = message; + } + } + + @Override + public void close(Status status, Metadata trailers) { + try { + synchronized (outbound) { + outbound.sendSingleMessageAndClose(pendingHeaders, pendingSingleMessage, status, trailers); + } + synchronized (inbound) { + inbound.onCloseSent(status); + } + } catch (StatusException se) { + synchronized (inbound) { + inbound.closeAbnormal(se.getStatus()); + } + } + } + + @Override + public void cancel(Status status) { + synchronized (inbound) { + inbound.closeOnCancel(status); + } + } + + @Override + public StatsTraceContext statsTraceContext() { + return outbound.getStatsTraceContext(); + } + + @Override + public Attributes getAttributes() { + return attributes; + } + + @Nullable + @Override + public String getAuthority() { + return attributes.get(BinderTransport.SERVER_AUTHORITY); + } + + @Override + public String toString() { + return "SingleMessageServerStream[" + inbound + "/" + outbound + "]"; + } + + // ===================== + // Misc stubbed & unsupported methods. + + @Override + public final void flush() { + // Ignore. + } + + @Override + public final void setCompressor(Compressor compressor) { + // Ignore. + } + + @Override + public final void setMessageCompression(boolean enable) { + // Ignore. + } + + @Override + public void setDecompressor(Decompressor decompressor) { + // Ignore. + } + + @Override + public void optimizeForDirectExecutor() { + // Ignore. + } + + @Override + public int streamId() { + return -1; + } +} diff --git a/binder/src/main/java/io/grpc/binder/internal/TransactionUtils.java b/binder/src/main/java/io/grpc/binder/internal/TransactionUtils.java new file mode 100644 index 00000000000..91f7fb8028f --- /dev/null +++ b/binder/src/main/java/io/grpc/binder/internal/TransactionUtils.java @@ -0,0 +1,99 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import android.os.Parcel; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.Status; +import javax.annotation.Nullable; + +/** Constants and helpers for managing inbound / outbound transactions. */ +final class TransactionUtils { + /** Set when the transaction contains rpc prefix data. */ + static final int FLAG_PREFIX = 0x1; + /** Set when the transaction contains some message data. */ + static final int FLAG_MESSAGE_DATA = 0x2; + /** Set when the transaction contains rpc suffix data. */ + static final int FLAG_SUFFIX = 0x4; + /** Set when the transaction is an out-of-band close event. */ + static final int FLAG_OUT_OF_BAND_CLOSE = 0x8; + + /** + * When a transaction contains client prefix data, this will be set if the rpc being made is + * expected to return a single message. (I.e the method type is either {@link MethodType#UNARY}, + * or {@link MethodType#CLIENT_STREAMING}). + */ + static final int FLAG_EXPECT_SINGLE_MESSAGE = 0x10; + + /** Set when the included status data includes a description string. */ + static final int FLAG_STATUS_DESCRIPTION = 0x20; + + /** When a transaction contains message data, this will be set if the message is a parcelable. */ + static final int FLAG_MESSAGE_DATA_IS_PARCELABLE = 0x40; + + /** + * When a transaction contains message data, this will be set if the message is only partial, and + * further transactions are required. + */ + static final int FLAG_MESSAGE_DATA_IS_PARTIAL = 0x80; + + static final int STATUS_CODE_SHIFT = 16; + static final int STATUS_CODE_MASK = 0xff0000; + + /** The maximum string length for a status description. */ + private static final int MAX_STATUS_DESCRIPTION_LENGTH = 1000; + + private TransactionUtils() {} + + static boolean hasFlag(int flags, int flag) { + return (flags & flag) != 0; + } + + @Nullable + private static String getTruncatedDescription(Status status) { + String desc = status.getDescription(); + if (desc != null && desc.length() > MAX_STATUS_DESCRIPTION_LENGTH) { + desc = desc.substring(0, MAX_STATUS_DESCRIPTION_LENGTH); + } + return desc; + } + + static Status readStatus(int flags, Parcel parcel) { + Status status = Status.fromCodeValue((flags & STATUS_CODE_MASK) >> STATUS_CODE_SHIFT); + if ((flags & FLAG_STATUS_DESCRIPTION) != 0) { + status = status.withDescription(parcel.readString()); + } + return status; + } + + static int writeStatus(Parcel parcel, Status status) { + int flags = status.getCode().value() << STATUS_CODE_SHIFT; + String desc = getTruncatedDescription(status); + if (desc != null) { + flags |= FLAG_STATUS_DESCRIPTION; + parcel.writeString(desc); + } + return flags; + } + + static void fillInFlags(Parcel parcel, int flags) { + int pos = parcel.dataPosition(); + parcel.setDataPosition(0); + parcel.writeInt(flags); + parcel.setDataPosition(pos); + } +} diff --git a/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java b/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java new file mode 100644 index 00000000000..6d15907c7b9 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/AndroidComponentAddressTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import static com.google.common.truth.Truth.assertThat; + +import android.content.ComponentName; +import android.content.Context; +import androidx.test.core.app.ApplicationProvider; +import com.google.common.testing.EqualsTester; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; + +@RunWith(RobolectricTestRunner.class) +public final class AndroidComponentAddressTest { + + private final Context appContext = ApplicationProvider.getApplicationContext(); + private final ComponentName hostComponent = new ComponentName(appContext, appContext.getClass()); + + @Test + public void testAuthority() { + AndroidComponentAddress addr = AndroidComponentAddress.forContext(appContext); + assertThat(addr.getAuthority()).isEqualTo(appContext.getPackageName()); + } + + @Test + public void testComponent() { + AndroidComponentAddress addr = AndroidComponentAddress.forComponent(hostComponent); + assertThat(addr.getComponent()).isSameInstanceAs(hostComponent); + } + + @Test + public void testEquality() { + new EqualsTester() + .addEqualityGroup( + AndroidComponentAddress.forComponent(hostComponent), + AndroidComponentAddress.forContext(appContext), + AndroidComponentAddress.forLocalComponent(appContext, appContext.getClass()), + AndroidComponentAddress.forRemoteComponent( + appContext.getPackageName(), appContext.getClass().getName())) + .addEqualityGroup( + AndroidComponentAddress.forRemoteComponent("appy.mcappface", ".McActivity")) + .addEqualityGroup(AndroidComponentAddress.forLocalComponent(appContext, getClass())) + .testEquals(); + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/BindServiceFlagsTest.java b/binder/src/test/java/io/grpc/binder/BindServiceFlagsTest.java similarity index 98% rename from binder/src/test/java/io/grpc/binder/internal/BindServiceFlagsTest.java rename to binder/src/test/java/io/grpc/binder/BindServiceFlagsTest.java index 7bdf8a3675c..0aebb79ee8e 100644 --- a/binder/src/test/java/io/grpc/binder/internal/BindServiceFlagsTest.java +++ b/binder/src/test/java/io/grpc/binder/BindServiceFlagsTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.grpc.binder.internal; +package io.grpc.binder; import static com.google.common.truth.Truth.assertThat; diff --git a/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java b/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java new file mode 100644 index 00000000000..604addb3b09 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/SecurityPoliciesTest.java @@ -0,0 +1,62 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import static com.google.common.truth.Truth.assertThat; + +import io.grpc.Status; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.robolectric.shadows.ShadowProcess; + +@RunWith(RobolectricTestRunner.class) +public final class SecurityPoliciesTest { + private static final int MY_UID = 1234; + private static final int OTHER_UID = MY_UID + 1; + + private static final String PERMISSION_DENIED_REASONS = "some reasons"; + + private SecurityPolicy policy; + + @Before + public void setUp() { + ShadowProcess.setUid(MY_UID); + } + + @Test + public void testInternalOnly() throws Exception { + policy = SecurityPolicies.internalOnly(); + assertThat(policy.checkAuthorization(MY_UID).getCode()).isEqualTo(Status.OK.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + @Test + public void testPermissionDenied() throws Exception { + policy = SecurityPolicies.permissionDenied(PERMISSION_DENIED_REASONS); + assertThat(policy.checkAuthorization(MY_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(MY_UID).getDescription()) + .isEqualTo(PERMISSION_DENIED_REASONS); + assertThat(policy.checkAuthorization(OTHER_UID).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorization(OTHER_UID).getDescription()) + .isEqualTo(PERMISSION_DENIED_REASONS); + } +} diff --git a/binder/src/test/java/io/grpc/binder/ServerSecurityPolicyTest.java b/binder/src/test/java/io/grpc/binder/ServerSecurityPolicyTest.java new file mode 100644 index 00000000000..8c61b6119b1 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/ServerSecurityPolicyTest.java @@ -0,0 +1,126 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.base.Function; +import io.grpc.Status; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; +import org.robolectric.shadows.ShadowProcess; + +@RunWith(RobolectricTestRunner.class) +public final class ServerSecurityPolicyTest { + + private static final String SERVICE1 = "service_one"; + private static final String SERVICE2 = "service_two"; + private static final String SERVICE3 = "service_three"; + + private static final int MY_UID = 1234; + private static final int OTHER_UID = MY_UID + 1; + + ServerSecurityPolicy policy; + + @Before + public void setUp() { + ShadowProcess.setUid(MY_UID); + } + + @Test + public void testDefaultInternalOnly() { + policy = new ServerSecurityPolicy(); + assertThat(policy.checkAuthorizationForService(MY_UID, SERVICE1).getCode()) + .isEqualTo(Status.OK.getCode()); + assertThat(policy.checkAuthorizationForService(MY_UID, SERVICE2).getCode()) + .isEqualTo(Status.OK.getCode()); + } + + @Test + public void testInternalOnly_AnotherUid() { + policy = new ServerSecurityPolicy(); + assertThat(policy.checkAuthorizationForService(OTHER_UID, SERVICE1).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorizationForService(OTHER_UID, SERVICE2).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + @Test + public void testBuilderDefault() { + policy = ServerSecurityPolicy.newBuilder().build(); + assertThat(policy.checkAuthorizationForService(MY_UID, SERVICE1).getCode()) + .isEqualTo(Status.OK.getCode()); + assertThat(policy.checkAuthorizationForService(OTHER_UID, SERVICE1).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + @Test + public void testPerService() { + policy = + ServerSecurityPolicy.newBuilder() + .servicePolicy(SERVICE2, policy((uid) -> Status.OK)) + .build(); + + assertThat(policy.checkAuthorizationForService(MY_UID, SERVICE1).getCode()) + .isEqualTo(Status.OK.getCode()); + assertThat(policy.checkAuthorizationForService(OTHER_UID, SERVICE1).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorizationForService(MY_UID, SERVICE2).getCode()) + .isEqualTo(Status.OK.getCode()); + assertThat(policy.checkAuthorizationForService(OTHER_UID, SERVICE2).getCode()) + .isEqualTo(Status.OK.getCode()); + } + + @Test + public void testPerServiceNoDefault() { + policy = + ServerSecurityPolicy.newBuilder() + .servicePolicy(SERVICE1, policy((uid) -> Status.INTERNAL)) + .servicePolicy( + SERVICE2, policy((uid) -> uid == OTHER_UID ? Status.OK : Status.PERMISSION_DENIED)) + .build(); + + // Uses the specified policy for service1. + assertThat(policy.checkAuthorizationForService(MY_UID, SERVICE1).getCode()) + .isEqualTo(Status.INTERNAL.getCode()); + assertThat(policy.checkAuthorizationForService(OTHER_UID, SERVICE1).getCode()) + .isEqualTo(Status.INTERNAL.getCode()); + + // Uses the specified policy for service2. + assertThat(policy.checkAuthorizationForService(MY_UID, SERVICE2).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(policy.checkAuthorizationForService(OTHER_UID, SERVICE2).getCode()) + .isEqualTo(Status.OK.getCode()); + + // Falls back to the default. + assertThat(policy.checkAuthorizationForService(MY_UID, SERVICE3).getCode()) + .isEqualTo(Status.OK.getCode()); + assertThat(policy.checkAuthorizationForService(OTHER_UID, SERVICE3).getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + } + + private static SecurityPolicy policy(Function func) { + return new SecurityPolicy() { + @Override + public Status checkAuthorization(int uid) { + return func.apply(uid); + } + }; + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/BinderServerTransportTest.java b/binder/src/test/java/io/grpc/binder/internal/BinderServerTransportTest.java new file mode 100644 index 00000000000..76cf9b6a6ac --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/BinderServerTransportTest.java @@ -0,0 +1,116 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.when; +import static org.robolectric.annotation.LooperMode.Mode.PAUSED; + +import android.os.IBinder; +import android.os.Parcel; +import com.google.common.collect.ImmutableList; +import com.google.common.util.concurrent.testing.TestingExecutors; +import io.grpc.Attributes; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.internal.FixedObjectPool; +import io.grpc.internal.ServerStream; +import io.grpc.internal.ServerTransportListener; +import java.util.concurrent.ScheduledExecutorService; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.robolectric.RobolectricTestRunner; +import org.robolectric.annotation.LooperMode; + +/** + * Low-level server-side transport tests for binder channel. Like BinderChannelSmokeTest, this + * convers edge cases not exercised by AbstractTransportTest, but it deals with the + * binderTransport.BinderServerTransport directly. + */ +@LooperMode(PAUSED) +@RunWith(RobolectricTestRunner.class) +public final class BinderServerTransportTest { + + @Rule public MockitoRule mocks = MockitoJUnit.rule(); + + private final ScheduledExecutorService executorService = + TestingExecutors.sameThreadScheduledExecutor(); + private final TestTransportListener transportListener = new TestTransportListener(); + + @Mock IBinder mockBinder; + + BinderTransport.BinderServerTransport transport; + + @Before + public void setUp() throws Exception { + transport = + new BinderTransport.BinderServerTransport( + new FixedObjectPool<>(executorService), + Attributes.EMPTY, + ImmutableList.of(), + mockBinder); + } + + @Test + public void testSetupTransactionFailureCausesMultipleShutdowns_b153460678() throws Exception { + // Make the binder fail the setup transaction. + when(mockBinder.transact(anyInt(), any(Parcel.class), isNull(), anyInt())).thenReturn(false); + transport.setServerTransportListener(transportListener); + + // Now shut it down. + transport.shutdownNow(Status.UNKNOWN.withDescription("reasons")); + + assertThat(transportListener.terminated).isTrue(); + } + + private static final class TestTransportListener implements ServerTransportListener { + + public boolean ready; + public boolean terminated; + + /** + * Called when a new stream was created by the remote client. + * + * @param stream the newly created stream. + * @param method the fully qualified method name being called on the server. + * @param headers containing metadata for the call. + */ + @Override + public void streamCreated(ServerStream stream, String method, Metadata headers) {} + + @Override + public Attributes transportReady(Attributes attributes) { + ready = true; + return attributes; + } + + @Override + public void transportTerminated() { + checkState(!terminated, "Terminated twice"); + terminated = true; + } + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/BlockInputStreamTest.java b/binder/src/test/java/io/grpc/binder/internal/BlockInputStreamTest.java new file mode 100644 index 00000000000..5a0279fea22 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/BlockInputStreamTest.java @@ -0,0 +1,129 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.truth.Truth.assertThat; + +import java.io.ByteArrayOutputStream; +import java.util.Arrays; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class BlockInputStreamTest { + + private final byte[] buff = new byte[1024]; + + @Test + public void testNoBytes() throws Exception { + try (BlockInputStream bis = new BlockInputStream(new byte[0])) { + assertThat(bis.read()).isEqualTo(-1); + } + } + + @Test + public void testNoBlocks() throws Exception { + try (BlockInputStream bis = new BlockInputStream(new byte[0][], 0)) { + assertThat(bis.read()).isEqualTo(-1); + } + } + + @Test + public void testSingleBlock() throws Exception { + BlockInputStream bis = + new BlockInputStream(new byte[][] {getBytes(10, 1)}, 10); + assertThat(bis.read(buff, 0, 20)).isEqualTo(10); + assertBytes(buff, 0, 10, 1); + } + + @Test + public void testMultipleBlocks() throws Exception { + BlockInputStream bis = + new BlockInputStream(new byte[][] {getBytes(10, 1), getBytes(10, 2)}, 20); + assertThat(bis.read(buff, 0, 20)).isEqualTo(20); + assertBytes(buff, 0, 10, 1); + assertBytes(buff, 10, 10, 2); + } + + @Test + public void testMultipleBlocks_drain() throws Exception { + BlockInputStream bis = + new BlockInputStream(new byte[][] {getBytes(10, 1), getBytes(10, 2)}, 20); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + bis.drainTo(baos); + byte[] data = baos.toByteArray(); + assertThat(data).hasLength(20); + assertBytes(data, 0, 10, 1); + assertBytes(data, 10, 10, 2); + } + + @Test + public void testMultipleBlocksLessData() throws Exception { + BlockInputStream bis = + new BlockInputStream(new byte[][] {getBytes(10, 1), getBytes(10, 2)}, 15); + assertThat(bis.read(buff, 0, 20)).isEqualTo(15); + assertBytes(buff, 0, 10, 1); + assertBytes(buff, 10, 5, 2); + } + + @Test + public void testMultipleBlocksLessData_drain() throws Exception { + BlockInputStream bis = + new BlockInputStream(new byte[][] {getBytes(10, 1), getBytes(10, 2)}, 15); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + bis.drainTo(baos); + byte[] data = baos.toByteArray(); + assertThat(data).hasLength(15); + assertBytes(data, 0, 10, 1); + assertBytes(data, 10, 5, 2); + } + + @Test + public void testMultipleBlocksEmptyFinalBlock() throws Exception { + BlockInputStream bis = + new BlockInputStream(new byte[][] {getBytes(10, 1), getBytes(0, 0)}, 10); + + assertThat(bis.read(buff, 0, 20)).isEqualTo(10); + assertBytes(buff, 0, 10, 1); + assertThat(bis.read(buff, 0, 20)).isEqualTo(-1); + assertThat(bis.read()).isEqualTo(-1); + } + + @Test + public void testMultipleBlocksEmptyFinalBlock_drain() throws Exception { + BlockInputStream bis = + new BlockInputStream(new byte[][] {getBytes(10, 1), getBytes(0, 0)}, 10); + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + bis.drainTo(baos); + byte[] data = baos.toByteArray(); + assertThat(data).hasLength(10); + assertBytes(data, 0, 10, 1); + } + + private static byte[] getBytes(int size, int val) { + byte[] res = new byte[size]; + Arrays.fill(res, 0, size, (byte) val); + return res; + } + + private static void assertBytes(byte[] data, int off, int len, int val) { + for (int i = off; i < off + len; i++) { + assertThat(data[i]).isEqualTo((byte) val); + } + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/BoundClientAddressTest.java b/binder/src/test/java/io/grpc/binder/internal/BoundClientAddressTest.java new file mode 100644 index 00000000000..9dc6cbc6938 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/BoundClientAddressTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import com.google.common.testing.EqualsTester; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; + +@RunWith(RobolectricTestRunner.class) +public final class BoundClientAddressTest { + + private static final int MY_UID = 1234; + private static final int OTHER_UID = 1235; + private static final int OTHER_UID_2 = 1236; + + @Test + public void testEquality() { + new EqualsTester() + .addEqualityGroup(new BoundClientAddress(MY_UID), new BoundClientAddress(MY_UID)) + .addEqualityGroup(new BoundClientAddress(OTHER_UID)) + .addEqualityGroup(new BoundClientAddress(OTHER_UID_2)) + .testEquals(); + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/ParcelableInputStreamTest.java b/binder/src/test/java/io/grpc/binder/internal/ParcelableInputStreamTest.java new file mode 100644 index 00000000000..bf90e21d046 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/ParcelableInputStreamTest.java @@ -0,0 +1,127 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.truth.Truth.assertThat; + +import android.os.Parcel; +import android.os.Parcelable; +import com.google.common.io.ByteStreams; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.robolectric.RobolectricTestRunner; + +@RunWith(RobolectricTestRunner.class) +public final class ParcelableInputStreamTest { + + private final TestParcelable testParcelable = new TestParcelable("testing"); + private final TestParcelable testParcelableWithFds = + new TestParcelable("testing_with_fds", Parcelable.CONTENTS_FILE_DESCRIPTOR); + + @Test + public void testGetParcelable() throws Exception { + ParcelableInputStream stream = + ParcelableInputStream.forInstance(testParcelable, TestParcelable.CREATOR); + + // We should serialize/deserialize the parcelable. + TestParcelable parceable = stream.getParcelable(); + assertThat(parceable).isEqualTo(testParcelable); + assertThat(parceable).isNotSameInstanceAs(testParcelable); + + // But just once. + assertThat(stream.getParcelable()).isSameInstanceAs(parceable); + } + + @Test + public void testGetParcelableWithFds() throws Exception { + ParcelableInputStream stream = + ParcelableInputStream.forInstance(testParcelableWithFds, TestParcelable.CREATOR); + + // We should serialize/deserialize the parcelable. + TestParcelable parceable = stream.getParcelable(); + assertThat(parceable).isEqualTo(testParcelableWithFds); + assertThat(parceable).isNotSameInstanceAs(testParcelableWithFds); + + // But just once. + assertThat(stream.getParcelable()).isSameInstanceAs(parceable); + } + + @Test + public void testGetParcelableImmutable() throws Exception { + ParcelableInputStream stream = + ParcelableInputStream.forImmutableInstance(testParcelable, TestParcelable.CREATOR); + + // We should return the parcelable directly. + TestParcelable parceable = stream.getParcelable(); + assertThat(parceable).isSameInstanceAs(testParcelable); + } + + @Test + public void testGetParcelableImmutableWithFds() throws Exception { + ParcelableInputStream stream = + ParcelableInputStream.forImmutableInstance(testParcelableWithFds, TestParcelable.CREATOR); + + // We should return the parcelable directly. + TestParcelable parceable = stream.getParcelable(); + assertThat(parceable).isSameInstanceAs(testParcelableWithFds); + } + + @Test + public void testWriteToParcel() throws Exception { + ParcelableInputStream stream = + ParcelableInputStream.forImmutableInstance(testParcelable, TestParcelable.CREATOR); + Parcel parcel = Parcel.obtain(); + stream.writeToParcel(parcel); + + parcel.setDataPosition(0); + assertThat((TestParcelable) parcel.readParcelable(getClass().getClassLoader())) + .isEqualTo(testParcelable); + } + + @Test + public void testCreateFromParcel() throws Exception { + Parcel parcel = Parcel.obtain(); + parcel.writeParcelable(testParcelable, 0); + parcel.setDataPosition(0); + + ParcelableInputStream stream = + ParcelableInputStream.readFromParcel(parcel, getClass().getClassLoader()); + assertThat(stream.getParcelable()).isEqualTo(testParcelable); + } + + @Test + public void testAsRegularInputStream() throws Exception { + ParcelableInputStream stream = + ParcelableInputStream.forInstance(testParcelable, TestParcelable.CREATOR); + byte[] data = ByteStreams.toByteArray(stream); + + Parcel parcel = Parcel.obtain(); + parcel.unmarshall(data, 0, data.length); + parcel.setDataPosition(0); + + assertThat((TestParcelable) parcel.readParcelable(getClass().getClassLoader())) + .isEqualTo(testParcelable); + } + + @Test + public void testAsRegularInputStreamFds() throws Exception { + ParcelableInputStream stream = + ParcelableInputStream.forInstance(testParcelableWithFds, TestParcelable.CREATOR); + byte[] data = ByteStreams.toByteArray(stream); + assertThat(data.length).isNotEqualTo(0); + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java b/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java new file mode 100644 index 00000000000..e17734baab8 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/PingTrackerTest.java @@ -0,0 +1,142 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import static com.google.common.truth.Truth.assertThat; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static java.util.concurrent.TimeUnit.SECONDS; + +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.internal.ClientTransport; +import io.grpc.internal.FakeClock; +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import javax.annotation.Nullable; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class PingTrackerTest { + + private final FakeClock clock = new FakeClock(); + + @Nullable private Status pingFailureStatus; + private List sentPings; + + private TestCallback callback; + private PingTracker pingTracker; + + @Before + public void setUp() { + sentPings = new ArrayList<>(); + callback = new TestCallback(); + pingTracker = + new PingTracker( + clock.getTimeProvider(), + (id) -> { + sentPings.add(id); + if (pingFailureStatus != null) { + throw pingFailureStatus.asException(); + } + }); + } + + @Test + public void successfulPing() throws Exception { + pingTracker.startPing(callback, directExecutor()); + assertThat(sentPings).hasSize(1); + callback.assertNotCalled(); + clock.forwardTime(3, SECONDS); + pingTracker.onPingResponse(sentPings.get(0)); + callback.assertSuccess(Duration.ofSeconds(3).toNanos()); + } + + @Test + public void failedPing() throws Exception { + pingFailureStatus = Status.INTERNAL.withDescription("Hello"); + pingTracker.startPing(callback, directExecutor()); + callback.assertFailure(pingFailureStatus); + } + + @Test + public void noSuccessAfterFailure() throws Exception { + pingFailureStatus = Status.INTERNAL.withDescription("Hello"); + pingTracker.startPing(callback, directExecutor()); + pingTracker.onPingResponse(sentPings.get(0)); + callback.assertFailure(pingFailureStatus); + } + + @Test + public void noMultiSuccess() throws Exception { + pingTracker.startPing(callback, directExecutor()); + pingTracker.onPingResponse(sentPings.get(0)); + pingTracker.onPingResponse(sentPings.get(0)); + callback.assertSuccess(); // Checks we were only called once. + } + + private static final class TestCallback implements ClientTransport.PingCallback { + private int numCallbacks; + private boolean success; + private boolean failure; + private Throwable failureException; + private long roundtripTimeNanos; + + @Override + public synchronized void onSuccess(long roundtripTimeNanos) { + numCallbacks += 1; + success = true; + this.roundtripTimeNanos = roundtripTimeNanos; + } + + @Override + public synchronized void onFailure(Throwable failureException) { + numCallbacks += 1; + failure = true; + this.failureException = failureException; + } + + public void assertNotCalled() { + assertThat(numCallbacks).isEqualTo(0); + } + + public void assertSuccess() { + assertThat(numCallbacks).isEqualTo(1); + assertThat(success).isTrue(); + } + + public void assertSuccess(long expectRoundTripTimeNanos) { + assertSuccess(); + assertThat(roundtripTimeNanos).isEqualTo(expectRoundTripTimeNanos); + } + + public void assertFailure(Status status) { + assertThat(numCallbacks).isEqualTo(1); + assertThat(failure).isTrue(); + assertThat(((StatusException) failureException).getStatus()).isSameInstanceAs(status); + } + + public void assertFailure(Status.Code statusCode) { + assertThat(numCallbacks).isEqualTo(1); + assertThat(failure).isTrue(); + assertThat(((StatusException) failureException).getStatus().getCode()).isEqualTo(statusCode); + } + } +} diff --git a/binder/src/test/java/io/grpc/binder/internal/TestParcelable.java b/binder/src/test/java/io/grpc/binder/internal/TestParcelable.java new file mode 100644 index 00000000000..a44f91162a0 --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/internal/TestParcelable.java @@ -0,0 +1,71 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.binder.internal; + +import android.os.Parcel; +import android.os.Parcelable; + +/** A parcelable for testing. */ +public class TestParcelable implements Parcelable { + private final String msg; + private final int contents; + + public TestParcelable(String msg) { + this(msg, 0); + } + + public TestParcelable(String msg, int contents) { + this.msg = msg; + this.contents = contents; + } + + @Override + public int describeContents() { + return contents; + } + + @Override + public void writeToParcel(Parcel parcel, int flags) { + parcel.writeString(msg); + } + + @Override + public int hashCode() { + return msg.hashCode(); + } + + @Override + public boolean equals(Object other) { + if (other instanceof TestParcelable) { + return msg.equals(((TestParcelable) other).msg); + } + return false; + } + + public static final Parcelable.Creator CREATOR = + new Parcelable.Creator() { + @Override + public TestParcelable createFromParcel(Parcel parcel) { + return new TestParcelable(parcel.readString(), 0); + } + + @Override + public TestParcelable[] newArray(int size) { + return new TestParcelable[size]; + } + }; +} diff --git a/build.gradle b/build.gradle index 36ccf753e0f..0fc314cf399 100644 --- a/build.gradle +++ b/build.gradle @@ -18,7 +18,7 @@ subprojects { apply plugin: "net.ltgt.errorprone" group = "io.grpc" - version = "1.38.0-SNAPSHOT" // CURRENT_GRPC_VERSION + version = "1.39.0" // CURRENT_GRPC_VERSION repositories { maven { // The google mirror is less flaky than mavenCentral() @@ -57,7 +57,7 @@ subprojects { nettyVersion = '4.1.52.Final' guavaVersion = '30.1-android' googleauthVersion = '0.22.2' - protobufVersion = '3.12.0' + protobufVersion = '3.17.2' protocVersion = protobufVersion opencensusVersion = '0.28.0' autovalueVersion = '1.7.4' @@ -186,10 +186,12 @@ subprojects { // Test dependencies. junit: 'junit:junit:4.12', mockito: 'org.mockito:mockito-core:3.3.3', + mockito_android: 'org.mockito:mockito-android:3.8.0', truth: 'com.google.truth:truth:1.0.1', guava_testlib: "com.google.guava:guava-testlib:${guavaVersion}", androidx_annotation: "androidx.annotation:annotation:1.1.0", androidx_core: "androidx.core:core:1.3.0", + androidx_lifecycle_service: "androidx.lifecycle:lifecycle-service:2.3.0", androidx_test: "androidx.test:core:1.3.0", androidx_test_rules: "androidx.test:rules:1.3.0", androidx_test_ext_junit: "androidx.test.ext:junit:1.1.2", @@ -203,68 +205,6 @@ subprojects { jetty_alpn_agent: 'org.mortbay.jetty.alpn:jetty-alpn-agent:2.0.10' ] - // A util function to config guava dependency with transitive dependencies - // properly resolved for the failOnVersionConflict strategy. - guavaDependency = { configurationName -> - dependencies."$configurationName"(libraries.guava) { - exclude group: 'com.google.code.findbugs', module: 'jsr305' - exclude group: 'com.google.errorprone', module: 'error_prone_annotations' - exclude group: 'org.codehaus.mojo', module: 'animal-sniffer-annotations' - } - dependencies."$configurationName" libraries.errorprone - dependencies.runtimeOnly libraries.animalsniffer_annotations - dependencies.runtimeOnly libraries.jsr305 - } - - // A util function to config opencensus_api dependency with transitive - // dependencies properly resolved for the failOnVersionConflict strategy. - censusApiDependency = { configurationName -> - dependencies."$configurationName"(libraries.opencensus_api) { - exclude group: 'com.google.code.findbugs', module: 'jsr305' - exclude group: 'com.google.guava', module: 'guava' - // we'll always be more up-to-date - exclude group: 'io.grpc', module: 'grpc-context' - } - dependencies.runtimeOnly project(':grpc-context') - dependencies.runtimeOnly libraries.jsr305 - guavaDependency 'runtimeOnly' - } - - // A util function to config opencensus_contrib_grpc_metrics dependency - // with transitive dependencies properly resolved for the failOnVersionConflict - // strategy. - censusGrpcMetricDependency = { configurationName -> - dependencies."$configurationName"(libraries.opencensus_contrib_grpc_metrics) { - exclude group: 'com.google.code.findbugs', module: 'jsr305' - exclude group: 'com.google.guava', module: 'guava' - // we'll always be more up-to-date - exclude group: 'io.grpc', module: 'grpc-context' - } - dependencies.runtimeOnly project(':grpc-context') - dependencies.runtimeOnly libraries.jsr305 - guavaDependency 'runtimeOnly' - } - - googleOauth2Dependency = { configurationName -> - dependencies."$configurationName"(libraries.google_auth_oauth2_http) { - exclude group: 'com.google.guava', module: 'guava' - exclude group: 'io.grpc', module: 'grpc-context' - exclude group: 'io.opencensus', module: 'opencensus-api' - } - dependencies.runtimeOnly project(':grpc-context') - censusApiDependency 'runtimeOnly' - guavaDependency 'runtimeOnly' - } - - // A util function to config perfmark dependency with transitive - // dependencies properly resolved for the failOnVersionConflict strategy. - perfmarkDependency = { configurationName -> - dependencies."$configurationName"(libraries.perfmark) { - exclude group: 'com.google.errorprone', module: 'error_prone_annotations' - } - dependencies.runtimeOnly libraries.errorprone - } - appendToProperty = { Property property, String value, String separator -> if (property.present) { property.set(property.get() + separator + value) @@ -274,25 +214,6 @@ subprojects { } } - configurations { - // Detect Maven Enforcer's dependencyConvergence failures. We only - // care for artifacts used as libraries by others. - if (isAndroid && !(project.name in ['grpc-android-interop-testing'])) { - releaseRuntimeClasspath { - resolutionStrategy.failOnVersionConflict() - } - } - if (!isAndroid && !(project.name in [ - 'grpc-benchmarks', - 'grpc-interop-testing', - 'grpc-gae-interop-testing-jdk8', - ])) { - runtimeClasspath { - resolutionStrategy.failOnVersionConflict() - } - } - } - // Disable JavaDoc doclint on Java 8. It's annoying. if (JavaVersion.current().isJava8Compatible()) { allprojects { @@ -404,6 +325,19 @@ subprojects { } } + plugins.withId("java-library") { + // Detect Maven Enforcer's dependencyConvergence failures. We only care + // for artifacts used as libraries by others with Maven. + tasks.register('checkUpperBoundDeps') { + doLast { + requireUpperBoundDepsMatch(configurations.runtimeClasspath, project) + } + } + tasks.named('compileJava') { + dependsOn checkUpperBoundDeps + } + } + plugins.withId("me.champeau.gradle.jmh") { dependencies { jmh 'org.openjdk.jmh:jmh-core:1.19', @@ -580,3 +514,56 @@ subprojects { } } } + +class DepAndParents { + DependencyResult dep + List parents +} + +/** + * Make sure that Maven would select the same versions as Gradle selected. + * This is essentially the same as if we used Maven Enforcer's + * requireUpperBoundDeps for our artifacts. + */ +def requireUpperBoundDepsMatch(Configuration conf, Project project) { + // artifact name => version + Map golden = conf.resolvedConfiguration.resolvedArtifacts.collectEntries { + ResolvedArtifact it -> + ModuleVersionIdentifier id = it.moduleVersion.id + [id.group + ":" + id.name, id.version] + } + // Breadth-first search like Maven for dependency resolution + Queue queue = new ArrayDeque<>() + conf.incoming.resolutionResult.root.dependencies.each { + queue.add(new DepAndParents(dep: it, parents: [project.displayName])) + } + Set found = new HashSet<>() + while (!queue.isEmpty()) { + DepAndParents depAndParents = queue.remove() + ResolvedDependencyResult result = (ResolvedDependencyResult) depAndParents.dep + ModuleVersionIdentifier id = result.selected.moduleVersion + String artifact = id.group + ":" + id.name + if (found.contains(artifact)) + continue + found.add(artifact) + String version + if (result.requested instanceof ProjectComponentSelector) { + ProjectComponentSelector selector = (ProjectComponentSelector) result.requested + version = project.findProject(selector.projectPath).version + } else { + version = ((ModuleComponentSelector) result.requested).version + } + String goldenVersion = golden[artifact] + if (goldenVersion != version && "[$goldenVersion]" != version) { + throw new RuntimeException( + "Maven version skew: $artifact ($version != $goldenVersion) " + + "Bad version dependency path: " + depAndParents.parents + + " Run './gradlew $project.path:dependencies --configuration $conf.name' " + + "to diagnose") + } + result.selected.dependencies.each { + queue.add(new DepAndParents( + dep: it, parents: depAndParents.parents + [artifact + ":" + version])) + } + } +} diff --git a/buildscripts/kokoro/unix.sh b/buildscripts/kokoro/unix.sh index 5a720295c64..faa9a6afe1d 100755 --- a/buildscripts/kokoro/unix.sh +++ b/buildscripts/kokoro/unix.sh @@ -46,8 +46,6 @@ export LDFLAGS=-L/tmp/protobuf/lib export CXXFLAGS="-I/tmp/protobuf/include" ./gradlew clean $GRADLE_FLAGS -# Ensure dependency convergence -./gradlew :grpc-all:dependencies $GRADLE_FLAGS if [[ -z "${SKIP_TESTS:-}" ]]; then # Ensure all *.proto changes include *.java generated code diff --git a/buildscripts/kokoro/xds-k8s.sh b/buildscripts/kokoro/xds-k8s.sh index bc0da15ba81..cafd884ccaf 100755 --- a/buildscripts/kokoro/xds-k8s.sh +++ b/buildscripts/kokoro/xds-k8s.sh @@ -4,9 +4,8 @@ set -eo pipefail # Constants readonly GITHUB_REPOSITORY_NAME="grpc-java" # GKE Cluster -readonly GKE_CLUSTER_NAME="interop-test-psm-sec-testing-api" -readonly GKE_CLUSTER_ZONE="us-west1-b" -export CLOUDSDK_API_ENDPOINT_OVERRIDES_CONTAINER="https://test-container.sandbox.googleapis.com/" +readonly GKE_CLUSTER_NAME="interop-test-psm-sec-v2-us-central1-a" +readonly GKE_CLUSTER_ZONE="us-central1-a" ## xDS test server/client Docker images readonly SERVER_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/java-server" readonly CLIENT_IMAGE_NAME="gcr.io/grpc-testing/xds-interop/java-client" diff --git a/buildscripts/kokoro/xds.sh b/buildscripts/kokoro/xds.sh index 893010eb0a3..c43438213c3 100755 --- a/buildscripts/kokoro/xds.sh +++ b/buildscripts/kokoro/xds.sh @@ -28,7 +28,7 @@ grpc/tools/run_tests/helper_scripts/prep_xds.sh # --test_case after they are added into "all". JAVA_OPTS=-Djava.util.logging.config.file=grpc-java/buildscripts/xds_logging.properties \ python3 grpc/tools/run_tests/run_xds_tests.py \ - --test_case="all,path_matching,header_matching,circuit_breaking,timeout,fault_injection,csds" \ + --test_case="all,circuit_breaking,timeout,fault_injection,csds" \ --project_id=grpc-testing \ --project_num=830293263384 \ --source_image=projects/grpc-testing/global/images/xds-test-server-4 \ diff --git a/buildscripts/make_dependencies.bat b/buildscripts/make_dependencies.bat index c9ff2a63a43..1622daaa16b 100644 --- a/buildscripts/make_dependencies.bat +++ b/buildscripts/make_dependencies.bat @@ -1,4 +1,4 @@ -set PROTOBUF_VER=3.12.0 +set PROTOBUF_VER=3.17.2 set CMAKE_NAME=cmake-3.3.2-win32-x86 if not exist "protobuf-%PROTOBUF_VER%\cmake\build\Release\" ( diff --git a/buildscripts/make_dependencies.sh b/buildscripts/make_dependencies.sh index 3c3cc23c6a9..927f1b4be28 100755 --- a/buildscripts/make_dependencies.sh +++ b/buildscripts/make_dependencies.sh @@ -3,7 +3,7 @@ # Build protoc set -evux -o pipefail -PROTOBUF_VERSION=3.12.0 +PROTOBUF_VERSION=3.17.2 # ARCH is x86_64 bit unless otherwise specified. ARCH="${ARCH:-x86_64}" diff --git a/census/build.gradle b/census/build.gradle index af5445a218c..35973a5f016 100644 --- a/census/build.gradle +++ b/census/build.gradle @@ -9,9 +9,9 @@ evaluationDependsOn(project(':grpc-api').path) dependencies { api project(':grpc-api') - guavaDependency 'implementation' - censusApiDependency 'implementation' - censusGrpcMetricDependency 'implementation' + implementation libraries.guava, + libraries.opencensus_api, + libraries.opencensus_contrib_grpc_metrics testImplementation project(':grpc-api').sourceSets.test.output, project(':grpc-context').sourceSets.test.output, diff --git a/compiler/README.md b/compiler/README.md index 81dbff4adbc..25e557d82eb 100644 --- a/compiler/README.md +++ b/compiler/README.md @@ -14,14 +14,11 @@ binaries for common platforms are available on Maven Central: However, if the pre-compiled binaries are not compatible with your system, you may want to build your own codegen. -## System requirement +## Compiling and testing the codegen -* Linux, Mac OS X with Clang, or Windows with MSYS2 -* Java 7 or up -* [Protobuf](https://github.com/google/protobuf) 3.12.0 or up +Set up your system as described in [COMPILING.md](../COMPILING.md). -## Compiling and testing the codegen -Change to the `compiler` directory: +Then change to the `compiler` directory: ``` $ cd $GRPC_JAVA_ROOT/compiler ``` diff --git a/compiler/build.gradle b/compiler/build.gradle index 60d3a436f6c..0b5766578a8 100644 --- a/compiler/build.gradle +++ b/compiler/build.gradle @@ -100,7 +100,9 @@ model { } else if (osdetector.os == "windows") { linker.args "-static", "-lprotoc", "-lprotobuf", "-static-libgcc", "-static-libstdc++", "-s" - } else { + } else if (osdetector.arch == "ppcle_64") { + linker.args "-Wl,-Bstatic", "-lprotoc", "-lprotobuf", "-Wl,-Bdynamic", "-lpthread", "-s" + } else { // Link protoc, protobuf, libgcc and libstdc++ statically. // Link other (system) libraries dynamically. // Clang under OSX doesn't support these options. diff --git a/compiler/src/test/golden/TestDeprecatedService.java.txt b/compiler/src/test/golden/TestDeprecatedService.java.txt index df1242f42a8..1feec798549 100644 --- a/compiler/src/test/golden/TestDeprecatedService.java.txt +++ b/compiler/src/test/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.38.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.39.0)", comments = "Source: grpc/testing/compiler/test.proto") @java.lang.Deprecated public final class TestDeprecatedServiceGrpc { diff --git a/compiler/src/test/golden/TestService.java.txt b/compiler/src/test/golden/TestService.java.txt index 52e6b86a922..fd0fdebcd29 100644 --- a/compiler/src/test/golden/TestService.java.txt +++ b/compiler/src/test/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.38.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.39.0)", comments = "Source: grpc/testing/compiler/test.proto") public final class TestServiceGrpc { diff --git a/compiler/src/testLite/golden/TestDeprecatedService.java.txt b/compiler/src/testLite/golden/TestDeprecatedService.java.txt index 000e0865315..8035aa9e73e 100644 --- a/compiler/src/testLite/golden/TestDeprecatedService.java.txt +++ b/compiler/src/testLite/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.38.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.39.0)", comments = "Source: grpc/testing/compiler/test.proto") @java.lang.Deprecated public final class TestDeprecatedServiceGrpc { diff --git a/compiler/src/testLite/golden/TestService.java.txt b/compiler/src/testLite/golden/TestService.java.txt index d88f7c4a96a..5fad2aea09e 100644 --- a/compiler/src/testLite/golden/TestService.java.txt +++ b/compiler/src/testLite/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.38.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.39.0)", comments = "Source: grpc/testing/compiler/test.proto") public final class TestServiceGrpc { diff --git a/core/BUILD.bazel b/core/BUILD.bazel index c6e1ffbb6d7..c50e86a511c 100644 --- a/core/BUILD.bazel +++ b/core/BUILD.bazel @@ -15,6 +15,7 @@ java_library( "//api", "//context", "@com_google_code_findbugs_jsr305//jar", + "@com_google_errorprone_error_prone_annotations//jar", "@com_google_guava_guava//jar", "@com_google_j2objc_j2objc_annotations//jar", ], @@ -25,6 +26,7 @@ java_library( srcs = glob([ "src/main/java/io/grpc/internal/*.java", ]), + javacopts = ["-Xep:DoNotCall:OFF"], # Remove once requiring Bazel 3.4.0+; allows non-final resources = glob([ "src/bazel-internal/resources/**", ]), diff --git a/core/build.gradle b/core/build.gradle index f4e0ccc3b62..14cfbb70497 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -27,9 +27,9 @@ dependencies { implementation libraries.gson, libraries.android_annotations, libraries.animalsniffer_annotations, - libraries.errorprone - guavaDependency 'implementation' - perfmarkDependency 'implementation' + libraries.errorprone, + libraries.guava, + libraries.perfmark testImplementation project(':grpc-context').sourceSets.test.output, project(':grpc-api').sourceSets.test.output, project(':grpc-testing'), diff --git a/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java b/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java index 8d285897fcb..812ad5d7861 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessChannelBuilder.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; +import com.google.errorprone.annotations.DoNotCall; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; import io.grpc.ExperimentalApi; @@ -60,6 +61,7 @@ public static InProcessChannelBuilder forName(String name) { /** * Always fails. Call {@link #forName} instead. */ + @DoNotCall("Unsupported. Use forName() instead") public static InProcessChannelBuilder forTarget(String target) { throw new UnsupportedOperationException("call forName() instead"); } @@ -67,6 +69,7 @@ public static InProcessChannelBuilder forTarget(String target) { /** * Always fails. Call {@link #forName} instead. */ + @DoNotCall("Unsupported. Use forName() instead") public static InProcessChannelBuilder forAddress(String name, int port) { throw new UnsupportedOperationException("call forName() instead"); } diff --git a/core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java b/core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java index 25291c28489..6c68189fcc9 100644 --- a/core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java +++ b/core/src/main/java/io/grpc/inprocess/InProcessServerBuilder.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.base.Preconditions; +import com.google.errorprone.annotations.DoNotCall; import io.grpc.Deadline; import io.grpc.ExperimentalApi; import io.grpc.Internal; @@ -86,6 +87,7 @@ public static InProcessServerBuilder forName(String name) { /** * Always fails. Call {@link #forName} instead. */ + @DoNotCall("Unsupported. Use forName() instead") public static InProcessServerBuilder forPort(int port) { throw new UnsupportedOperationException("call forName() instead"); } diff --git a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java index 84a930e369f..98bbfcc7b1e 100644 --- a/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/AbstractManagedChannelImplBuilder.java @@ -17,6 +17,7 @@ package io.grpc.internal; import com.google.common.base.MoreObjects; +import com.google.errorprone.annotations.DoNotCall; import io.grpc.BinaryLog; import io.grpc.ClientInterceptor; import io.grpc.CompressorRegistry; @@ -49,6 +50,7 @@ protected AbstractManagedChannelImplBuilder() {} /** * This method serves to force sub classes to "hide" this static factory. */ + @DoNotCall("Unsupported") public static ManagedChannelBuilder forAddress(String name, int port) { throw new UnsupportedOperationException("Subclass failed to hide static factory"); } @@ -56,6 +58,7 @@ public static ManagedChannelBuilder forAddress(String name, int port) { /** * This method serves to force sub classes to "hide" this static factory. */ + @DoNotCall("Unsupported") public static ManagedChannelBuilder forTarget(String target) { throw new UnsupportedOperationException("Subclass failed to hide static factory"); } diff --git a/core/src/main/java/io/grpc/internal/AbstractReadableBuffer.java b/core/src/main/java/io/grpc/internal/AbstractReadableBuffer.java index e43b7a7cc0e..16c046dfc36 100644 --- a/core/src/main/java/io/grpc/internal/AbstractReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/AbstractReadableBuffer.java @@ -16,6 +16,8 @@ package io.grpc.internal; +import java.nio.ByteBuffer; + /** * Abstract base class for {@link ReadableBuffer} implementations. */ @@ -45,6 +47,29 @@ public int arrayOffset() { throw new UnsupportedOperationException(); } + @Override + public boolean markSupported() { + return false; + } + + @Override + public void mark() {} + + @Override + public void reset() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean byteBufferSupported() { + return false; + } + + @Override + public ByteBuffer getByteBuffer() { + throw new UnsupportedOperationException(); + } + @Override public void close() {} diff --git a/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java b/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java index 34021d8a82b..9baec34b189 100644 --- a/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/CompositeReadableBuffer.java @@ -20,8 +20,10 @@ import java.io.OutputStream; import java.nio.Buffer; import java.nio.ByteBuffer; +import java.nio.InvalidMarkException; import java.util.ArrayDeque; -import java.util.Queue; +import java.util.Deque; +import javax.annotation.Nullable; /** * A {@link ReadableBuffer} that is composed of 0 or more {@link ReadableBuffer}s. This provides a @@ -33,15 +35,17 @@ */ public class CompositeReadableBuffer extends AbstractReadableBuffer { + private final Deque readableBuffers; + private Deque rewindableBuffers; private int readableBytes; - private final Queue buffers; + private boolean marked; public CompositeReadableBuffer(int initialCapacity) { - buffers = new ArrayDeque<>(initialCapacity); + readableBuffers = new ArrayDeque<>(initialCapacity); } public CompositeReadableBuffer() { - buffers = new ArrayDeque<>(); + readableBuffers = new ArrayDeque<>(); } /** @@ -51,16 +55,24 @@ public CompositeReadableBuffer() { * this {@code CompositeBuffer}. */ public void addBuffer(ReadableBuffer buffer) { + boolean markHead = marked && readableBuffers.isEmpty(); + enqueueBuffer(buffer); + if (markHead) { + readableBuffers.peek().mark(); + } + } + + private void enqueueBuffer(ReadableBuffer buffer) { if (!(buffer instanceof CompositeReadableBuffer)) { - buffers.add(buffer); + readableBuffers.add(buffer); readableBytes += buffer.readableBytes(); return; } CompositeReadableBuffer compositeBuffer = (CompositeReadableBuffer) buffer; - while (!compositeBuffer.buffers.isEmpty()) { - ReadableBuffer subBuffer = compositeBuffer.buffers.remove(); - buffers.add(subBuffer); + while (!compositeBuffer.readableBuffers.isEmpty()) { + ReadableBuffer subBuffer = compositeBuffer.readableBuffers.remove(); + readableBuffers.add(subBuffer); } readableBytes += compositeBuffer.readableBytes; compositeBuffer.readableBytes = 0; @@ -158,22 +170,27 @@ public ReadableBuffer readBytes(int length) { ReadableBuffer newBuffer = null; CompositeReadableBuffer newComposite = null; do { - ReadableBuffer buffer = buffers.peek(); + ReadableBuffer buffer = readableBuffers.peek(); int readable = buffer.readableBytes(); ReadableBuffer readBuffer; if (readable > length) { readBuffer = buffer.readBytes(length); length = 0; } else { - readBuffer = buffers.poll(); + if (marked) { + readBuffer = buffer.readBytes(readable); + advanceBuffer(); + } else { + readBuffer = readableBuffers.poll(); + } length -= readable; } if (newBuffer == null) { newBuffer = readBuffer; } else { if (newComposite == null) { - newComposite = - new CompositeReadableBuffer(length == 0 ? 2 : Math.min(buffers.size() + 2, 16)); + newComposite = new CompositeReadableBuffer( + length == 0 ? 2 : Math.min(readableBuffers.size() + 2, 16)); newComposite.addBuffer(newBuffer); newBuffer = newComposite; } @@ -183,10 +200,77 @@ public ReadableBuffer readBytes(int length) { return newBuffer; } + @Override + public boolean markSupported() { + for (ReadableBuffer buffer : readableBuffers) { + if (!buffer.markSupported()) { + return false; + } + } + return true; + } + + @Override + public void mark() { + if (rewindableBuffers == null) { + rewindableBuffers = new ArrayDeque<>(Math.min(readableBuffers.size(), 16)); + } + while (!rewindableBuffers.isEmpty()) { + rewindableBuffers.remove().close(); + } + marked = true; + ReadableBuffer buffer = readableBuffers.peek(); + if (buffer != null) { + buffer.mark(); + } + } + + @Override + public void reset() { + if (!marked) { + throw new InvalidMarkException(); + } + ReadableBuffer buffer; + if ((buffer = readableBuffers.peek()) != null) { + int currentRemain = buffer.readableBytes(); + buffer.reset(); + readableBytes += (buffer.readableBytes() - currentRemain); + } + while ((buffer = rewindableBuffers.pollLast()) != null) { + buffer.reset(); + readableBuffers.addFirst(buffer); + readableBytes += buffer.readableBytes(); + } + } + + @Override + public boolean byteBufferSupported() { + for (ReadableBuffer buffer : readableBuffers) { + if (!buffer.byteBufferSupported()) { + return false; + } + } + return true; + } + + @Nullable + @Override + public ByteBuffer getByteBuffer() { + if (readableBuffers.isEmpty()) { + return null; + } + return readableBuffers.peek().getByteBuffer(); + } + @Override public void close() { - while (!buffers.isEmpty()) { - buffers.remove().close(); + while (!readableBuffers.isEmpty()) { + readableBuffers.remove().close(); + } + if (rewindableBuffers != null) { + while (!rewindableBuffers.isEmpty()) { + rewindableBuffers.remove().close(); + } } } @@ -197,12 +281,12 @@ public void close() { private int execute(ReadOperation op, int length, T dest, int value) throws IOException { checkReadable(length); - if (!buffers.isEmpty()) { + if (!readableBuffers.isEmpty()) { advanceBufferIfNecessary(); } - for (; length > 0 && !buffers.isEmpty(); advanceBufferIfNecessary()) { - ReadableBuffer buffer = buffers.peek(); + for (; length > 0 && !readableBuffers.isEmpty(); advanceBufferIfNecessary()) { + ReadableBuffer buffer = readableBuffers.peek(); int lengthToCopy = Math.min(length, buffer.readableBytes()); // Perform the read operation for this buffer. @@ -232,9 +316,24 @@ private int executeNoThrow(NoThrowReadOperation op, int length, T dest, i * If the current buffer is exhausted, removes and closes it. */ private void advanceBufferIfNecessary() { - ReadableBuffer buffer = buffers.peek(); + ReadableBuffer buffer = readableBuffers.peek(); if (buffer.readableBytes() == 0) { - buffers.remove().close(); + advanceBuffer(); + } + } + + /** + * Removes one buffer from the front and closes it. + */ + private void advanceBuffer() { + if (marked) { + rewindableBuffers.add(readableBuffers.remove()); + ReadableBuffer next = readableBuffers.peek(); + if (next != null) { + next.mark(); + } + } else { + readableBuffers.remove().close(); } } diff --git a/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java b/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java index 954d0ac5486..1d7b412e195 100644 --- a/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/ForwardingReadableBuffer.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.io.OutputStream; import java.nio.ByteBuffer; +import javax.annotation.Nullable; /** * Base class for a wrapper around another {@link ReadableBuffer}. @@ -96,6 +97,32 @@ public int arrayOffset() { return buf.arrayOffset(); } + @Override + public boolean markSupported() { + return buf.markSupported(); + } + + @Override + public void mark() { + buf.mark(); + } + + @Override + public void reset() { + buf.reset(); + } + + @Override + public boolean byteBufferSupported() { + return buf.byteBufferSupported(); + } + + @Nullable + @Override + public ByteBuffer getByteBuffer() { + return buf.getByteBuffer(); + } + @Override public void close() { buf.close(); diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 98cba477913..ac7d13a029a 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -197,7 +197,7 @@ public byte[] parseAsciiString(byte[] serialized) { public static final Splitter ACCEPT_ENCODING_SPLITTER = Splitter.on(',').trimResults(); - private static final String IMPLEMENTATION_VERSION = "1.38.0-SNAPSHOT"; // CURRENT_GRPC_VERSION + private static final String IMPLEMENTATION_VERSION = "1.39.0"; // CURRENT_GRPC_VERSION /** * The default timeout in nanos for a keepalive ping request. diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java index c74eed6df7f..d42b3832136 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.DoNotCall; import io.grpc.Attributes; import io.grpc.BinaryLog; import io.grpc.CallCredentials; @@ -61,11 +62,13 @@ public final class ManagedChannelImplBuilder private static final Logger log = Logger.getLogger(ManagedChannelImplBuilder.class.getName()); + @DoNotCall("ClientTransportFactoryBuilder is required, use a constructor") public static ManagedChannelBuilder forAddress(String name, int port) { throw new UnsupportedOperationException( "ClientTransportFactoryBuilder is required, use a constructor"); } + @DoNotCall("ClientTransportFactoryBuilder is required, use a constructor") public static ManagedChannelBuilder forTarget(String target) { throw new UnsupportedOperationException( "ClientTransportFactoryBuilder is required, use a constructor"); @@ -280,9 +283,25 @@ static String makeTargetStringForDirectAddress(SocketAddress address) { public ManagedChannelImplBuilder(SocketAddress directServerAddress, String authority, ClientTransportFactoryBuilder clientTransportFactoryBuilder, @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { + this(directServerAddress, authority, null, null, clientTransportFactoryBuilder, + channelBuilderDefaultPortProvider); + } + + /** + * Creates a new managed channel builder with the given server address, authority string of the + * channel. Transport implementors must provide client transport factory builder, and may set + * custom channel default port provider. + * + * @param channelCreds The ChannelCredentials provided by the user. These may be used when + * creating derivative channels. + */ + public ManagedChannelImplBuilder(SocketAddress directServerAddress, String authority, + @Nullable ChannelCredentials channelCreds, @Nullable CallCredentials callCreds, + ClientTransportFactoryBuilder clientTransportFactoryBuilder, + @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { this.target = makeTargetStringForDirectAddress(directServerAddress); - this.channelCredentials = null; - this.callCredentials = null; + this.channelCredentials = channelCreds; + this.callCredentials = callCreds; this.clientTransportFactoryBuilder = Preconditions .checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder"); this.directServerAddress = directServerAddress; diff --git a/core/src/main/java/io/grpc/internal/ReadableBuffer.java b/core/src/main/java/io/grpc/internal/ReadableBuffer.java index 7d2ca7ebba5..b47501a9943 100644 --- a/core/src/main/java/io/grpc/internal/ReadableBuffer.java +++ b/core/src/main/java/io/grpc/internal/ReadableBuffer.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.io.OutputStream; import java.nio.ByteBuffer; +import javax.annotation.Nullable; /** * Interface for an abstract byte buffer. Buffers are intended to be a read-only, except for the @@ -123,6 +124,44 @@ public interface ReadableBuffer extends Closeable { */ int arrayOffset(); + /** + * Indicates whether or not {@link #mark} operation is supported for this buffer. + */ + boolean markSupported(); + + /** + * Marks the current position in this buffer. A subsequent call to the {@link #reset} method + * repositions this stream at the last marked position so that subsequent reads re-read the same + * bytes. + */ + void mark(); + + /** + * Repositions this buffer to the position at the time {@link #mark} was last called on this + * buffer. + */ + void reset(); + + /** + * Indicates whether or not {@link #getByteBuffer} operation is supported for this buffer. + */ + boolean byteBufferSupported(); + + /** + * Gets a {@link ByteBuffer} that contains some bytes of the content next to be read, or {@code + * null} if this buffer has been exhausted. The number of bytes contained in the returned buffer + * is implementation specific. The position of this buffer is unchanged after calling this + * method. The returned buffer's content should not be modified, but the position, limit, and + * mark may be changed. Operations for changing the position, limit, and mark of the returned + * buffer does not affect the position, limit, and mark of this buffer. Buffers returned by this + * method have independent position, limit and mark. This is an optional method, so callers + * should first check {@link #byteBufferSupported}. + * + * @throws UnsupportedOperationException the buffer does not support this method. + */ + @Nullable + ByteBuffer getByteBuffer(); + /** * Closes this buffer and releases any resources. */ diff --git a/core/src/main/java/io/grpc/internal/ReadableBuffers.java b/core/src/main/java/io/grpc/internal/ReadableBuffers.java index cfe5542a573..c54cb0e67d0 100644 --- a/core/src/main/java/io/grpc/internal/ReadableBuffers.java +++ b/core/src/main/java/io/grpc/internal/ReadableBuffers.java @@ -19,13 +19,17 @@ import static com.google.common.base.Charsets.UTF_8; import com.google.common.base.Preconditions; +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; import io.grpc.KnownLength; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.nio.Buffer; import java.nio.ByteBuffer; +import java.nio.InvalidMarkException; import java.nio.charset.Charset; +import javax.annotation.Nullable; /** * Utility methods for creating {@link ReadableBuffer} instances. @@ -128,6 +132,7 @@ private static class ByteArrayWrapper extends AbstractReadableBuffer { int offset; final int end; final byte[] bytes; + int mark = -1; ByteArrayWrapper(byte[] bytes) { this(bytes, 0, bytes.length); @@ -204,6 +209,24 @@ public byte[] array() { public int arrayOffset() { return offset; } + + @Override + public boolean markSupported() { + return true; + } + + @Override + public void mark() { + mark = offset; + } + + @Override + public void reset() { + if (mark == -1) { + throw new InvalidMarkException(); + } + offset = mark; + } } /** @@ -291,13 +314,39 @@ public byte[] array() { public int arrayOffset() { return bytes.arrayOffset() + bytes.position(); } + + @Override + public boolean markSupported() { + return true; + } + + @Override + public void mark() { + bytes.mark(); + } + + @Override + public void reset() { + bytes.reset(); + } + + @Override + public boolean byteBufferSupported() { + return true; + } + + @Override + public ByteBuffer getByteBuffer() { + return bytes.slice(); + } } /** * An {@link InputStream} that is backed by a {@link ReadableBuffer}. */ - private static final class BufferInputStream extends InputStream implements KnownLength { - final ReadableBuffer buffer; + private static final class BufferInputStream extends InputStream + implements KnownLength, HasByteBuffer, Detachable { + private ReadableBuffer buffer; public BufferInputStream(ReadableBuffer buffer) { this.buffer = Preconditions.checkNotNull(buffer, "buffer"); @@ -329,6 +378,46 @@ public int read(byte[] dest, int destOffset, int length) throws IOException { return length; } + @Override + public long skip(long n) throws IOException { + int length = (int) Math.min(buffer.readableBytes(), n); + buffer.skipBytes(length); + return length; + } + + @Override + public void mark(int readlimit) { + buffer.mark(); + } + + @Override + public void reset() throws IOException { + buffer.reset(); + } + + @Override + public boolean markSupported() { + return buffer.markSupported(); + } + + @Override + public boolean byteBufferSupported() { + return buffer.byteBufferSupported(); + } + + @Nullable + @Override + public ByteBuffer getByteBuffer() { + return buffer.getByteBuffer(); + } + + @Override + public InputStream detach() { + ReadableBuffer detachedBuffer = buffer; + buffer = buffer.readBytes(0); + return new BufferInputStream(detachedBuffer); + } + @Override public void close() throws IOException { buffer.close(); diff --git a/core/src/main/java/io/grpc/internal/ServerImpl.java b/core/src/main/java/io/grpc/internal/ServerImpl.java index 6c66aac07ac..21f13cf5b4d 100644 --- a/core/src/main/java/io/grpc/internal/ServerImpl.java +++ b/core/src/main/java/io/grpc/internal/ServerImpl.java @@ -604,7 +604,7 @@ private ServerStreamListener startCall(ServerStream stream, String stream.getAuthority())); ServerCallHandler handler = methodDef.getServerCallHandler(); for (ServerInterceptor interceptor : interceptors) { - handler = InternalServerInterceptors.interceptCallHandler(interceptor, handler); + handler = InternalServerInterceptors.interceptCallHandlerCreate(interceptor, handler); } ServerMethodDefinition interceptedDef = methodDef.withServerCallHandler(handler); ServerMethodDefinition wMethodDef = binlog == null diff --git a/core/src/main/java/io/grpc/internal/ServerImplBuilder.java b/core/src/main/java/io/grpc/internal/ServerImplBuilder.java index 04e6059d13f..aafdc150fb2 100644 --- a/core/src/main/java/io/grpc/internal/ServerImplBuilder.java +++ b/core/src/main/java/io/grpc/internal/ServerImplBuilder.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.DoNotCall; import io.grpc.BinaryLog; import io.grpc.BindableService; import io.grpc.CompressorRegistry; @@ -55,6 +56,7 @@ public final class ServerImplBuilder extends ServerBuilder { private static final Logger log = Logger.getLogger(ServerImplBuilder.class.getName()); + @DoNotCall("ClientTransportServersBuilder is required, use a constructor") public static ServerBuilder forPort(int port) { throw new UnsupportedOperationException( "ClientTransportServersBuilder is required, use a constructor"); diff --git a/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java b/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java index 660aa116317..011d83b548a 100644 --- a/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java +++ b/core/src/test/java/io/grpc/internal/CompositeReadableBufferTest.java @@ -17,14 +17,20 @@ package io.grpc.internal; import static com.google.common.base.Charsets.UTF_8; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.Buffer; import java.nio.ByteBuffer; +import java.nio.InvalidMarkException; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -154,6 +160,145 @@ public void readStreamShouldSucceed() throws IOException { assertEquals(EXPECTED_VALUE, new String(bos.toByteArray(), UTF_8)); } + @Test + public void markSupportedOnlyAllComponentsSupportMark() { + composite = new CompositeReadableBuffer(); + ReadableBuffer buffer1 = mock(ReadableBuffer.class); + ReadableBuffer buffer2 = mock(ReadableBuffer.class); + ReadableBuffer buffer3 = mock(ReadableBuffer.class); + when(buffer1.markSupported()).thenReturn(true); + when(buffer2.markSupported()).thenReturn(true); + when(buffer3.markSupported()).thenReturn(false); + composite.addBuffer(buffer1); + assertTrue(composite.markSupported()); + composite.addBuffer(buffer2); + assertTrue(composite.markSupported()); + composite.addBuffer(buffer3); + assertFalse(composite.markSupported()); + } + + @Test + public void resetUnmarkedShouldThrow() { + try { + composite.reset(); + fail(); + } catch (InvalidMarkException expected) { + } + } + + @Test + public void markAndResetWithSkipBytesShouldSucceed() { + composite.mark(); + composite.skipBytes(EXPECTED_VALUE.length() / 2); + composite.reset(); + assertEquals(EXPECTED_VALUE.length(), composite.readableBytes()); + } + + @Test + public void markAndResetWithReadUnsignedByteShouldSucceed() { + composite.readUnsignedByte(); + composite.mark(); + int b = composite.readUnsignedByte(); + composite.reset(); + assertEquals(EXPECTED_VALUE.length() - 1, composite.readableBytes()); + assertEquals(b, composite.readUnsignedByte()); + } + + @Test + public void markAndResetWithReadByteArrayShouldSucceed() { + composite.mark(); + byte[] first = new byte[EXPECTED_VALUE.length()]; + composite.readBytes(first, 0, EXPECTED_VALUE.length()); + composite.reset(); + assertEquals(EXPECTED_VALUE.length(), composite.readableBytes()); + byte[] second = new byte[EXPECTED_VALUE.length()]; + composite.readBytes(second, 0, EXPECTED_VALUE.length()); + assertArrayEquals(first, second); + } + + @Test + public void markAndResetWithReadByteBufferShouldSucceed() { + byte[] first = new byte[EXPECTED_VALUE.length()]; + composite.mark(); + composite.readBytes(ByteBuffer.wrap(first)); + composite.reset(); + byte[] second = new byte[EXPECTED_VALUE.length()]; + assertEquals(EXPECTED_VALUE.length(), composite.readableBytes()); + composite.readBytes(ByteBuffer.wrap(second)); + assertArrayEquals(first, second); + } + + @Test + public void markAndResetWithReadStreamShouldSucceed() throws IOException { + ByteArrayOutputStream first = new ByteArrayOutputStream(); + composite.mark(); + composite.readBytes(first, EXPECTED_VALUE.length() / 2); + composite.reset(); + assertEquals(EXPECTED_VALUE.length(), composite.readableBytes()); + ByteArrayOutputStream second = new ByteArrayOutputStream(); + composite.readBytes(second, EXPECTED_VALUE.length() / 2); + assertArrayEquals(first.toByteArray(), second.toByteArray()); + } + + @Test + public void markAndResetWithReadReadableBufferShouldSucceed() { + composite.readBytes(EXPECTED_VALUE.length() / 2); + int remaining = composite.readableBytes(); + composite.mark(); + ReadableBuffer first = composite.readBytes(1); + composite.reset(); + assertEquals(remaining, composite.readableBytes()); + ReadableBuffer second = composite.readBytes(1); + assertEquals(first.readUnsignedByte(), second.readUnsignedByte()); + } + + @Test + public void markAgainShouldOverwritePreviousMark() { + composite.mark(); + composite.skipBytes(EXPECTED_VALUE.length() / 2); + int remaining = composite.readableBytes(); + composite.mark(); + composite.skipBytes(1); + composite.reset(); + assertEquals(remaining, composite.readableBytes()); + } + + @Test + public void bufferAddedAfterMarkedShouldBeIncluded() { + composite = new CompositeReadableBuffer(); + composite.mark(); + splitAndAdd(EXPECTED_VALUE); + composite.skipBytes(EXPECTED_VALUE.length() / 2); + composite.reset(); + assertEquals(EXPECTED_VALUE.length(), composite.readableBytes()); + } + + @Test + public void canUseByteBufferOnlyAllComponentsSupportUsingByteBuffer() { + composite = new CompositeReadableBuffer(); + ReadableBuffer buffer1 = mock(ReadableBuffer.class); + ReadableBuffer buffer2 = mock(ReadableBuffer.class); + ReadableBuffer buffer3 = mock(ReadableBuffer.class); + when(buffer1.byteBufferSupported()).thenReturn(true); + when(buffer2.byteBufferSupported()).thenReturn(true); + when(buffer3.byteBufferSupported()).thenReturn(false); + composite.addBuffer(buffer1); + assertTrue(composite.byteBufferSupported()); + composite.addBuffer(buffer2); + assertTrue(composite.byteBufferSupported()); + composite.addBuffer(buffer3); + assertFalse(composite.byteBufferSupported()); + } + + @Test + public void getByteBufferDelegatesToComponents() { + composite = new CompositeReadableBuffer(); + ReadableBuffer buffer = mock(ReadableBuffer.class); + composite.addBuffer(buffer); + composite.getByteBuffer(); + verify(buffer).getByteBuffer(); + } + @Test public void closeShouldCloseBuffers() { composite = new CompositeReadableBuffer(); diff --git a/core/src/test/java/io/grpc/internal/ReadableBufferTestBase.java b/core/src/test/java/io/grpc/internal/ReadableBufferTestBase.java index e469b807d51..97e0df38ae7 100644 --- a/core/src/test/java/io/grpc/internal/ReadableBufferTestBase.java +++ b/core/src/test/java/io/grpc/internal/ReadableBufferTestBase.java @@ -24,6 +24,7 @@ import java.nio.Buffer; import java.nio.ByteBuffer; import java.util.Arrays; +import org.junit.Assume; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -117,6 +118,58 @@ public void partialReadToReadableBufferShouldSucceed() { assertArrayEquals(new byte[] {'h', 'e'}, Arrays.copyOfRange(array, 0, 2)); } + @Test + public void markAndResetWithReadShouldSucceed() { + ReadableBuffer buffer = buffer(); + int offset = 5; + buffer.readBytes(new byte[offset], 0, offset); + buffer.mark(); + int b = buffer.readUnsignedByte(); + assertEquals(msg.length() - offset - 1, buffer.readableBytes()); + buffer.reset(); + assertEquals(msg.length() - offset, buffer.readableBytes()); + assertEquals(b, buffer.readUnsignedByte()); + } + + @Test + public void markAndResetWithReadToReadableBufferShouldSucceed() { + ReadableBuffer buffer = buffer(); + int offset = 5; + buffer.readBytes(offset); + int testLen = 100; + buffer.mark(); + ReadableBuffer first = buffer.readBytes(testLen); + assertEquals(msg.length() - offset - testLen, buffer.readableBytes()); + buffer.reset(); + assertEquals(msg.length() - offset, buffer.readableBytes()); + ReadableBuffer second = buffer.readBytes(testLen); + byte[] array1 = new byte[testLen]; + byte[] array2 = new byte[testLen]; + first.readBytes(array1, 0, testLen); + second.readBytes(array2, 0, testLen); + assertArrayEquals(array1, array2); + } + + @Test + public void getByteBufferDoesNotAffectBufferPosition() { + ReadableBuffer buffer = buffer(); + Assume.assumeTrue(buffer.byteBufferSupported()); + ByteBuffer byteBuffer = buffer.getByteBuffer(); + assertEquals(msg.length(), buffer.readableBytes()); + byteBuffer.get(new byte[byteBuffer.remaining()]); + assertEquals(msg.length(), buffer.readableBytes()); + } + + @Test + public void getByteBufferIsNotAffectedByBufferRead() { + ReadableBuffer buffer = buffer(); + Assume.assumeTrue(buffer.byteBufferSupported()); + ByteBuffer byteBuffer = buffer.getByteBuffer(); + int initialRemaining = byteBuffer.remaining(); + buffer.readBytes(new byte[100], 0, 100); + assertEquals(initialRemaining, byteBuffer.remaining()); + } + protected abstract ReadableBuffer buffer(); private static String repeatUntilLength(String toRepeat, int length) { diff --git a/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java b/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java index ea9daeed6a3..0947f65da12 100644 --- a/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java +++ b/core/src/test/java/io/grpc/internal/ReadableBuffersTest.java @@ -19,13 +19,22 @@ import static com.google.common.base.Charsets.UTF_8; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import io.grpc.Detachable; +import io.grpc.HasByteBuffer; +import java.io.IOException; import java.io.InputStream; +import java.nio.InvalidMarkException; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,6 +46,9 @@ public class ReadableBuffersTest { private static final byte[] MSG_BYTES = "hello".getBytes(UTF_8); + @Rule + public final ExpectedException thrown = ExpectedException.none(); + @Test public void empty_returnsEmptyBuffer() { ReadableBuffer buffer = ReadableBuffers.empty(); @@ -128,4 +140,108 @@ public void bufferInputStream_close_closesBuffer() throws Exception { inputStream.close(); verify(buffer, times(1)).close(); } + + @Test + public void bufferInputStream_markAndReset() throws IOException { + ReadableBuffer buffer = ReadableBuffers.wrap(MSG_BYTES); + InputStream inputStream = ReadableBuffers.openStream(buffer, true); + assertTrue(inputStream.markSupported()); + inputStream.mark(2); + byte[] first = new byte[5]; + inputStream.read(first); + assertEquals(0, inputStream.available()); + inputStream.reset(); + assertEquals(5, inputStream.available()); + byte[] second = new byte[5]; + inputStream.read(second); + assertArrayEquals(first, second); + } + + @Test + public void bufferInputStream_getByteBufferDelegatesToBuffer() { + ReadableBuffer buffer = mock(ReadableBuffer.class); + when(buffer.byteBufferSupported()).thenReturn(true); + InputStream inputStream = ReadableBuffers.openStream(buffer, true); + assertTrue(((HasByteBuffer) inputStream).byteBufferSupported()); + ((HasByteBuffer) inputStream).getByteBuffer(); + verify(buffer).getByteBuffer(); + } + + @Test + public void bufferInputStream_availableAfterDetached_returnsZeroByte() throws IOException { + ReadableBuffer buffer = ReadableBuffers.wrap(MSG_BYTES); + InputStream inputStream = ReadableBuffers.openStream(buffer, true); + assertEquals(5, inputStream.available()); + InputStream detachedStream = ((Detachable) inputStream).detach(); + assertEquals(0, inputStream.available()); + assertEquals(5, detachedStream.available()); + } + + @Test + public void bufferInputStream_skipAfterDetached() throws IOException { + ReadableBuffer buffer = ReadableBuffers.wrap(MSG_BYTES); + InputStream inputStream = ReadableBuffers.openStream(buffer, true); + assertEquals(3, inputStream.skip(3)); + InputStream detachedStream = ((Detachable) inputStream).detach(); + assertEquals(0, inputStream.skip(2)); + assertEquals(2, detachedStream.skip(2)); + } + + @Test + public void bufferInputStream_readUnsignedByteAfterDetached() throws IOException { + ReadableBuffer buffer = ReadableBuffers.wrap(MSG_BYTES); + InputStream inputStream = ReadableBuffers.openStream(buffer, true); + assertEquals((int) 'h', inputStream.read()); + InputStream detachedStream = ((Detachable) inputStream).detach(); + assertEquals(-1, inputStream.read()); + assertEquals((int) 'e', detachedStream.read()); + } + + @Test + public void bufferInputStream_partialReadAfterDetached() throws IOException { + ReadableBuffer buffer = ReadableBuffers.wrap(MSG_BYTES); + InputStream inputStream = ReadableBuffers.openStream(buffer, true); + byte[] dest = new byte[3]; + assertEquals(3, inputStream.read(dest, /*destOffset*/ 0, /*length*/ 3)); + assertArrayEquals(new byte[]{'h', 'e', 'l'}, dest); + InputStream detachedStream = ((Detachable) inputStream).detach(); + byte[] newDest = new byte[2]; + assertEquals(2, detachedStream.read(newDest, /*destOffset*/ 0, /*length*/ 2)); + assertArrayEquals(new byte[]{'l', 'o'}, newDest); + } + + @Test + public void bufferInputStream_markDiscardedAfterDetached() throws IOException { + ReadableBuffer buffer = ReadableBuffers.wrap(MSG_BYTES); + InputStream inputStream = ReadableBuffers.openStream(buffer, true); + inputStream.mark(5); + ((Detachable) inputStream).detach(); + thrown.expect(InvalidMarkException.class); + inputStream.reset(); + } + + @Test + public void bufferInputStream_markPreservedInForkedInputStream() throws IOException { + ReadableBuffer buffer = ReadableBuffers.wrap(MSG_BYTES); + InputStream inputStream = ReadableBuffers.openStream(buffer, true); + inputStream.skip(2); + inputStream.mark(3); + InputStream detachedStream = ((Detachable) inputStream).detach(); + detachedStream.skip(3); + assertEquals(0, detachedStream.available()); + detachedStream.reset(); + assertEquals(3, detachedStream.available()); + } + + @Test + public void bufferInputStream_closeAfterDetached() throws IOException { + ReadableBuffer buffer = mock(ReadableBuffer.class); + when(buffer.readBytes(anyInt())).thenReturn(mock(ReadableBuffer.class)); + InputStream inputStream = ReadableBuffers.openStream(buffer, true); + InputStream detachedStream = ((Detachable) inputStream).detach(); + inputStream.close(); + verify(buffer, never()).close(); + detachedStream.close(); + verify(buffer).close(); + } } diff --git a/cronet/README.md b/cronet/README.md index 5a85d37cae5..bd5329e5192 100644 --- a/cronet/README.md +++ b/cronet/README.md @@ -26,7 +26,7 @@ In your app module's `build.gradle` file, include a dependency on both `grpc-cro Google Play Services Client Library for Cronet ``` -implementation 'io.grpc:grpc-cronet:1.37.0' +implementation 'io.grpc:grpc-cronet:1.39.0' implementation 'com.google.android.gms:play-services-cronet:16.0.0' ``` diff --git a/cronet/build.gradle b/cronet/build.gradle index 95aa7e9e236..2d73cc4194f 100644 --- a/cronet/build.gradle +++ b/cronet/build.gradle @@ -35,7 +35,7 @@ android { dependencies { api project(':grpc-core'), libraries.cronet_api - guavaDependency 'implementation' + implementation libraries.guava testImplementation project(':grpc-testing') testImplementation libraries.cronet_embedded diff --git a/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java b/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java index 217928ae94a..22778691a10 100644 --- a/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java +++ b/cronet/src/main/java/io/grpc/cronet/CronetChannelBuilder.java @@ -24,6 +24,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import com.google.common.util.concurrent.MoreExecutors; +import com.google.errorprone.annotations.DoNotCall; import io.grpc.ChannelCredentials; import io.grpc.ChannelLogger; import io.grpc.ExperimentalApi; @@ -71,6 +72,7 @@ public static CronetChannelBuilder forAddress(String host, int port, CronetEngin /** * Always fails. Call {@link #forAddress(String, int, CronetEngine)} instead. */ + @DoNotCall("Unsupported. Use forAddress(String, int, CronetEngine) instead") public static CronetChannelBuilder forTarget(String target) { throw new UnsupportedOperationException("call forAddress() instead"); } @@ -78,6 +80,7 @@ public static CronetChannelBuilder forTarget(String target) { /** * Always fails. Call {@link #forAddress(String, int, CronetEngine)} instead. */ + @DoNotCall("Unsupported. Use forAddress(String, int, CronetEngine) instead") public static CronetChannelBuilder forAddress(String name, int port) { throw new UnsupportedOperationException("call forAddress(String, int, CronetEngine) instead"); } diff --git a/documentation/android-channel-builder.md b/documentation/android-channel-builder.md index d516a7db342..93447639197 100644 --- a/documentation/android-channel-builder.md +++ b/documentation/android-channel-builder.md @@ -36,8 +36,8 @@ In your `build.gradle` file, include a dependency on both `grpc-android` and `grpc-okhttp`: ``` -implementation 'io.grpc:grpc-android:1.37.0' -implementation 'io.grpc:grpc-okhttp:1.37.0' +implementation 'io.grpc:grpc-android:1.39.0' +implementation 'io.grpc:grpc-okhttp:1.39.0' ``` You also need permission to access the device's network state in your diff --git a/examples/android/clientcache/app/build.gradle b/examples/android/clientcache/app/build.gradle index b13afd88411..8848e2bfc9a 100644 --- a/examples/android/clientcache/app/build.gradle +++ b/examples/android/clientcache/app/build.gradle @@ -32,9 +32,9 @@ android { } protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.12.0' } + protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.39.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -54,12 +54,12 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.39.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.39.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.39.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' testImplementation 'junit:junit:4.12' testImplementation 'com.google.truth:truth:1.0.1' - testImplementation 'io.grpc:grpc-testing:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION + testImplementation 'io.grpc:grpc-testing:1.39.0' // CURRENT_GRPC_VERSION } diff --git a/examples/android/clientcache/build.gradle b/examples/android/clientcache/build.gradle index f5c361578fb..ed4e8159386 100644 --- a/examples/android/clientcache/build.gradle +++ b/examples/android/clientcache/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:4.0.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.15" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.16" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/helloworld/app/build.gradle b/examples/android/helloworld/app/build.gradle index 24351d62439..6ac738807ea 100644 --- a/examples/android/helloworld/app/build.gradle +++ b/examples/android/helloworld/app/build.gradle @@ -30,9 +30,9 @@ android { } protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.12.0' } + protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.39.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.39.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.39.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.39.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/helloworld/build.gradle b/examples/android/helloworld/build.gradle index f5c361578fb..ed4e8159386 100644 --- a/examples/android/helloworld/build.gradle +++ b/examples/android/helloworld/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:4.0.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.15" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.16" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/routeguide/app/build.gradle b/examples/android/routeguide/app/build.gradle index e01287c6098..2e9f87dbf71 100644 --- a/examples/android/routeguide/app/build.gradle +++ b/examples/android/routeguide/app/build.gradle @@ -30,9 +30,9 @@ android { } protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.12.0' } + protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.39.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:27.0.2' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.39.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.39.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.39.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/routeguide/build.gradle b/examples/android/routeguide/build.gradle index 51cde168106..980cf32ff85 100644 --- a/examples/android/routeguide/build.gradle +++ b/examples/android/routeguide/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:4.0.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.15" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.16" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index 552b5a3bb90..0d45eac1d50 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -31,9 +31,9 @@ android { } protobuf { - protoc { artifact = 'com.google.protobuf:protoc:3.12.0' } + protoc { artifact = 'com.google.protobuf:protoc:3.17.2' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.39.0' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -53,8 +53,8 @@ dependencies { implementation 'com.android.support:appcompat-v7:28.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.39.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.39.0' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.39.0' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/strictmode/build.gradle b/examples/android/strictmode/build.gradle index f5c361578fb..ed4e8159386 100644 --- a/examples/android/strictmode/build.gradle +++ b/examples/android/strictmode/build.gradle @@ -7,7 +7,7 @@ buildscript { } dependencies { classpath 'com.android.tools.build:gradle:4.0.0' - classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.15" + classpath "com.google.protobuf:protobuf-gradle-plugin:0.8.16" // NOTE: Do not place your application dependencies here; they belong // in the individual module build.gradle files diff --git a/examples/build.gradle b/examples/build.gradle index 594f5c74dbf..93486aa613d 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -2,7 +2,7 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' // ASSUMES GRADLE 5.6 OR HIGHER. Use plugin version 0.8.10 with earlier gradle versions - id 'com.google.protobuf' version '0.8.15' + id 'com.google.protobuf' version '0.8.16' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } @@ -22,8 +22,8 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.12.0' +def grpcVersion = '1.39.0' // CURRENT_GRPC_VERSION +def protobufVersion = '3.17.2' def protocVersion = protobufVersion dependencies { diff --git a/examples/example-alts/build.gradle b/examples/example-alts/build.gradle index d8b8c0d4bf2..137377dbfbe 100644 --- a/examples/example-alts/build.gradle +++ b/examples/example-alts/build.gradle @@ -2,7 +2,7 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' // ASSUMES GRADLE 5.6 OR HIGHER. Use plugin version 0.8.10 with earlier gradle versions - id 'com.google.protobuf' version '0.8.15' + id 'com.google.protobuf' version '0.8.16' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } @@ -23,8 +23,8 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.12.0' +def grpcVersion = '1.39.0' // CURRENT_GRPC_VERSION +def protocVersion = '3.17.2' dependencies { // grpc-alts transitively depends on grpc-netty-shaded, grpc-protobuf, and grpc-stub diff --git a/examples/example-gauth/build.gradle b/examples/example-gauth/build.gradle index 3e3730cb284..c3e98839679 100644 --- a/examples/example-gauth/build.gradle +++ b/examples/example-gauth/build.gradle @@ -2,7 +2,7 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' // ASSUMES GRADLE 5.6 OR HIGHER. Use plugin version 0.8.10 with earlier gradle versions - id 'com.google.protobuf' version '0.8.15' + id 'com.google.protobuf' version '0.8.16' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } @@ -23,8 +23,8 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.12.0' +def grpcVersion = '1.39.0' // CURRENT_GRPC_VERSION +def protobufVersion = '3.17.2' def protocVersion = protobufVersion diff --git a/examples/example-gauth/pom.xml b/examples/example-gauth/pom.xml index 4b740a1997c..e06ea648a0a 100644 --- a/examples/example-gauth/pom.xml +++ b/examples/example-gauth/pom.xml @@ -6,14 +6,14 @@ jar - 1.38.0-SNAPSHOT + 1.39.0 example-gauth https://github.com/grpc/grpc-java UTF-8 - 1.38.0-SNAPSHOT - 3.12.0 + 1.39.0 + 3.17.2 1.7 1.7 @@ -49,11 +49,6 @@ io.grpc grpc-auth - - com.google.protobuf - protobuf-java-util - ${protobuf.version} - org.apache.tomcat annotations-api diff --git a/examples/example-hostname/build.gradle b/examples/example-hostname/build.gradle index ca4037d7078..2607b7ea17b 100644 --- a/examples/example-hostname/build.gradle +++ b/examples/example-hostname/build.gradle @@ -2,7 +2,7 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. id 'java' - id "com.google.protobuf" version "0.8.15" + id "com.google.protobuf" version "0.8.16" id 'com.google.cloud.tools.jib' version '2.7.0' // For releasing to Docker Hub } @@ -21,8 +21,8 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.12.0' +def grpcVersion = '1.39.0' // CURRENT_GRPC_VERSION +def protobufVersion = '3.17.2' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" diff --git a/examples/example-hostname/pom.xml b/examples/example-hostname/pom.xml index bc34657d9b8..f836fe34800 100644 --- a/examples/example-hostname/pom.xml +++ b/examples/example-hostname/pom.xml @@ -6,14 +6,14 @@ jar - 1.38.0-SNAPSHOT + 1.39.0 example-hostname https://github.com/grpc/grpc-java UTF-8 - 1.38.0-SNAPSHOT - 3.12.0 + 1.39.0 + 3.17.2 1.7 1.7 diff --git a/examples/example-jwt-auth/build.gradle b/examples/example-jwt-auth/build.gradle index c6334185661..e53ebaa773a 100644 --- a/examples/example-jwt-auth/build.gradle +++ b/examples/example-jwt-auth/build.gradle @@ -2,7 +2,7 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' // ASSUMES GRADLE 5.6 OR HIGHER. Use plugin version 0.8.10 with earlier gradle versions - id 'com.google.protobuf' version '0.8.15' + id 'com.google.protobuf' version '0.8.16' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } @@ -22,8 +22,8 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protobufVersion = '3.12.0' +def grpcVersion = '1.39.0' // CURRENT_GRPC_VERSION +def protobufVersion = '3.17.2' def protocVersion = protobufVersion dependencies { diff --git a/examples/example-jwt-auth/pom.xml b/examples/example-jwt-auth/pom.xml index fc16655869e..d84be4b0732 100644 --- a/examples/example-jwt-auth/pom.xml +++ b/examples/example-jwt-auth/pom.xml @@ -7,15 +7,15 @@ jar - 1.38.0-SNAPSHOT + 1.39.0 example-jwt-auth https://github.com/grpc/grpc-java UTF-8 - 1.38.0-SNAPSHOT - 3.12.0 - 3.12.0 + 1.39.0 + 3.17.2 + 3.17.2 1.7 1.7 diff --git a/examples/example-tls/build.gradle b/examples/example-tls/build.gradle index a79ca46058c..5e418d6e583 100644 --- a/examples/example-tls/build.gradle +++ b/examples/example-tls/build.gradle @@ -2,7 +2,7 @@ plugins { // Provide convenience executables for trying out the examples. id 'application' // ASSUMES GRADLE 5.6 OR HIGHER. Use plugin version 0.8.10 with earlier gradle versions - id 'com.google.protobuf' version '0.8.15' + id 'com.google.protobuf' version '0.8.16' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' } @@ -23,8 +23,8 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION -def protocVersion = '3.12.0' +def grpcVersion = '1.39.0' // CURRENT_GRPC_VERSION +def protocVersion = '3.17.2' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" diff --git a/examples/example-tls/pom.xml b/examples/example-tls/pom.xml index c22cb2cdad5..2eda1373e54 100644 --- a/examples/example-tls/pom.xml +++ b/examples/example-tls/pom.xml @@ -6,14 +6,14 @@ jar - 1.38.0-SNAPSHOT + 1.39.0 example-tls https://github.com/grpc/grpc-java UTF-8 - 1.38.0-SNAPSHOT - 3.12.0 + 1.39.0 + 3.17.2 2.0.34.Final 1.7 diff --git a/examples/example-xds/build.gradle b/examples/example-xds/build.gradle index 6a3989647ab..5f73d9f4ec6 100644 --- a/examples/example-xds/build.gradle +++ b/examples/example-xds/build.gradle @@ -1,7 +1,7 @@ plugins { id 'application' // Provide convenience executables for trying out the examples. // ASSUMES GRADLE 5.6 OR HIGHER. Use plugin version 0.8.10 with earlier gradle versions - id 'com.google.protobuf' version '0.8.15' + id 'com.google.protobuf' version '0.8.16' // Generate IntelliJ IDEA's .idea & .iml project files id 'idea' id 'java' @@ -22,9 +22,9 @@ targetCompatibility = 1.7 // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.38.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.39.0' // CURRENT_GRPC_VERSION def nettyTcNativeVersion = '2.0.31.Final' -def protocVersion = '3.12.0' +def protocVersion = '3.17.2' dependencies { implementation "io.grpc:grpc-protobuf:${grpcVersion}" diff --git a/examples/pom.xml b/examples/pom.xml index ad07641e61c..e7d87629bd4 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -6,15 +6,15 @@ jar - 1.38.0-SNAPSHOT + 1.39.0 examples https://github.com/grpc/grpc-java UTF-8 - 1.38.0-SNAPSHOT - 3.12.0 - 3.12.0 + 1.39.0 + 3.17.2 + 3.17.2 1.7 1.7 diff --git a/grpclb/BUILD.bazel b/grpclb/BUILD.bazel index b69fb234733..a18795d0485 100644 --- a/grpclb/BUILD.bazel +++ b/grpclb/BUILD.bazel @@ -9,6 +9,7 @@ java_library( deps = [ ":load_balancer_java_grpc", "//api", + "//context", "//core:internal", "//core:util", "//stub", diff --git a/grpclb/build.gradle b/grpclb/build.gradle index 973770feb2f..58ff2f412d1 100644 --- a/grpclb/build.gradle +++ b/grpclb/build.gradle @@ -14,13 +14,9 @@ dependencies { implementation project(':grpc-core'), project(':grpc-protobuf'), project(':grpc-stub'), - libraries.protobuf - implementation (libraries.protobuf_util) { - // prefer our own versions instead of protobuf-util's dependency - exclude group: 'com.google.guava', module: 'guava' - exclude group: 'com.google.errorprone', module: 'error_prone_annotations' - } - guavaDependency 'implementation' + libraries.protobuf, + libraries.protobuf_util, + libraries.guava runtimeOnly libraries.errorprone compileOnly libraries.javax_annotation testImplementation libraries.truth, diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java index 1a8dec36e38..65293d24511 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancer.java @@ -23,6 +23,7 @@ import com.google.common.base.Stopwatch; import io.grpc.Attributes; import io.grpc.ChannelLogger.ChannelLogLevel; +import io.grpc.Context; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; import io.grpc.Status; @@ -45,6 +46,7 @@ class GrpclbLoadBalancer extends LoadBalancer { private static final GrpclbConfig DEFAULT_CONFIG = GrpclbConfig.create(Mode.ROUND_ROBIN); private final Helper helper; + private final Context context; private final TimeProvider time; private final Stopwatch stopwatch; private final SubchannelPool subchannelPool; @@ -58,11 +60,13 @@ class GrpclbLoadBalancer extends LoadBalancer { GrpclbLoadBalancer( Helper helper, + Context context, SubchannelPool subchannelPool, TimeProvider time, Stopwatch stopwatch, BackoffPolicy.Provider backoffPolicyProvider) { this.helper = checkNotNull(helper, "helper"); + this.context = checkNotNull(context, "context"); this.time = checkNotNull(time, "time provider"); this.stopwatch = checkNotNull(stopwatch, "stopwatch"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); @@ -131,7 +135,7 @@ private void recreateStates() { checkState(grpclbState == null, "Should've been cleared"); grpclbState = new GrpclbState( - config, helper, subchannelPool, time, stopwatch, backoffPolicyProvider); + config, helper, context, subchannelPool, time, stopwatch, backoffPolicyProvider); } @Override diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerProvider.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerProvider.java index badcfdcec7c..fa9b6963f33 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerProvider.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbLoadBalancerProvider.java @@ -17,6 +17,7 @@ package io.grpc.grpclb; import com.google.common.base.Stopwatch; +import io.grpc.Context; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancerProvider; @@ -62,6 +63,7 @@ public LoadBalancer newLoadBalancer(LoadBalancer.Helper helper) { return new GrpclbLoadBalancer( helper, + Context.ROOT, new CachedSubchannelPool(helper), TimeProvider.SYSTEM_TIME_PROVIDER, Stopwatch.createUnstarted(), diff --git a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java index 8c638b979ed..19bc35b373f 100644 --- a/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java +++ b/grpclb/src/main/java/io/grpc/grpclb/GrpclbState.java @@ -35,6 +35,7 @@ import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; +import io.grpc.Context; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; @@ -132,6 +133,7 @@ enum Mode { private final String serviceName; private final Helper helper; + private final Context context; private final SynchronizationContext syncContext; @Nullable private final SubchannelPool subchannelPool; @@ -182,12 +184,14 @@ enum Mode { GrpclbState( GrpclbConfig config, Helper helper, + Context context, SubchannelPool subchannelPool, TimeProvider time, Stopwatch stopwatch, BackoffPolicy.Provider backoffPolicyProvider) { this.config = checkNotNull(config, "config"); this.helper = checkNotNull(helper, "helper"); + this.context = checkNotNull(context, "context"); this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); if (config.getMode() == Mode.ROUND_ROBIN) { this.subchannelPool = checkNotNull(subchannelPool, "subchannelPool"); @@ -255,11 +259,19 @@ void handleAddresses( serviceName, newLbAddressGroups, newBackendServers); + fallbackBackendList = newBackendServers; if (newLbAddressGroups.isEmpty()) { - // No balancer address: close existing balancer connection and enter fallback mode - // immediately. + // No balancer address: close existing balancer connection and prepare to enter fallback + // mode. If there is no successful backend connection, it enters fallback mode immediately. + // Otherwise, fallback does not happen until backend connections are lost. This behavior + // might be different from other languages (e.g., existing balancer connection is not + // closed in C-core), but we aren't changing it at this time. shutdownLbComm(); - syncContext.execute(new FallbackModeTask(NO_LB_ADDRESS_PROVIDED_STATUS)); + if (!usingFallbackBackends) { + fallbackReason = NO_LB_ADDRESS_PROVIDED_STATUS; + cancelFallbackTimer(); + maybeUseFallbackBackends(); + } } else { startLbComm(newLbAddressGroups); // Avoid creating a new RPC just because the addresses were updated, as it can cause a @@ -277,7 +289,6 @@ void handleAddresses( TimeUnit.MILLISECONDS, timerService); } } - fallbackBackendList = newBackendServers; if (usingFallbackBackends) { // Populate the new fallback backends to round-robin list. useFallbackBackends(); @@ -368,7 +379,12 @@ private void startLbRpc() { checkState(lbStream == null, "previous lbStream has not been cleared yet"); LoadBalancerGrpc.LoadBalancerStub stub = LoadBalancerGrpc.newStub(lbCommChannel); lbStream = new LbStream(stub); - lbStream.start(); + Context prevContext = context.attach(); + try { + lbStream.start(); + } finally { + context.detach(prevContext); + } stopwatch.reset().start(); LoadBalanceRequest initRequest = LoadBalanceRequest.newBuilder() diff --git a/grpclb/src/main/proto/grpc/lb/v1/load_balancer.proto b/grpclb/src/main/proto/grpc/lb/v1/load_balancer.proto index a9588b0db94..00fc7096c9c 100644 --- a/grpclb/src/main/proto/grpc/lb/v1/load_balancer.proto +++ b/grpclb/src/main/proto/grpc/lb/v1/load_balancer.proto @@ -104,12 +104,7 @@ message LoadBalanceResponse { message FallbackResponse {} message InitialLoadBalanceResponse { - // This is an application layer redirect that indicates the client should use - // the specified server for load balancing. When this field is non-empty in - // the response, the client should open a separate connection to the - // load_balancer_delegate and call the BalanceLoad method. Its length should - // be less than 64 bytes. - string load_balancer_delegate = 1; + reserved 1; // never-used load_balancer_delegate // This interval defines how often the client should send the client stats // to the load balancer. Stats should only be reported when the duration is diff --git a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java index 0c194ae84c9..f664aad0a7a 100644 --- a/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java +++ b/grpclb/src/test/java/io/grpc/grpclb/GrpclbLoadBalancerTest.java @@ -55,6 +55,8 @@ import io.grpc.ClientStreamTracer; import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; +import io.grpc.Context; +import io.grpc.Context.CancellableContext; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer.CreateSubchannelArgs; import io.grpc.LoadBalancer.Helper; @@ -229,6 +231,7 @@ public Void answer(InvocationOnMock invocation) { when(backoffPolicyProvider.get()).thenReturn(backoffPolicy1, backoffPolicy2); balancer = new GrpclbLoadBalancer( helper, + Context.ROOT, subchannelPool, fakeClock.getTimeProvider(), fakeClock.getStopwatchSupplier().get(), @@ -1449,8 +1452,11 @@ public void grpclbFallback_breakLbStreamBeforeFallbackTimerExpires() { public void grpclbFallback_noBalancerAddress() { InOrder inOrder = inOrder(helper, subchannelPool); - // Create just backend addresses - List backendList = createResolvedBackendAddresses(2); + // Create 5 distinct backends + List backends = createResolvedBackendAddresses(5); + + // Name resolver gives the first two backend addresses + List backendList = backends.subList(0, 2); deliverResolvedAddresses(backendList, Collections.emptyList()); assertThat(logs).containsAtLeast( @@ -1471,6 +1477,28 @@ public void grpclbFallback_noBalancerAddress() { .createOobChannel(ArgumentMatchers.anyList(), anyString()); logs.clear(); + ///////////////////////////////////////////////////////////////////////////////////////// + // Name resolver sends new resolution results with new backend addr but no balancer addr + ///////////////////////////////////////////////////////////////////////////////////////// + // Name resolver then gives the last three backends + backendList = backends.subList(2, 5); + deliverResolvedAddresses(backendList, Collections.emptyList()); + + assertThat(logs).containsAtLeast( + "INFO: [grpclb-] Using fallback backends", + "INFO: [grpclb-] " + + "Using RR list=[[[FakeSocketAddress-fake-address-2]/{}], " + + "[[FakeSocketAddress-fake-address-3]/{}], " + + "[[FakeSocketAddress-fake-address-4]/{}]], drop=[null, null, null]", + "INFO: [grpclb-] " + + "Update balancing state to CONNECTING: picks=[BUFFER_ENTRY], " + + "drops=[null, null, null]") + .inOrder(); + + // Shift to use updated backends + fallbackTestVerifyUseOfFallbackBackendLists(inOrder, backendList); + logs.clear(); + /////////////////////////////////////////////////////////////////////////////////////// // Name resolver sends new resolution results without any backend addr or balancer addr /////////////////////////////////////////////////////////////////////////////////////// @@ -2683,6 +2711,39 @@ public void grpclbWorking_lbSendsFallbackMessage() { .inOrder(); } + @Test + public void useIndependentRpcContext() { + // Simulates making RPCs within the context of an inbound RPC. + CancellableContext cancellableContext = Context.current().withCancellation(); + Context prevContext = cancellableContext.attach(); + try { + List backendList = createResolvedBackendAddresses(2); + List grpclbBalancerList = createResolvedBalancerAddresses(2); + deliverResolvedAddresses(backendList, grpclbBalancerList); + + verify(helper).createOobChannel(eq(xattr(grpclbBalancerList)), + eq(lbAuthority(0) + NO_USE_AUTHORITY_SUFFIX)); + verify(mockLbService).balanceLoad(lbResponseObserverCaptor.capture()); + StreamObserver lbResponseObserver = lbResponseObserverCaptor.getValue(); + assertEquals(1, lbRequestObservers.size()); + StreamObserver lbRequestObserver = lbRequestObservers.poll(); + verify(lbRequestObserver).onNext( + eq(LoadBalanceRequest.newBuilder() + .setInitialRequest( + InitialLoadBalanceRequest.newBuilder().setName(SERVICE_AUTHORITY).build()) + .build())); + lbResponseObserver.onNext(buildInitialResponse()); + + // The inbound RPC finishes and closes its context. The outbound RPC's control plane RPC + // should not be impacted (no retry). + cancellableContext.close(); + assertEquals(0, fakeClock.numPendingTasks(LB_RPC_RETRY_TASK_FILTER)); + verifyNoMoreInteractions(mockLbService); + } finally { + cancellableContext.detach(prevContext); + } + } + private void deliverSubchannelState( final Subchannel subchannel, final ConnectivityStateInfo newState) { ((FakeSubchannel) subchannel).updateState(newState); diff --git a/interop-testing/build.gradle b/interop-testing/build.gradle index 01af6529324..79aa5356ecd 100644 --- a/interop-testing/build.gradle +++ b/interop-testing/build.gradle @@ -28,9 +28,9 @@ dependencies { project(':grpc-testing'), project(path: ':grpc-xds', configuration: 'shadow'), libraries.junit, - libraries.truth - censusGrpcMetricDependency 'implementation' - googleOauth2Dependency 'implementation' + libraries.truth, + libraries.opencensus_contrib_grpc_metrics, + libraries.google_auth_oauth2_http compileOnly libraries.javax_annotation // TODO(sergiitk): replace with com.google.cloud:google-cloud-logging // Used instead of google-cloud-logging because it's failing diff --git a/netty/build.gradle b/netty/build.gradle index 20b35eb36d1..00726940cde 100644 --- a/netty/build.gradle +++ b/netty/build.gradle @@ -18,9 +18,10 @@ evaluationDependsOn(project(':grpc-core').path) dependencies { api project(':grpc-core'), libraries.netty - implementation libraries.netty_proxy_handler - guavaDependency 'implementation' - perfmarkDependency 'implementation' + implementation libraries.netty_proxy_handler, + libraries.guava, + libraries.errorprone, + libraries.perfmark // Tests depend on base class defined by core module. testImplementation project(':grpc-core').sourceSets.test.output, diff --git a/netty/shaded/build.gradle b/netty/shaded/build.gradle index 37c5428b158..521256ea13d 100644 --- a/netty/shaded/build.gradle +++ b/netty/shaded/build.gradle @@ -1,3 +1,19 @@ +import com.github.jengelman.gradle.plugins.shadow.transformers.Transformer +import com.github.jengelman.gradle.plugins.shadow.transformers.TransformerContext +import org.gradle.api.file.FileTreeElement +import shadow.org.apache.tools.zip.ZipOutputStream +import shadow.org.apache.tools.zip.ZipEntry + + +buildscript { + repositories { + jcenter() + } + dependencies { + classpath "com.github.jengelman.gradle.plugins:shadow:6.1.0" + } +} + plugins { id "java" id "maven-publish" @@ -17,7 +33,9 @@ dependencies { project(':grpc-testing-proto'), project(':grpc-testing'), libraries.truth - shadow project(':grpc-core') + shadow project(':grpc-netty').configurations.runtimeClasspath.allDependencies.matching { + it.group != 'io.netty' + } } jar { @@ -31,6 +49,7 @@ shadowJar { include(project(':grpc-netty')) include(dependency('io.netty:')) } + exclude 'META-INF/maven/**' relocate 'io.grpc.netty', 'io.grpc.netty.shaded.io.grpc.netty' relocate 'io.netty', 'io.grpc.netty.shaded.io.netty' // We have to be careful with these replacements as they must not match any @@ -38,27 +57,38 @@ shadowJar { // this includes concatenation of string literals and constants. relocate 'META-INF/native/libnetty', 'META-INF/native/libio_grpc_netty_shaded_netty' relocate 'META-INF/native/netty', 'META-INF/native/io_grpc_netty_shaded_netty' + transform(NettyResourceTransformer.class) mergeServiceFiles() } publishing { publications { maven(MavenPublication) { - // Ideally swap to project.shadow.component(it) when it isn't broken for project deps - artifact shadowJar + project.shadow.component(it) + // Empty jars are not published via withJavadocJar() and withSourcesJar() artifact javadocJar artifact sourcesJar + // Avoid confusing error message "class file for + // io.grpc.internal.AbstractServerImplBuilder not found" + // (https://github.com/grpc/grpc-java/issues/5881). This can be + // removed after https://github.com/grpc/grpc-java/issues/7211 is + // resolved. + pom.withXml { + asNode().dependencies.'*'.findAll() { dep -> + dep.artifactId.text() == 'grpc-core' + }.each() { core -> + core.scope*.value = "compile" + } + } + + // shadow.component() is run after the main build.gradle's withXml pom.withXml { - def dependencies = asNode().appendNode('dependencies') - project.configurations.shadow.allDependencies.each { dep -> - def dependencyNode = dependencies.appendNode('dependency') - dependencyNode.appendNode('groupId', dep.group) - dependencyNode.appendNode('artifactId', dep.name) - def version = (dep.name == 'grpc-core') ? '[' + dep.version + ']' : dep.version - dependencyNode.appendNode('version', version) - dependencyNode.appendNode('scope', 'compile') + asNode().dependencies.'*'.findAll() { dep -> + dep.artifactId.text() in ['grpc-api', 'grpc-core'] + }.each() { core -> + core.version*.value = "[" + core.version.text() + "]" } } } @@ -73,3 +103,41 @@ compileTestShadowJava.options.compilerArgs = compileTestJava.options.compilerArg compileTestShadowJava.options.encoding = compileTestJava.options.encoding test.dependsOn testShadow + +/** + * A Transformer which updates the Netty JAR META-INF/ resources to accurately + * reference shaded class names. + */ +class NettyResourceTransformer implements Transformer { + + // A map of resource file paths to be modified + private Map resources = [:] + + @Override + boolean canTransformResource(FileTreeElement fileTreeElement) { + fileTreeElement.name.startsWith("META-INF/native-image/io.netty") + } + + @Override + void transform(TransformerContext context) { + String updatedContent = context.is.getText().replace("io.netty", "io.grpc.netty.shaded.io.netty") + resources.put(context.path, updatedContent) + } + + @Override + boolean hasTransformedResource() { + resources.size() > 0 + } + + @Override + void modifyOutputStream(ZipOutputStream outputStream, boolean preserveFileTimestamps) { + for (resourceEntry in resources) { + ZipEntry entry = new ZipEntry(resourceEntry.key) + entry.time = TransformerContext.getEntryTimestamp(preserveFileTimestamps, entry.time) + + outputStream.putNextEntry(entry) + outputStream.write(resourceEntry.value.getBytes()) + outputStream.closeEntry() + } + } +} diff --git a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java index e4a84986976..d3bdc4394ca 100644 --- a/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java +++ b/netty/shaded/src/testShadow/java/io/grpc/netty/shaded/ShadingTest.java @@ -39,6 +39,11 @@ import io.grpc.testing.protobuf.SimpleServiceGrpc; import io.grpc.testing.protobuf.SimpleServiceGrpc.SimpleServiceBlockingStub; import io.grpc.testing.protobuf.SimpleServiceGrpc.SimpleServiceImplBase; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.Scanner; import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.Test; @@ -69,6 +74,19 @@ public void noNormalNetty() throws Exception { Class.forName("io.grpc.netty.NettyServerBuilder"); } + /** Verify that resources under META-INF/native-image reference shaded class names. */ + @Test + public void nettyResourcesUpdated() throws IOException { + InputStream inputStream = NettyChannelBuilder.class.getClassLoader() + .getResourceAsStream("META-INF/native-image/io.netty/transport/reflection-config.json"); + assertThat(inputStream).isNotNull(); + + Scanner s = new Scanner(inputStream, StandardCharsets.UTF_8.name()).useDelimiter("\\A"); + String reflectionConfig = s.hasNext() ? s.next() : ""; + + assertThat(reflectionConfig).contains("io.grpc.netty.shaded.io.netty"); + } + @Test public void serviceLoaderFindsNetty() throws Exception { assertThat(Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create())) diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index bb203e2906a..25338c4100d 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -127,6 +127,23 @@ public static NettyChannelBuilder forAddress(SocketAddress serverAddress) { return new NettyChannelBuilder(serverAddress); } + /** + * Creates a new builder with the given server address. This factory method is primarily intended + * for using Netty Channel types other than SocketChannel. + * {@link #forAddress(String, int, ChannelCredentials)} should generally be preferred over this + * method, since that API permits delaying DNS lookups and noticing changes to DNS. If an + * unresolved InetSocketAddress is passed in, then it will remain unresolved. + */ + @CheckReturnValue + public static NettyChannelBuilder forAddress(SocketAddress serverAddress, + ChannelCredentials creds) { + FromChannelCredentialsResult result = ProtocolNegotiators.from(creds); + if (result.error != null) { + throw new IllegalArgumentException(result.error); + } + return new NettyChannelBuilder(serverAddress, creds, result.callCredentials, result.negotiator); + } + /** * Creates a new builder with the given host and port. */ @@ -207,6 +224,18 @@ public int getDefaultPort() { this.freezeProtocolNegotiatorFactory = false; } + NettyChannelBuilder( + SocketAddress address, ChannelCredentials channelCreds, CallCredentials callCreds, + ProtocolNegotiator.ClientFactory negotiator) { + managedChannelImplBuilder = new ManagedChannelImplBuilder(address, + getAuthorityFromAddress(address), + channelCreds, callCreds, + new NettyChannelTransportFactoryBuilder(), + new NettyChannelDefaultPortProvider()); + this.protocolNegotiatorFactory = checkNotNull(negotiator, "negotiator"); + this.freezeProtocolNegotiatorFactory = true; + } + @Internal @Override protected ManagedChannelBuilder delegate() { diff --git a/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java b/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java index 37caccb0eb3..cce58f1e60d 100644 --- a/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java +++ b/netty/src/main/java/io/grpc/netty/NettyReadableBuffer.java @@ -94,6 +94,31 @@ public int arrayOffset() { return buffer.arrayOffset() + buffer.readerIndex(); } + @Override + public boolean markSupported() { + return true; + } + + @Override + public void mark() { + buffer.markReaderIndex(); + } + + @Override + public void reset() { + buffer.resetReaderIndex(); + } + + @Override + public boolean byteBufferSupported() { + return buffer.nioBufferCount() > 0; + } + + @Override + public ByteBuffer getByteBuffer() { + return buffer.nioBufferCount() == 1 ? buffer.nioBuffer() : buffer.nioBuffers()[0]; + } + /** * If the first call to close, calls {@link ByteBuf#release} to release the internal Netty buffer. */ diff --git a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java index b255923abc7..032b040528b 100644 --- a/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyChannelBuilderTest.java @@ -23,6 +23,7 @@ import static org.mockito.Mockito.mock; import io.grpc.ChannelCredentials; +import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; import io.grpc.internal.ClientTransportFactory; import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult; @@ -153,6 +154,28 @@ public void failIfSslContextIsNotClient() { builder.sslContext(sslContext); } + @Test + public void failNegotiationTypeWithChannelCredentials_target() { + NettyChannelBuilder builder = NettyChannelBuilder.forTarget( + "fakeTarget", InsecureChannelCredentials.create()); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Cannot change security when using ChannelCredentials"); + + builder.negotiationType(NegotiationType.TLS); + } + + @Test + public void failNegotiationTypeWithChannelCredentials_socketAddress() { + NettyChannelBuilder builder = NettyChannelBuilder.forAddress( + new SocketAddress(){}, InsecureChannelCredentials.create()); + + thrown.expect(IllegalStateException.class); + thrown.expectMessage("Cannot change security when using ChannelCredentials"); + + builder.negotiationType(NegotiationType.TLS); + } + @Test public void createProtocolNegotiatorByType_plaintext() { ProtocolNegotiator negotiator = NettyChannelBuilder.createProtocolNegotiatorByType( diff --git a/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java b/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java index 8090e601911..1a0ac229a89 100644 --- a/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyReadableBufferTest.java @@ -17,11 +17,16 @@ package io.grpc.netty; import static com.google.common.base.Charsets.UTF_8; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import com.google.common.base.Splitter; import io.grpc.internal.ReadableBuffer; import io.grpc.internal.ReadableBufferTestBase; +import io.netty.buffer.CompositeByteBuf; import io.netty.buffer.Unpooled; +import java.nio.ByteBuffer; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -52,6 +57,29 @@ public void closeMultipleTimesShouldReleaseBufferOnce() { assertEquals(0, buffer.buffer().refCnt()); } + @Test + public void getByteBufferFromSingleNioBufferBackedBuffer() { + assertTrue(buffer.byteBufferSupported()); + ByteBuffer byteBuffer = buffer.getByteBuffer(); + byte[] arr = new byte[byteBuffer.remaining()]; + byteBuffer.get(arr); + assertArrayEquals(msg.getBytes(UTF_8), arr); + } + + @Test + public void getByteBufferFromCompositeBufferReturnsOnlyFirstComponent() { + CompositeByteBuf composite = Unpooled.compositeBuffer(10); + int chunks = 4; + int chunkLen = msg.length() / chunks; + for (String chunk : Splitter.fixedLength(chunkLen).split(msg)) { + composite.addComponent(true, Unpooled.copiedBuffer(chunk.getBytes(UTF_8))); + } + buffer = new NettyReadableBuffer(composite); + byte[] array = new byte[chunkLen]; + buffer.getByteBuffer().get(array); + assertArrayEquals(msg.substring(0, chunkLen).getBytes(UTF_8), array); + } + @Override protected ReadableBuffer buffer() { return buffer; diff --git a/okhttp/build.gradle b/okhttp/build.gradle index d28d4e00ca2..999f21e7c10 100644 --- a/okhttp/build.gradle +++ b/okhttp/build.gradle @@ -11,14 +11,11 @@ description = "gRPC: OkHttp" evaluationDependsOn(project(':grpc-core').path) dependencies { - api project(':grpc-core') - api (libraries.okhttp) { - // prefer our own versions instead of okhttp's dependency - exclude group: 'com.squareup.okio', module: 'okio' - } - implementation libraries.okio - guavaDependency 'implementation' - perfmarkDependency 'implementation' + api project(':grpc-core'), + libraries.okhttp + implementation libraries.okio, + libraries.guava, + libraries.perfmark // Tests depend on base class defined by core module. testImplementation project(':grpc-core').sourceSets.test.output, project(':grpc-api').sourceSets.test.output, diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java index 0347ffc770d..a001ddb73e7 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpClientTransport.java @@ -1106,8 +1106,14 @@ public void run() { // frameReader.nextFrame() returns false when the underlying read encounters an IOException, // it may be triggered by the socket closing, in such case, the startGoAway() will do // nothing, otherwise, we finish all streams since it's a real IO issue. - startGoAway(0, ErrorCode.INTERNAL_ERROR, - Status.UNAVAILABLE.withDescription("End of stream or IOException")); + Status status; + synchronized (lock) { + status = goAwayStatus; + } + if (status == null) { + status = Status.UNAVAILABLE.withDescription("End of stream or IOException"); + } + startGoAway(0, ErrorCode.INTERNAL_ERROR, status); } catch (Throwable t) { // TODO(madongfly): Send the exception message to the server. startGoAway( diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java index 2ece98ffb97..4aeeae2fa8b 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpReadableBufferTest.java @@ -56,6 +56,18 @@ public void partialReadToByteBufferShouldSucceed() { // Not supported. } + @Override + @Test + public void markAndResetWithReadShouldSucceed() { + // Not supported. + } + + @Override + @Test + public void markAndResetWithReadToReadableBufferShouldSucceed() { + // Not supported. + } + @Override protected ReadableBuffer buffer() { return buffer; diff --git a/protobuf-lite/build.gradle b/protobuf-lite/build.gradle index 000d8c721e0..7b58309c414 100644 --- a/protobuf-lite/build.gradle +++ b/protobuf-lite/build.gradle @@ -12,8 +12,8 @@ description = 'gRPC: Protobuf Lite' dependencies { api project(':grpc-api'), libraries.protobuf_lite - implementation libraries.jsr305 - guavaDependency 'implementation' + implementation libraries.jsr305, + libraries.guava testImplementation project(':grpc-core') diff --git a/protobuf/build.gradle b/protobuf/build.gradle index af7f51c836e..bb8546dc701 100644 --- a/protobuf/build.gradle +++ b/protobuf/build.gradle @@ -13,14 +13,12 @@ dependencies { api project(':grpc-api'), libraries.jsr305, libraries.protobuf - guavaDependency 'implementation' + implementation libraries.guava api (libraries.google_api_protos) { // 'com.google.api:api-common' transitively depends on auto-value, which breaks our // annotations. exclude group: 'com.google.api', module: 'api-common' - // Prefer our more up-to-date protobuf over 3.2.0 - exclude group: 'com.google.protobuf', module: 'protobuf-java' } api (project(':grpc-protobuf-lite')) { diff --git a/repositories.bzl b/repositories.bzl index 2fd188765a5..ad50272d286 100644 --- a/repositories.bzl +++ b/repositories.bzl @@ -109,24 +109,24 @@ def com_google_protobuf(): # This statement defines the @com_google_protobuf repo. http_archive( name = "com_google_protobuf", - sha256 = "b37e96e81842af659605908a421960a5dc809acbc888f6b947bc320f8628e5b1", - strip_prefix = "protobuf-3.12.0", - urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.12.0.zip"], + sha256 = "f6042eef01551cee4c663a11c3f429c06360a1f51daa9f4772bf3f13d24cde1f", + strip_prefix = "protobuf-3.17.2", + urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.17.2.zip"], ) def com_google_protobuf_javalite(): # java_lite_proto_library rules implicitly depend on @com_google_protobuf_javalite http_archive( name = "com_google_protobuf_javalite", - sha256 = "b37e96e81842af659605908a421960a5dc809acbc888f6b947bc320f8628e5b1", - strip_prefix = "protobuf-3.12.0", - urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.12.0.zip"], + sha256 = "f6042eef01551cee4c663a11c3f429c06360a1f51daa9f4772bf3f13d24cde1f", + strip_prefix = "protobuf-3.17.2", + urls = ["https://github.com/protocolbuffers/protobuf/archive/v3.17.2.zip"], ) def io_grpc_grpc_proto(): http_archive( name = "io_grpc_grpc_proto", - sha256 = "5848a4e034126bece0c37c16554fb80625615aedf1acad4e2a3cdbaaa76944eb", - strip_prefix = "grpc-proto-cf828d0e1155e5ea58b46d7184ee5596e03ddcb8", - urls = ["https://github.com/grpc/grpc-proto/archive/cf828d0e1155e5ea58b46d7184ee5596e03ddcb8.zip"], + sha256 = "464e97a24d7d784d9c94c25fa537ba24127af5aae3edd381007b5b98705a0518", + strip_prefix = "grpc-proto-08911e9d585cbda3a55eb1dcc4b99c89aebccff8", + urls = ["https://github.com/grpc/grpc-proto/archive/08911e9d585cbda3a55eb1dcc4b99c89aebccff8.zip"], ) diff --git a/rls/build.gradle b/rls/build.gradle index 32004a7d43a..a2ebf2a62ef 100644 --- a/rls/build.gradle +++ b/rls/build.gradle @@ -13,8 +13,8 @@ evaluationDependsOn(project(':grpc-core').path) dependencies { implementation project(':grpc-core'), project(':grpc-protobuf'), - project(':grpc-stub') - guavaDependency 'implementation' + project(':grpc-stub'), + libraries.guava compileOnly libraries.javax_annotation testImplementation libraries.truth, project(':grpc-grpclb'), diff --git a/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java b/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java index baf7fa0dbac..32df13c4262 100644 --- a/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java +++ b/rls/src/main/java/io/grpc/rls/RlsProtoConverters.java @@ -46,6 +46,7 @@ final class RlsProtoConverters { static final class RouteLookupRequestConverter extends Converter { + @SuppressWarnings("deprecation") @Override protected RlsProtoData.RouteLookupRequest doForward(RouteLookupRequest routeLookupRequest) { return @@ -56,6 +57,7 @@ protected RlsProtoData.RouteLookupRequest doForward(RouteLookupRequest routeLook routeLookupRequest.getKeyMapMap()); } + @SuppressWarnings("deprecation") @Override protected RouteLookupRequest doBackward(RlsProtoData.RouteLookupRequest routeLookupRequest) { return diff --git a/rls/src/main/proto/grpc/lookup/v1/rls.proto b/rls/src/main/proto/grpc/lookup/v1/rls.proto index 0d14e693a6c..d9dd6c246f2 100644 --- a/rls/src/main/proto/grpc/lookup/v1/rls.proto +++ b/rls/src/main/proto/grpc/lookup/v1/rls.proto @@ -24,32 +24,39 @@ option java_outer_classname = "RlsProto"; message RouteLookupRequest { // Full host name of the target server, e.g. firestore.googleapis.com. // Only set for gRPC requests; HTTP requests must use key_map explicitly. - string server = 1; + // Deprecated in favor of setting key_map keys with GrpcKeyBuilder.extra_keys. + string server = 1 [deprecated = true]; // Full path of the request, i.e. "/service/method". // Only set for gRPC requests; HTTP requests must use key_map explicitly. - string path = 2; + // Deprecated in favor of setting key_map keys with GrpcKeyBuilder.extra_keys. + string path = 2 [deprecated = true]; // Target type allows the client to specify what kind of target format it // would like from RLS to allow it to find the regional server, e.g. "grpc". string target_type = 3; + // Possible reasons for making a request. + enum Reason { + REASON_UNKNOWN = 0; // Unused + REASON_MISS = 1; // No data available in local cache + REASON_STALE = 2; // Data in local cache is stale + } + // Reason for making this request. + Reason reason = 5; // Map of key values extracted via key builders for the gRPC or HTTP request. map key_map = 4; } message RouteLookupResponse { - // Actual addressable entity to use for routing decision, using syntax - // requested by the request target_type. - // This field is deprecated in favor of the new "targets" field, below. - string target = 1 [deprecated=true]; // Prioritized list (best one first) of addressable entities to use // for routing, using syntax requested by the request target_type. // The targets will be tried in order until a healthy one is found. - // If present, it should be used by proxy/gRPC client code instead of - // "target" (which is deprecated). repeated string targets = 3; // Optional header value to pass along to AFE in the X-Google-RLS-Data header. // Cached with "target" and sent with all requests that match the request key. // Allows the RLS to pass its work product to the eventual target. string header_data = 2; + + reserved 1; + reserved "target"; } service RouteLookupService { diff --git a/rls/src/main/proto/grpc/lookup/v1/rls_config.proto b/rls/src/main/proto/grpc/lookup/v1/rls_config.proto index 4c02201329b..db99a8949ea 100644 --- a/rls/src/main/proto/grpc/lookup/v1/rls_config.proto +++ b/rls/src/main/proto/grpc/lookup/v1/rls_config.proto @@ -29,6 +29,9 @@ option java_outer_classname = "RlsConfigProto"; // present for the keybuilder to match. message NameMatcher { // The name that will be used in the RLS key_map to refer to this value. + // If required_match is true, you may omit this field or set it to an empty + // string, in which case the matcher will require a match, but won't update + // the key_map. string key = 1; // Ordered list of names (headers or query parameter names) that can supply @@ -52,10 +55,29 @@ message GrpcKeyBuilder { } repeated Name names = 1; + // If you wish to include the host, service, or method names as keys in the + // generated RouteLookupRequest, specify key names to use in the extra_keys + // submessage. If a key name is empty, no key will be set for that value. + // If this submessage is specified, the normal host/path fields will be left + // unset in the RouteLookupRequest. We are deprecating host/path in the + // RouteLookupRequest, so services should migrate to the ExtraKeys approach. + message ExtraKeys { + string host = 1; + string service = 2; + string method = 3; + } + ExtraKeys extra_keys = 3; + // Extract keys from all listed headers. // For gRPC, it is an error to specify "required_match" on the NameMatcher // protos. repeated NameMatcher headers = 2; + + // You can optionally set one or more specific key/value pairs to be added to + // the key_map. This can be useful to identify which builder built the key, + // for example if you are suppressing the actual method, but need to + // separately cache and request all the matched methods. + map constant_keys = 4; } // An HttpKeyBuilder applies to a given HTTP URL and headers. @@ -131,6 +153,12 @@ message HttpKeyBuilder { // to match. If a given header appears multiple times in the request we will // report it as a comma-separated string, in standard HTTP fashion. repeated NameMatcher headers = 4; + + // You can optionally set one or more specific key/value pairs to be added to + // the key_map. This can be useful to identify which builder built the key, + // for example if you are suppressing a lot of information from the URL, but + // need to separately cache and request URLs with that content. + map constant_keys = 5; } message RouteLookupConfig { @@ -176,40 +204,15 @@ message RouteLookupConfig { // This is a list of all the possible targets that can be returned by the // lookup service. If a target not on this list is returned, it will be - // treated the same as an RPC error from the RLS. + // treated the same as an unhealthy target. repeated string valid_targets = 8; - // This value provides a default target to use if needed. It will be used for - // request processing strategy SYNC_LOOKUP_DEFAULT_TARGET_ON_ERROR if RLS - // returns an error, or strategy ASYNC_LOOKUP_DEFAULT_TARGET_ON_MISS if RLS - // returns an error or there is a cache miss in the client. It will also be - // used if there are no healthy backends for an RLS target. Note that - // requests can be routed only to a subdomain of the original target, - // e.g. "us_east_1.cloudbigtable.googleapis.com". + // This value provides a default target to use if needed. If set, it will be + // used if RLS returns an error, times out, or returns an invalid response. + // Note that requests can be routed only to a subdomain of the original + // target, e.g. "us_east_1.cloudbigtable.googleapis.com". string default_target = 9; - // Specify how to process a request when not already in the cache. - enum RequestProcessingStrategy { - STRATEGY_UNSPECIFIED = 0; - - // Query the RLS and process the request using target returned by the - // lookup. The target will then be cached and used for processing - // subsequent requests for the same key. Any errors during lookup service - // processing will fall back to default target for request processing. - SYNC_LOOKUP_DEFAULT_TARGET_ON_ERROR = 1; - - // Query the RLS and process the request using target returned by the - // lookup. The target will then be cached and used for processing - // subsequent requests for the same key. Any errors during lookup service - // processing will return an error back to the client. Services with - // strict regional routing requirements should use this strategy. - SYNC_LOOKUP_CLIENT_SEES_ERROR = 2; - - // Query the RLS asynchronously but respond with the default target. The - // target in the lookup response will then be cached and used for - // subsequent requests. Services with strict latency requirements (but not - // strict regional routing requirements) should use this strategy. - ASYNC_LOOKUP_DEFAULT_TARGET_ON_MISS = 3; - } - RequestProcessingStrategy request_processing_strategy = 10; + reserved 10; + reserved "request_processing_strategy"; } diff --git a/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java b/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java index 2f7d84d6b13..a50cdeb9f68 100644 --- a/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java +++ b/rls/src/test/java/io/grpc/rls/RlsProtoConvertersTest.java @@ -41,6 +41,7 @@ @RunWith(JUnit4.class) public class RlsProtoConvertersTest { + @SuppressWarnings("deprecation") @Test public void convert_toRequestProto() { Converter converter = @@ -60,6 +61,7 @@ public void convert_toRequestProto() { assertThat(object.getKeyMap()).containsExactly("key1", "val1"); } + @SuppressWarnings("deprecation") @Test public void convert_toRequestObject() { Converter converter = diff --git a/services/build.gradle b/services/build.gradle index f5132515528..2de9418c3c2 100644 --- a/services/build.gradle +++ b/services/build.gradle @@ -22,12 +22,8 @@ dependencies { api project(':grpc-protobuf'), project(':grpc-stub'), project(':grpc-core') - implementation (libraries.protobuf_util) { - // prefer our own versions instead of protobuf-util's dependency - exclude group: 'com.google.guava', module: 'guava' - exclude group: 'com.google.errorprone', module: 'error_prone_annotations' - } - guavaDependency 'implementation' + implementation libraries.protobuf_util, + libraries.guava runtimeOnly libraries.errorprone compileOnly libraries.javax_annotation diff --git a/settings.gradle b/settings.gradle index fbee65a294e..d552d957838 100644 --- a/settings.gradle +++ b/settings.gradle @@ -5,7 +5,7 @@ pluginManagement { id "com.github.johnrengelman.shadow" version "6.1.0" id "com.github.kt3k.coveralls" version "2.10.2" id "com.google.osdetector" version "1.6.2" - id "com.google.protobuf" version "0.8.15" + id "com.google.protobuf" version "0.8.16" id "digital.wup.android-maven-publish" version "3.6.2" id "me.champeau.gradle.japicmp" version "0.2.5" id "me.champeau.gradle.jmh" version "0.5.2" diff --git a/stub/build.gradle b/stub/build.gradle index ce1e742ab38..4076460377c 100644 --- a/stub/build.gradle +++ b/stub/build.gradle @@ -8,8 +8,8 @@ plugins { description = "gRPC: Stub" dependencies { - api project(':grpc-api') - guavaDependency 'api' + api project(':grpc-api'), + libraries.guava testImplementation libraries.truth, project(':grpc-testing') signature "org.codehaus.mojo.signature:java17:1.0@signature" diff --git a/testing/build.gradle b/testing/build.gradle index dc619913b6b..a0f31819916 100644 --- a/testing/build.gradle +++ b/testing/build.gradle @@ -13,8 +13,8 @@ dependencies { api project(':grpc-core'), project(':grpc-stub'), libraries.junit - - censusApiDependency 'implementation' + implementation libraries.opencensus_api + runtimeOnly project(":grpc-context") // Pull in newer version than census-api testImplementation (libraries.mockito) { // prefer our own versions instead of mockito's dependency diff --git a/xds/build.gradle b/xds/build.gradle index d8462c0de54..ae8d8d208a9 100644 --- a/xds/build.gradle +++ b/xds/build.gradle @@ -36,21 +36,11 @@ dependencies { libraries.gson, libraries.re2j, libraries.bouncycastle, - libraries.autovalue_annotation + libraries.autovalue_annotation, + libraries.opencensus_proto, + libraries.protobuf_util def nettyDependency = implementation project(':grpc-netty') - implementation (libraries.opencensus_proto) { - // prefer our own versions instead of opencensus_proto's - exclude group: 'com.google.protobuf', module: 'protobuf-java' - exclude group: 'io.grpc', module: 'grpc-protobuf' - exclude group: 'io.grpc', module: 'grpc-stub' - } - implementation (libraries.protobuf_util) { - // prefer our own versions instead of protobuf-util's dependency - exclude group: 'com.google.guava', module: 'guava' - exclude group: 'com.google.errorprone', module: 'error_prone_annotations' - } - testImplementation project(':grpc-core').sourceSets.test.output annotationProcessor libraries.autovalue @@ -107,6 +97,7 @@ javadoc { exclude 'io/grpc/xds/*LoadBalancer*' exclude 'io/grpc/xds/Bootstrapper.java' exclude 'io/grpc/xds/Envoy*' + exclude 'io/grpc/xds/TlsContextManager.java' exclude 'io/grpc/xds/XdsAttributes.java' exclude 'io/grpc/xds/XdsClientWrapperForServerSds.java' exclude 'io/grpc/xds/XdsInitializationException.java' diff --git a/xds/src/main/java/io/grpc/xds/AbstractXdsClient.java b/xds/src/main/java/io/grpc/xds/AbstractXdsClient.java index ee54a8befc5..357534233b0 100644 --- a/xds/src/main/java/io/grpc/xds/AbstractXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/AbstractXdsClient.java @@ -28,6 +28,7 @@ import io.envoyproxy.envoy.service.discovery.v3.AggregatedDiscoveryServiceGrpc; import io.envoyproxy.envoy.service.discovery.v3.DiscoveryRequest; import io.envoyproxy.envoy.service.discovery.v3.DiscoveryResponse; +import io.grpc.Context; import io.grpc.InternalLogId; import io.grpc.ManagedChannel; import io.grpc.Status; @@ -81,6 +82,7 @@ public void uncaughtException(Thread t, Throwable e) { private final InternalLogId logId; private final XdsLogger logger; private final ManagedChannel channel; + private final Context context; private final ScheduledExecutorService timeService; private final BackoffPolicy.Provider backoffPolicyProvider; private final Stopwatch stopwatch; @@ -103,10 +105,11 @@ public void uncaughtException(Thread t, Throwable e) { private ScheduledHandle rpcRetryTimer; AbstractXdsClient(ManagedChannel channel, Bootstrapper.BootstrapInfo bootstrapInfo, - ScheduledExecutorService timeService, BackoffPolicy.Provider backoffPolicyProvider, - Supplier stopwatchSupplier) { + Context context, ScheduledExecutorService timeService, + BackoffPolicy.Provider backoffPolicyProvider, Supplier stopwatchSupplier) { this.channel = checkNotNull(channel, "channel"); this.bootstrapInfo = checkNotNull(bootstrapInfo, "bootstrapInfo"); + this.context = checkNotNull(context, "context"); this.timeService = checkNotNull(timeService, "timeService"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); stopwatch = checkNotNull(stopwatchSupplier, "stopwatchSupplier").get(); @@ -305,7 +308,12 @@ private void startRpcStream() { } else { adsStream = new AdsStreamV2(); } - adsStream.start(); + Context prevContext = context.attach(); + try { + adsStream.start(); + } finally { + context.detach(prevContext); + } logger.log(XdsLogLevel.INFO, "ADS stream started"); stopwatch.reset().start(); } diff --git a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java index d0286b268de..e91e76090ab 100644 --- a/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java +++ b/xds/src/main/java/io/grpc/xds/CdsLoadBalancer2.java @@ -161,8 +161,8 @@ private void handleClusterDiscovered() { clusterState.result.upstreamTlsContext()); } else { // logical DNS instance = DiscoveryMechanism.forLogicalDns( - clusterState.name, clusterState.result.lrsServerName(), - clusterState.result.maxConcurrentRequests(), + clusterState.name, clusterState.result.dnsHostName(), + clusterState.result.lrsServerName(), clusterState.result.maxConcurrentRequests(), clusterState.result.upstreamTlsContext()); } instances.add(instance); diff --git a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java index 77678f49367..c037d6c4166 100644 --- a/xds/src/main/java/io/grpc/xds/ClientXdsClient.java +++ b/xds/src/main/java/io/grpc/xds/ClientXdsClient.java @@ -41,6 +41,8 @@ import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig; import io.envoyproxy.envoy.config.core.v3.HttpProtocolOptions; import io.envoyproxy.envoy.config.core.v3.RoutingPriority; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.core.v3.SocketAddress.PortSpecifierCase; import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; import io.envoyproxy.envoy.config.listener.v3.Listener; import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; @@ -48,6 +50,7 @@ import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.Rds; import io.envoyproxy.envoy.type.v3.FractionalPercent; import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; +import io.grpc.Context; import io.grpc.EquivalentAddressGroup; import io.grpc.ManagedChannel; import io.grpc.Status; @@ -63,15 +66,15 @@ import io.grpc.xds.Filter.NamedFilterConfig; import io.grpc.xds.LoadStatsManager2.ClusterDropStats; import io.grpc.xds.LoadStatsManager2.ClusterLocalityStats; -import io.grpc.xds.Matchers.FractionMatcher; -import io.grpc.xds.Matchers.HeaderMatcher; -import io.grpc.xds.Matchers.PathMatcher; import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.VirtualHost.Route.RouteAction; import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight; import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy; import io.grpc.xds.VirtualHost.Route.RouteMatch; +import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; import io.grpc.xds.XdsLogger.XdsLogLevel; +import io.grpc.xds.internal.Matchers.FractionMatcher; +import io.grpc.xds.internal.Matchers.HeaderMatcher; import java.net.InetSocketAddress; import java.util.ArrayList; import java.util.Collection; @@ -95,6 +98,12 @@ final class ClientXdsClient extends AbstractXdsClient { @VisibleForTesting static final int INITIAL_RESOURCE_FETCH_TIMEOUT_SEC = 15; @VisibleForTesting + static final long DEFAULT_RING_HASH_LB_POLICY_MIN_RING_SIZE = 1024L; + @VisibleForTesting + static final long DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE = 8 * 1024 * 1024L; + @VisibleForTesting + static final long MAX_RING_HASH_LB_POLICY_RING_SIZE = 8 * 1024 * 1024L; + @VisibleForTesting static final String AGGREGATE_CLUSTER_TYPE_NAME = "envoy.clusters.aggregate"; @VisibleForTesting static final String HASH_POLICY_FILTER_STATE_KEY = "io.grpc.channel_id"; @@ -130,15 +139,18 @@ final class ClientXdsClient extends AbstractXdsClient { private final LoadReportClient lrsClient; private final TimeProvider timeProvider; private boolean reportingLoad; + private final TlsContextManager tlsContextManager; ClientXdsClient( - ManagedChannel channel, Bootstrapper.BootstrapInfo bootstrapInfo, + ManagedChannel channel, Bootstrapper.BootstrapInfo bootstrapInfo, Context context, ScheduledExecutorService timeService, BackoffPolicy.Provider backoffPolicyProvider, - Supplier stopwatchSupplier, TimeProvider timeProvider) { - super(channel, bootstrapInfo, timeService, backoffPolicyProvider, stopwatchSupplier); + Supplier stopwatchSupplier, TimeProvider timeProvider, + TlsContextManager tlsContextManager) { + super(channel, bootstrapInfo, context, timeService, backoffPolicyProvider, stopwatchSupplier); loadStatsManager = new LoadStatsManager2(stopwatchSupplier); this.timeProvider = timeProvider; - lrsClient = new LoadReportClient(loadStatsManager, channel, + this.tlsContextManager = checkNotNull(tlsContextManager, "tlsContextManager"); + lrsClient = new LoadReportClient(loadStatsManager, channel, context, bootstrapInfo.getServers().get(0).isUseProtocolV3(), bootstrapInfo.getNode(), getSyncContext(), timeService, backoffPolicyProvider, stopwatchSupplier); } @@ -281,10 +293,10 @@ private static LdsUpdate processClientSideListener(Listener listener, boolean pa "HttpConnectionManager neither has inlined route_config nor RDS."); } - private static LdsUpdate processServerSideListener(Listener listener) + private LdsUpdate processServerSideListener(Listener listener) throws ResourceInvalidException { StructOrError convertedListener = - parseServerSideListener(listener); + parseServerSideListener(listener, tlsContextManager); if (convertedListener.getErrorDetail() != null) { throw new ResourceInvalidException(convertedListener.getErrorDetail()); } @@ -363,10 +375,10 @@ private static StructOrError parseRawFilterConfig( } @VisibleForTesting static StructOrError parseServerSideListener( - Listener listener) { + Listener listener, TlsContextManager tlsContextManager) { try { return StructOrError.fromStruct( - EnvoyServerProtoData.Listener.fromEnvoyProtoListener(listener)); + EnvoyServerProtoData.Listener.fromEnvoyProtoListener(listener, tlsContextManager)); } catch (InvalidProtocolBufferException e) { return StructOrError.fromError( "Failed to unpack Listener " + listener.getName() + ":" + e.getMessage()); @@ -784,7 +796,7 @@ protected void handleCdsResponse(String versionInfo, List resources, String // Process Cluster into CdsUpdate. CdsUpdate cdsUpdate; try { - cdsUpdate = processCluster(cluster, retainedEdsResources); + cdsUpdate = parseCluster(cluster, retainedEdsResources); } catch (ResourceInvalidException e) { errors.add( "CDS response Cluster '" + clusterName + "' validation error: " + e.getMessage()); @@ -812,7 +824,8 @@ protected void handleCdsResponse(String versionInfo, List resources, String } } - private static CdsUpdate processCluster(Cluster cluster, Set retainedEdsResources) + @VisibleForTesting + static CdsUpdate parseCluster(Cluster cluster, Set retainedEdsResources) throws ResourceInvalidException { StructOrError structOrError; switch (cluster.getClusterDiscoveryTypeCase()) { @@ -824,26 +837,36 @@ private static CdsUpdate processCluster(Cluster cluster, Set retainedEds break; case CLUSTERDISCOVERYTYPE_NOT_SET: default: - throw new ResourceInvalidException("Unspecified cluster discovery type"); + throw new ResourceInvalidException( + "Cluster " + cluster.getName() + ": unspecified cluster discovery type"); } if (structOrError.getErrorDetail() != null) { throw new ResourceInvalidException(structOrError.getErrorDetail()); } - CdsUpdate.Builder updateBuilder = structOrError.getStruct(); if (cluster.getLbPolicy() == LbPolicy.RING_HASH) { RingHashLbConfig lbConfig = cluster.getRingHashLbConfig(); - if (lbConfig.getHashFunction() != RingHashLbConfig.HashFunction.XX_HASH) { + long minRingSize = + lbConfig.hasMinimumRingSize() + ? lbConfig.getMinimumRingSize().getValue() + : DEFAULT_RING_HASH_LB_POLICY_MIN_RING_SIZE; + long maxRingSize = + lbConfig.hasMaximumRingSize() + ? lbConfig.getMaximumRingSize().getValue() + : DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE; + if (lbConfig.getHashFunction() != RingHashLbConfig.HashFunction.XX_HASH + || minRingSize > maxRingSize + || maxRingSize > MAX_RING_HASH_LB_POLICY_RING_SIZE) { throw new ResourceInvalidException( - "Unsupported ring hash function: " + lbConfig.getHashFunction()); + "Cluster " + cluster.getName() + ": invalid ring_hash_lb_config: " + lbConfig); } - updateBuilder.lbPolicy(CdsUpdate.LbPolicy.RING_HASH, - lbConfig.getMinimumRingSize().getValue(), lbConfig.getMaximumRingSize().getValue()); + updateBuilder.ringHashLbPolicy(minRingSize, maxRingSize); } else if (cluster.getLbPolicy() == LbPolicy.ROUND_ROBIN) { - updateBuilder.lbPolicy(CdsUpdate.LbPolicy.ROUND_ROBIN); + updateBuilder.roundRobinLbPolicy(); } else { - throw new ResourceInvalidException("Unsupported lb policy: " + cluster.getLbPolicy()); + throw new ResourceInvalidException( + "Cluster " + cluster.getName() + ": unsupported lb policy: " + cluster.getLbPolicy()); } return updateBuilder.build(); @@ -925,8 +948,40 @@ private static StructOrError parseNonAggregateCluster( return StructOrError.fromStruct(CdsUpdate.forEds( clusterName, edsServiceName, lrsServerName, maxConcurrentRequests, upstreamTlsContext)); } else if (type.equals(DiscoveryType.LOGICAL_DNS)) { + if (!cluster.hasLoadAssignment()) { + return StructOrError.fromError( + "Cluster " + clusterName + ": LOGICAL_DNS clusters must have a single host"); + } + ClusterLoadAssignment assignment = cluster.getLoadAssignment(); + if (assignment.getEndpointsCount() != 1 + || assignment.getEndpoints(0).getLbEndpointsCount() != 1) { + return StructOrError.fromError( + "Cluster " + clusterName + ": LOGICAL_DNS clusters must have a single " + + "locality_lb_endpoint and a single lb_endpoint"); + } + io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint lbEndpoint = + assignment.getEndpoints(0).getLbEndpoints(0); + if (!lbEndpoint.hasEndpoint() || !lbEndpoint.getEndpoint().hasAddress() + || !lbEndpoint.getEndpoint().getAddress().hasSocketAddress()) { + return StructOrError.fromError( + "Cluster " + clusterName + + ": LOGICAL_DNS clusters must have an endpoint with address and socket_address"); + } + SocketAddress socketAddress = lbEndpoint.getEndpoint().getAddress().getSocketAddress(); + if (!socketAddress.getResolverName().isEmpty()) { + return StructOrError.fromError( + "Cluster " + clusterName + + ": LOGICAL DNS clusters must NOT have a custom resolver name set"); + } + if (socketAddress.getPortSpecifierCase() != PortSpecifierCase.PORT_VALUE) { + return StructOrError.fromError( + "Cluster " + clusterName + + ": LOGICAL DNS clusters socket_address must have port_value"); + } + String dnsHostName = + String.format("%s:%d", socketAddress.getAddress(), socketAddress.getPortValue()); return StructOrError.fromStruct(CdsUpdate.forLogicalDns( - clusterName, lrsServerName, maxConcurrentRequests, upstreamTlsContext)); + clusterName, dnsHostName, lrsServerName, maxConcurrentRequests, upstreamTlsContext)); } return StructOrError.fromError( "Cluster " + clusterName + ": unsupported built-in discovery type: " + type); @@ -1180,6 +1235,11 @@ Map getSubscribedResourcesMetadata(ResourceType type) return metadataMap; } + @Override + TlsContextManager getTlsContextManager() { + return tlsContextManager; + } + @Override void watchLdsResource(final String resourceName, final LdsResourceWatcher watcher) { getSyncContext().execute(new Runnable() { @@ -1567,14 +1627,15 @@ private void notifyWatcher(ResourceWatcher watcher, ResourceUpdate update) { } } - private static final class ResourceInvalidException extends Exception { + @VisibleForTesting + static final class ResourceInvalidException extends Exception { private static final long serialVersionUID = 0L; - public ResourceInvalidException(String message) { + private ResourceInvalidException(String message) { super(message, null, false, false); } - public ResourceInvalidException(String message, Throwable cause) { + private ResourceInvalidException(String message, Throwable cause) { super(cause != null ? message + ": " + cause.getMessage() : message, cause, false, false); } } diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 92e23b13b67..5beefc3384c 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -45,8 +45,6 @@ import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; import io.grpc.xds.XdsSubchannelPickers.ErrorPicker; import io.grpc.xds.internal.sds.SslContextProviderSupplier; -import io.grpc.xds.internal.sds.TlsContextManager; -import io.grpc.xds.internal.sds.TlsContextManagerImpl; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -78,7 +76,6 @@ final class ClusterImplLoadBalancer extends LoadBalancer { private final XdsLogger logger; private final Helper helper; private final ThreadSafeRandom random; - private final TlsContextManager tlsContextManager; // The following fields are effectively final. private String cluster; @Nullable @@ -91,14 +88,12 @@ final class ClusterImplLoadBalancer extends LoadBalancer { private LoadBalancer childLb; ClusterImplLoadBalancer(Helper helper) { - this(helper, ThreadSafeRandomImpl.instance, TlsContextManagerImpl.getInstance()); + this(helper, ThreadSafeRandomImpl.instance); } - ClusterImplLoadBalancer(Helper helper, ThreadSafeRandom random, - TlsContextManager tlsContextManager) { + ClusterImplLoadBalancer(Helper helper, ThreadSafeRandom random) { this.helper = checkNotNull(helper, "helper"); this.random = checkNotNull(random, "random"); - this.tlsContextManager = checkNotNull(tlsContextManager, "tlsContextManager"); InternalLogId logId = InternalLogId.allocate("cluster-impl-lb", helper.getAuthority()); logger = XdsLogger.withLogId(logId); logger.log(XdsLogLevel.INFO, "Created"); @@ -158,7 +153,10 @@ public void shutdown() { } if (childLb != null) { childLb.shutdown(); - childLbHelper = null; + if (childLbHelper != null) { + childLbHelper.updateSslContextProviderSupplier(null); + childLbHelper = null; + } } if (xdsClient != null) { xdsClient = xdsClientPool.returnObject(xdsClient); @@ -274,7 +272,7 @@ private void updateSslContextProviderSupplier(@Nullable UpstreamTlsContext tlsCo } sslContextProviderSupplier = tlsContext != null - ? new SslContextProviderSupplier(tlsContext, tlsContextManager) + ? new SslContextProviderSupplier(tlsContext, xdsClient.getTlsContextManager()) : null; } diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java index 6e1b17cf7ac..9cfbe3a753b 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java @@ -17,11 +17,9 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.ConnectivityState.CONNECTING; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.xds.XdsLbPolicies.PRIORITY_POLICY_NAME; import static io.grpc.xds.XdsLbPolicies.WEIGHTED_TARGET_POLICY_NAME; -import static io.grpc.xds.XdsSubchannelPickers.BUFFER_PICKER; import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; @@ -81,7 +79,6 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { // to an empty locality. private static final Locality LOGICAL_DNS_CLUSTER_LOCALITY = Locality.create("", "", ""); private final XdsLogger logger; - private final String authority; private final SynchronizationContext syncContext; private final ScheduledExecutorService timeService; private final LoadBalancerRegistry lbRegistry; @@ -101,7 +98,6 @@ final class ClusterResolverLoadBalancer extends LoadBalancer { BackoffPolicy.Provider backoffPolicyProvider) { this.lbRegistry = checkNotNull(lbRegistry, "lbRegistry"); this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); - this.authority = checkNotNull(checkNotNull(helper, "helper").getAuthority(), "authority"); this.syncContext = checkNotNull(helper.getSynchronizationContext(), "syncContext"); this.timeService = checkNotNull(helper.getScheduledExecutorService(), "timeService"); delegate = new GracefulSwitchLoadBalancer(helper); @@ -180,8 +176,8 @@ public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { state = new EdsClusterState(instance.cluster, instance.edsServiceName, instance.lrsServerName, instance.maxConcurrentRequests, instance.tlsContext); } else { // logical DNS - state = new LogicalDnsClusterState(instance.cluster, instance.lrsServerName, - instance.maxConcurrentRequests, instance.tlsContext); + state = new LogicalDnsClusterState(instance.cluster, instance.dnsHostName, + instance.lrsServerName, instance.maxConcurrentRequests, instance.tlsContext); } clusterStates.put(instance.cluster, state); state.start(); @@ -211,30 +207,35 @@ private void handleEndpointResourceUpdate() { List addresses = new ArrayList<>(); Map priorityChildConfigs = new HashMap<>(); List priorities = new ArrayList<>(); // totally ordered priority list - boolean allResolved = true; + Status endpointNotFound = Status.OK; for (String cluster : clusters) { ClusterState state = clusterStates.get(cluster); - if (!state.resolved) { - allResolved = false; - continue; + // Propagate endpoints to the child LB policy only after all clusters have been resolved. + if (!state.resolved && state.status.isOk()) { + return; } if (state.result != null) { addresses.addAll(state.result.addresses); priorityChildConfigs.putAll(state.result.priorityChildConfigs); priorities.addAll(state.result.priorities); + } else { + endpointNotFound = state.status; } } if (addresses.isEmpty()) { + if (endpointNotFound.isOk()) { + endpointNotFound = Status.UNAVAILABLE.withDescription( + "No usable endpoint from cluster(s): " + clusters); + } else { + endpointNotFound = + Status.UNAVAILABLE.withCause(endpointNotFound.getCause()) + .withDescription(endpointNotFound.getDescription()); + } + helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(endpointNotFound)); if (childLb != null) { childLb.shutdown(); childLb = null; } - if (allResolved) { - Status unavailable = Status.UNAVAILABLE.withDescription("No usable endpoint"); - helper.updateBalancingState(TRANSIENT_FAILURE, new ErrorPicker(unavailable)); - } else { - helper.updateBalancingState(CONNECTING, BUFFER_PICKER); - } return; } PriorityLbConfig childConfig = @@ -252,14 +253,16 @@ private void handleEndpointResourceUpdate() { private void handleEndpointResolutionError() { boolean allInError = true; - for (ClusterState state : clusterStates.values()) { + Status error = null; + for (String cluster : clusters) { + ClusterState state = clusterStates.get(cluster); if (state.status.isOk()) { allInError = false; + } else { + error = state.status; } } if (allInError) { - // Propagate the error status of the last cluster. This is the best we can do. - Status error = clusterStates.get(clusters.get(clusters.size() - 1)).status; if (childLb != null) { childLb.handleNameResolutionError(error); } else { @@ -301,10 +304,6 @@ protected Helper delegate() { private abstract class ClusterState { // Name of the cluster to be resolved. protected final String name; - // The resource name to be used for resolving endpoints via EDS. - // Always null if the cluster is a logical DNS cluster. - @Nullable - protected final String edsServiceName; @Nullable protected final String lrsServerName; @Nullable @@ -320,11 +319,9 @@ private abstract class ClusterState { protected ClusterResolutionResult result; protected boolean shutdown; - private ClusterState(String name, @Nullable String edsServiceName, - @Nullable String lrsServerName, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext) { + private ClusterState(String name, @Nullable String lrsServerName, + @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext) { this.name = name; - this.edsServiceName = edsServiceName; this.lrsServerName = lrsServerName; this.maxConcurrentRequests = maxConcurrentRequests; this.tlsContext = tlsContext; @@ -338,11 +335,14 @@ void shutdown() { } private final class EdsClusterState extends ClusterState implements EdsResourceWatcher { + @Nullable + private final String edsServiceName; private EdsClusterState(String name, @Nullable String edsServiceName, @Nullable String lrsServerName, @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext) { - super(name, edsServiceName, lrsServerName, maxConcurrentRequests, tlsContext); + super(name, lrsServerName, maxConcurrentRequests, tlsContext); + this.edsServiceName = edsServiceName; } @Override @@ -387,13 +387,17 @@ public void run() { for (LbEndpoint endpoint : localityLbInfo.endpoints()) { if (endpoint.isHealthy()) { discard = false; - long weight = - (long) localityLbInfo.localityWeight() * endpoint.loadBalancingWeight(); - Attributes attr = endpoint.eag().getAttributes().toBuilder() - .set(InternalXdsAttributes.ATTR_LOCALITY, locality) - .set(InternalXdsAttributes.ATTR_SERVER_WEIGHT, weight).build(); - EquivalentAddressGroup eag = - new EquivalentAddressGroup(endpoint.eag().getAddresses(), attr); + long weight = localityLbInfo.localityWeight(); + if (endpoint.loadBalancingWeight() != 0) { + weight *= endpoint.loadBalancingWeight(); + } + Attributes attr = + endpoint.eag().getAttributes().toBuilder() + .set(InternalXdsAttributes.ATTR_LOCALITY, locality) + .set(InternalXdsAttributes.ATTR_SERVER_WEIGHT, weight) + .build(); + EquivalentAddressGroup eag = new EquivalentAddressGroup( + endpoint.eag().getAddresses(), attr); eag = AddressFilter.setPathFilter( eag, Arrays.asList(priorityName, localityName(locality))); addresses.add(eag); @@ -465,6 +469,7 @@ public void run() { } private final class LogicalDnsClusterState extends ClusterState { + private final String dnsHostName; private final NameResolver.Factory nameResolverFactory; private final NameResolver.Args nameResolverArgs; private NameResolver resolver; @@ -473,9 +478,11 @@ private final class LogicalDnsClusterState extends ClusterState { @Nullable private ScheduledHandle scheduledRefresh; - private LogicalDnsClusterState(String name, @Nullable String lrsServerName, - @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext) { - super(name, null, lrsServerName, maxConcurrentRequests, tlsContext); + private LogicalDnsClusterState(String name, String dnsHostName, + @Nullable String lrsServerName, @Nullable Long maxConcurrentRequests, + @Nullable UpstreamTlsContext tlsContext) { + super(name, lrsServerName, maxConcurrentRequests, tlsContext); + this.dnsHostName = checkNotNull(dnsHostName, "dnsHostName"); nameResolverFactory = checkNotNull(helper.getNameResolverRegistry().asFactory(), "nameResolverFactory"); nameResolverArgs = checkNotNull(helper.getNameResolverArgs(), "nameResolverArgs"); @@ -485,10 +492,10 @@ private LogicalDnsClusterState(String name, @Nullable String lrsServerName, void start() { URI uri; try { - uri = new URI("dns", "", "/" + authority, null); + uri = new URI("dns", "", "/" + dnsHostName, null); } catch (URISyntaxException e) { - status = - Status.INTERNAL.withDescription("Bug, invalid authority: " + authority).withCause(e); + status = Status.INTERNAL.withDescription( + "Bug, invalid URI creation: " + dnsHostName).withCause(e); handleEndpointResolutionError(); return; } @@ -560,8 +567,8 @@ public void run() { addresses.add(eag); } PriorityChildConfig priorityChildConfig = generateDnsBasedPriorityChildConfig( - name, edsServiceName, lrsServerName, maxConcurrentRequests, tlsContext, - lbRegistry, Collections.emptyList()); + name, lrsServerName, maxConcurrentRequests, tlsContext, lbRegistry, + Collections.emptyList()); status = Status.OK; resolved = true; result = new ClusterResolutionResult(addresses, priorityName, priorityChildConfig); @@ -581,10 +588,17 @@ public void run() { return; } status = error; - // NameResolver.Listener API cannot distinguish transient errors, we should avoid - // waiting for DNS addresses indefinitely. - resolved = true; - handleEndpointResolutionError(); + // NameResolver.Listener API cannot distinguish between address-not-found and + // transient errors. If the error occurs in the first resolution, treat it as + // address not found. Otherwise, either there is previously resolved addresses + // previously encountered error, propagate the error to downstream/upstream and + // let downstream/upstream handle it. + if (!resolved) { + resolved = true; + handleEndpointResourceUpdate(); + } else { + handleEndpointResolutionError(); + } if (scheduledRefresh != null && scheduledRefresh.isPending()) { return; } @@ -634,14 +648,14 @@ private static class ClusterResolutionResult { *

priority LB -> cluster_impl LB (single hardcoded priority) -> pick_first */ private static PriorityChildConfig generateDnsBasedPriorityChildConfig( - String cluster, @Nullable String edsServiceName, @Nullable String lrsServerName, - @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext, - LoadBalancerRegistry lbRegistry, List dropOverloads) { + String cluster, @Nullable String lrsServerName, @Nullable Long maxConcurrentRequests, + @Nullable UpstreamTlsContext tlsContext, LoadBalancerRegistry lbRegistry, + List dropOverloads) { // Override endpoint-level LB policy with pick_first for logical DNS cluster. PolicySelection endpointLbPolicy = new PolicySelection(lbRegistry.getProvider("pick_first"), null); ClusterImplConfig clusterImplConfig = - new ClusterImplConfig(cluster, edsServiceName, lrsServerName, maxConcurrentRequests, + new ClusterImplConfig(cluster, null, lrsServerName, maxConcurrentRequests, dropOverloads, endpointLbPolicy, tlsContext); LoadBalancerProvider clusterImplLbProvider = lbRegistry.getProvider(XdsLbPolicies.CLUSTER_IMPL_POLICY_NAME); diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java index e62c70cb5ec..33b150e667b 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancerProvider.java @@ -119,6 +119,9 @@ static final class DiscoveryMechanism { // Resource name for resolving endpoints via EDS. Only valid for EDS clusters. @Nullable final String edsServiceName; + // Hostname for resolving endpoints via DNS. Only valid for LOGICAL_DNS clusters. + @Nullable + final String dnsHostName; enum Type { EDS, @@ -126,11 +129,12 @@ enum Type { } private DiscoveryMechanism(String cluster, Type type, @Nullable String edsServiceName, - @Nullable String lrsServerName, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext) { + @Nullable String dnsHostName, @Nullable String lrsServerName, + @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext) { this.cluster = checkNotNull(cluster, "cluster"); this.type = checkNotNull(type, "type"); this.edsServiceName = edsServiceName; + this.dnsHostName = dnsHostName; this.lrsServerName = lrsServerName; this.maxConcurrentRequests = maxConcurrentRequests; this.tlsContext = tlsContext; @@ -139,20 +143,21 @@ private DiscoveryMechanism(String cluster, Type type, @Nullable String edsServic static DiscoveryMechanism forEds(String cluster, @Nullable String edsServiceName, @Nullable String lrsServerName, @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext) { - return new DiscoveryMechanism(cluster, Type.EDS, edsServiceName, lrsServerName, + return new DiscoveryMechanism(cluster, Type.EDS, edsServiceName, null, lrsServerName, maxConcurrentRequests, tlsContext); } - static DiscoveryMechanism forLogicalDns(String cluster, @Nullable String lrsServerName, - @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext) { - return new DiscoveryMechanism(cluster, Type.LOGICAL_DNS, null, lrsServerName, - maxConcurrentRequests, tlsContext); + static DiscoveryMechanism forLogicalDns(String cluster, String dnsHostName, + @Nullable String lrsServerName, @Nullable Long maxConcurrentRequests, + @Nullable UpstreamTlsContext tlsContext) { + return new DiscoveryMechanism(cluster, Type.LOGICAL_DNS, null, dnsHostName, + lrsServerName, maxConcurrentRequests, tlsContext); } @Override public int hashCode() { return Objects.hash(cluster, type, lrsServerName, maxConcurrentRequests, tlsContext, - edsServiceName); + edsServiceName, dnsHostName); } @Override @@ -167,6 +172,7 @@ public boolean equals(Object o) { return cluster.equals(that.cluster) && type == that.type && Objects.equals(edsServiceName, that.edsServiceName) + && Objects.equals(dnsHostName, that.dnsHostName) && Objects.equals(lrsServerName, that.lrsServerName) && Objects.equals(maxConcurrentRequests, that.maxConcurrentRequests) && Objects.equals(tlsContext, that.tlsContext); @@ -178,12 +184,11 @@ public String toString() { MoreObjects.toStringHelper(this) .add("cluster", cluster) .add("type", type) + .add("edsServiceName", edsServiceName) + .add("dnsHostName", dnsHostName) .add("lrsServerName", lrsServerName) // Exclude tlsContext as its string representation is cumbersome. .add("maxConcurrentRequests", maxConcurrentRequests); - if (type == Type.EDS) { - toStringHelper.add("edsServiceName", edsServiceName); - } return toStringHelper.toString(); } } diff --git a/xds/src/main/java/io/grpc/xds/Endpoints.java b/xds/src/main/java/io/grpc/xds/Endpoints.java index 5f871c9d269..8b1715731df 100644 --- a/xds/src/main/java/io/grpc/xds/Endpoints.java +++ b/xds/src/main/java/io/grpc/xds/Endpoints.java @@ -16,6 +16,8 @@ package io.grpc.xds; +import static com.google.common.base.Preconditions.checkArgument; + import com.google.auto.value.AutoValue; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; @@ -33,7 +35,7 @@ abstract static class LocalityLbEndpoints { // Endpoints to be load balanced. abstract ImmutableList endpoints(); - // Locality's weight for inter-locality load balancing. + // Locality's weight for inter-locality load balancing. Guaranteed to be greater than 0. abstract int localityWeight(); // Locality's priority level. @@ -41,6 +43,7 @@ abstract static class LocalityLbEndpoints { static LocalityLbEndpoints create(List endpoints, int localityWeight, int priority) { + checkArgument(localityWeight > 0, "localityWeight must be greater than 0"); return new AutoValue_Endpoints_LocalityLbEndpoints( ImmutableList.copyOf(endpoints), localityWeight, priority); } @@ -52,7 +55,7 @@ abstract static class LbEndpoint { // The endpoint address to be connected to. abstract EquivalentAddressGroup eag(); - // Endpoint's wight for load balancing. + // Endpoint's weight for load balancing. If unspecified, value of 0 is returned. abstract int loadBalancingWeight(); // Whether the endpoint is healthy. diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index 22cd597db8a..b8a5bfd290e 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -28,6 +28,7 @@ import io.envoyproxy.envoy.extensions.filters.network.http_connection_manager.v3.HttpFilter; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.grpc.Internal; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; import java.net.InetAddress; import java.net.UnknownHostException; import java.util.ArrayList; @@ -227,6 +228,8 @@ static final class FilterChainMatch { private final List sourcePrefixRanges; private final ConnectionSourceType sourceType; private final List sourcePorts; + private final List serverNames; + private final String transportProtocol; @VisibleForTesting FilterChainMatch( @@ -235,13 +238,17 @@ static final class FilterChainMatch { List applicationProtocols, List sourcePrefixRanges, ConnectionSourceType sourceType, - List sourcePorts) { + List sourcePorts, + List serverNames, + String transportProtocol) { this.destinationPort = destinationPort; this.prefixRanges = Collections.unmodifiableList(prefixRanges); this.applicationProtocols = Collections.unmodifiableList(applicationProtocols); this.sourcePrefixRanges = sourcePrefixRanges; this.sourceType = sourceType; this.sourcePorts = sourcePorts; + this.serverNames = Collections.unmodifiableList(serverNames); + this.transportProtocol = transportProtocol; } static FilterChainMatch fromEnvoyProtoFilterChainMatch( @@ -273,13 +280,19 @@ static FilterChainMatch fromEnvoyProtoFilterChainMatch( default: throw new InvalidProtocolBufferException("Unknown source-type:" + proto.getSourceType()); } + List serverNames = new ArrayList<>(); + for (String serverName : proto.getServerNamesList()) { + serverNames.add(serverName); + } return new FilterChainMatch( proto.getDestinationPort().getValue(), prefixRanges, applicationProtocols, sourcePrefixRanges, sourceType, - proto.getSourcePortsList()); + proto.getSourcePortsList(), + serverNames, + proto.getTransportProtocol()); } public int getDestinationPort() { @@ -306,6 +319,14 @@ public List getSourcePorts() { return sourcePorts; } + public List getServerNames() { + return serverNames; + } + + public String getTransportProtocol() { + return transportProtocol; + } + @Override public boolean equals(Object o) { if (this == o) { @@ -320,7 +341,9 @@ public boolean equals(Object o) { && Objects.equals(applicationProtocols, that.applicationProtocols) && Objects.equals(sourcePrefixRanges, that.sourcePrefixRanges) && sourceType == that.sourceType - && Objects.equals(sourcePorts, that.sourcePorts); + && Objects.equals(sourcePorts, that.sourcePorts) + && Objects.equals(serverNames, that.serverNames) + && Objects.equals(transportProtocol, that.transportProtocol); } @Override @@ -331,7 +354,9 @@ public int hashCode() { applicationProtocols, sourcePrefixRanges, sourceType, - sourcePorts); + sourcePorts, + serverNames, + transportProtocol); } @Override @@ -343,6 +368,8 @@ public String toString() { .add("sourcePrefixRanges", sourcePrefixRanges) .add("sourceType", sourceType) .add("sourcePorts", sourcePorts) + .add("serverNames", serverNames) + .add("transportProtocol", transportProtocol) .toString(); } } @@ -354,17 +381,21 @@ static final class FilterChain { // TODO(sanjaypujare): flatten structure by moving FilterChainMatch class members here. private final FilterChainMatch filterChainMatch; @Nullable - private final DownstreamTlsContext downstreamTlsContext; + private final SslContextProviderSupplier sslContextProviderSupplier; @VisibleForTesting FilterChain( - FilterChainMatch filterChainMatch, @Nullable DownstreamTlsContext downstreamTlsContext) { + FilterChainMatch filterChainMatch, @Nullable DownstreamTlsContext downstreamTlsContext, + TlsContextManager tlsContextManager) { + SslContextProviderSupplier sslContextProviderSupplier1 = downstreamTlsContext == null ? null + : new SslContextProviderSupplier(downstreamTlsContext, tlsContextManager); this.filterChainMatch = filterChainMatch; - this.downstreamTlsContext = downstreamTlsContext; + this.sslContextProviderSupplier = sslContextProviderSupplier1; } static FilterChain fromEnvoyProtoFilterChain( - io.envoyproxy.envoy.config.listener.v3.FilterChain proto, boolean isDefaultFilterChain) + io.envoyproxy.envoy.config.listener.v3.FilterChain proto, + TlsContextManager tlsContextManager, boolean isDefaultFilterChain) throws InvalidProtocolBufferException { if (!isDefaultFilterChain && proto.getFiltersList().isEmpty()) { throw new IllegalArgumentException( @@ -380,7 +411,8 @@ static FilterChain fromEnvoyProtoFilterChain( } return new FilterChain( FilterChainMatch.fromEnvoyProtoFilterChainMatch(proto.getFilterChainMatch()), - getTlsContextFromFilterChain(proto) + getTlsContextFromFilterChain(proto), + tlsContextManager ); } @@ -456,9 +488,8 @@ public FilterChainMatch getFilterChainMatch() { return filterChainMatch; } - @Nullable - public DownstreamTlsContext getDownstreamTlsContext() { - return downstreamTlsContext; + public SslContextProviderSupplier getSslContextProviderSupplier() { + return sslContextProviderSupplier; } @Override @@ -471,19 +502,19 @@ public boolean equals(Object o) { } FilterChain that = (FilterChain) o; return java.util.Objects.equals(filterChainMatch, that.filterChainMatch) - && java.util.Objects.equals(downstreamTlsContext, that.downstreamTlsContext); + && java.util.Objects.equals(sslContextProviderSupplier, that.sslContextProviderSupplier); } @Override public int hashCode() { - return java.util.Objects.hash(filterChainMatch, downstreamTlsContext); + return java.util.Objects.hash(filterChainMatch, sslContextProviderSupplier); } @Override public String toString() { return "FilterChain{" + "filterChainMatch=" + filterChainMatch - + ", downstreamTlsContext=" + downstreamTlsContext + + ", sslContextProviderSupplier=" + sslContextProviderSupplier + '}'; } } @@ -524,7 +555,8 @@ private static String convertEnvoyAddressToString(Address proto) { return null; } - static Listener fromEnvoyProtoListener(io.envoyproxy.envoy.config.listener.v3.Listener proto) + static Listener fromEnvoyProtoListener(io.envoyproxy.envoy.config.listener.v3.Listener proto, + TlsContextManager tlsContextManager) throws InvalidProtocolBufferException { if (!proto.getTrafficDirection().equals(TrafficDirection.INBOUND)) { throw new IllegalArgumentException("Listener " + proto.getName() + " is not INBOUND"); @@ -537,44 +569,28 @@ static Listener fromEnvoyProtoListener(io.envoyproxy.envoy.config.listener.v3.Li throw new IllegalArgumentException( "Listener " + proto.getName() + " cannot have use_original_dst set to true"); } - List filterChains = validateAndSelectFilterChains(proto.getFilterChainsList()); + List filterChains = validateAndSelectFilterChains(proto.getFilterChainsList(), + tlsContextManager); return new Listener( proto.getName(), convertEnvoyAddressToString(proto.getAddress()), - filterChains, FilterChain.fromEnvoyProtoFilterChain(proto.getDefaultFilterChain(), true)); + filterChains, FilterChain + .fromEnvoyProtoFilterChain(proto.getDefaultFilterChain(), tlsContextManager, true)); } private static List validateAndSelectFilterChains( - List inputFilterChains) + List inputFilterChains, + TlsContextManager tlsContextManager) throws InvalidProtocolBufferException { List filterChains = new ArrayList<>(inputFilterChains.size()); for (io.envoyproxy.envoy.config.listener.v3.FilterChain filterChain : inputFilterChains) { - if (isAcceptable(filterChain.getFilterChainMatch())) { - filterChains.add(FilterChain.fromEnvoyProtoFilterChain(filterChain, false)); - } + filterChains + .add(FilterChain.fromEnvoyProtoFilterChain(filterChain, tlsContextManager, false)); } return filterChains; } - // check if a filter is acceptable for gRPC server side processing - private static boolean isAcceptable( - io.envoyproxy.envoy.config.listener.v3.FilterChainMatch filterChainMatch) { - // reject if filer-chain-match - // - has server_name - // - transport protocol is other than "raw_buffer" - // - application_protocols is non-empty - if (!filterChainMatch.getServerNamesList().isEmpty()) { - return false; - } - String transportProtocol = filterChainMatch.getTransportProtocol(); - if (!transportProtocol.isEmpty() && !"raw_buffer".equals(transportProtocol)) { - return false; - } - List appProtocols = filterChainMatch.getApplicationProtocolsList(); - return appProtocols.isEmpty(); - } - public String getName() { return name; } diff --git a/xds/src/main/java/io/grpc/xds/FaultFilter.java b/xds/src/main/java/io/grpc/xds/FaultFilter.java index a531407baa5..3ad3533fa73 100644 --- a/xds/src/main/java/io/grpc/xds/FaultFilter.java +++ b/xds/src/main/java/io/grpc/xds/FaultFilter.java @@ -35,10 +35,13 @@ import io.grpc.ClientInterceptor; import io.grpc.Context; import io.grpc.Deadline; +import io.grpc.ForwardingClientCall; +import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; import io.grpc.LoadBalancer.PickSubchannelArgs; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; +import io.grpc.Status.Code; import io.grpc.internal.DelayedClientCall; import io.grpc.internal.GrpcUtil; import io.grpc.xds.FaultConfig.FaultAbort; @@ -215,23 +218,62 @@ public ClientCall interceptCall( // TODO(https://github.com/grpc/grpc-java/issues/7868) callExecutor = MoreExecutors.directExecutor(); } - if (finalDelayNanos != null && finalAbortStatus != null) { - return new DelayInjectedCall<>( - finalDelayNanos, callExecutor, scheduler, callOptions.getDeadline(), - Suppliers.ofInstance( - new FailingClientCall(finalAbortStatus, callExecutor))); - } - if (finalAbortStatus != null) { - return new FailingClientCall<>(finalAbortStatus, callExecutor); + if (finalDelayNanos != null) { + Supplier> callSupplier; + if (finalAbortStatus != null) { + callSupplier = Suppliers.ofInstance( + new FailingClientCall(finalAbortStatus, callExecutor)); + } else { + callSupplier = new Supplier>() { + @Override + public ClientCall get() { + return next.newCall(method, callOptions); + } + }; + } + final DelayInjectedCall delayInjectedCall = new DelayInjectedCall<>( + finalDelayNanos, callExecutor, scheduler, callOptions.getDeadline(), callSupplier); + + final class DeadlineInsightForwardingCall extends ForwardingClientCall { + @Override + protected ClientCall delegate() { + return delayInjectedCall; + } + + @Override + public void start(Listener listener, Metadata headers) { + Listener finalListener = + new SimpleForwardingClientCallListener(listener) { + @Override + public void onClose(Status status, Metadata trailers) { + if (status.getCode().equals(Code.DEADLINE_EXCEEDED)) { + // TODO(zdapeng:) check effective deadline locally, and + // do the following only if the local deadline is exceeded. + // (If the server sends DEADLINE_EXCEEDED for its own deadline, then the + // injected delay does not contribute to the error, because the request is + // only sent out after the delay. There could be a race between local and + // remote, but it is rather rare.) + String description = String.format( + "Deadline exceeded after up to %d ns of fault-injected delay", + finalDelayNanos); + if (status.getDescription() != null) { + description = description + ": " + status.getDescription(); + } + status = Status.DEADLINE_EXCEEDED + .withDescription(description).withCause(status.getCause()); + // Replace trailers to prevent mixing sources of status and trailers. + trailers = new Metadata(); + } + delegate().onClose(status, trailers); + } + }; + delegate().start(finalListener, headers); + } + } + + return new DeadlineInsightForwardingCall(); } else { - return new DelayInjectedCall<>( - finalDelayNanos, callExecutor, scheduler, callOptions.getDeadline(), - new Supplier>() { - @Override - public ClientCall get() { - return next.newCall(method, callOptions); - } - }); + return new FailingClientCall<>(finalAbortStatus, callExecutor); } } } diff --git a/xds/src/main/java/io/grpc/xds/Filter.java b/xds/src/main/java/io/grpc/xds/Filter.java index f8372cdd1b5..3da71fd6a4c 100644 --- a/xds/src/main/java/io/grpc/xds/Filter.java +++ b/xds/src/main/java/io/grpc/xds/Filter.java @@ -65,8 +65,9 @@ ClientInterceptor buildClientInterceptor( ScheduledExecutorService scheduler); } - // Server side filters are not currently supported, but this interface is defined for clarity. + /** Uses the FilterConfigs produced above to produce an HTTP filter interceptor for the server. */ interface ServerInterceptorBuilder { + @Nullable ServerInterceptor buildServerInterceptor( FilterConfig config, @Nullable FilterConfig overrideConfig); } diff --git a/xds/src/main/java/io/grpc/xds/FilterRegistry.java b/xds/src/main/java/io/grpc/xds/FilterRegistry.java index db4f256bce4..7f1fe82c6c3 100644 --- a/xds/src/main/java/io/grpc/xds/FilterRegistry.java +++ b/xds/src/main/java/io/grpc/xds/FilterRegistry.java @@ -34,7 +34,10 @@ private FilterRegistry() {} static synchronized FilterRegistry getDefaultRegistry() { if (instance == null) { - instance = newRegistry().register(FaultFilter.INSTANCE, RouterFilter.INSTANCE); + instance = newRegistry().register( + FaultFilter.INSTANCE, + RouterFilter.INSTANCE, + RbacFilter.INSTANCE); } return instance; } diff --git a/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java b/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java index 45d146e4a87..55ec809772f 100644 --- a/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/GoogleCloudToProdNameResolver.java @@ -40,6 +40,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.net.URL; +import java.util.Random; import java.util.concurrent.Executor; /** @@ -71,6 +72,7 @@ final class GoogleCloudToProdNameResolver extends NameResolver { private final Resource executorResource; private final XdsClientPoolFactory xdsClientPoolFactory; private final NameResolver delegate; + private final Random rand; private final boolean usingExecutorResource; // It's not possible to use both PSM and DirectPath C2P in the same application. // Delegate to DNS if user-provided bootstrap is found. @@ -83,15 +85,17 @@ final class GoogleCloudToProdNameResolver extends NameResolver { GoogleCloudToProdNameResolver(URI targetUri, Args args, Resource executorResource, XdsClientPoolFactory xdsClientPoolFactory) { - this(targetUri, args, executorResource, xdsClientPoolFactory, + this(targetUri, args, executorResource, new Random(), xdsClientPoolFactory, NameResolverRegistry.getDefaultRegistry().asFactory()); } @VisibleForTesting GoogleCloudToProdNameResolver(URI targetUri, Args args, Resource executorResource, - XdsClientPoolFactory xdsClientPoolFactory, NameResolver.Factory nameResolverFactory) { + Random rand, XdsClientPoolFactory xdsClientPoolFactory, + NameResolver.Factory nameResolverFactory) { this.executorResource = checkNotNull(executorResource, "executorResource"); this.xdsClientPoolFactory = checkNotNull(xdsClientPoolFactory, "xdsClientPoolFactory"); + this.rand = checkNotNull(rand, "rand"); String targetPath = checkNotNull(checkNotNull(targetUri, "targetUri").getPath(), "targetPath"); Preconditions.checkArgument( targetPath.startsWith("/"), @@ -169,9 +173,9 @@ public void run() { executor.execute(new Resolve()); } - private static ImmutableMap generateBootstrap(String zone, boolean supportIpv6) { + private ImmutableMap generateBootstrap(String zone, boolean supportIpv6) { ImmutableMap.Builder nodeBuilder = ImmutableMap.builder(); - nodeBuilder.put("id", "C2P"); + nodeBuilder.put("id", "C2P-" + rand.nextInt()); if (!zone.isEmpty()) { nodeBuilder.put("locality", ImmutableMap.of("zone", zone)); } diff --git a/xds/src/main/java/io/grpc/xds/LoadReportClient.java b/xds/src/main/java/io/grpc/xds/LoadReportClient.java index 603e3dcd6a8..54fa20128bc 100644 --- a/xds/src/main/java/io/grpc/xds/LoadReportClient.java +++ b/xds/src/main/java/io/grpc/xds/LoadReportClient.java @@ -28,6 +28,7 @@ import io.envoyproxy.envoy.service.load_stats.v3.LoadReportingServiceGrpc.LoadReportingServiceStub; import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsRequest; import io.envoyproxy.envoy.service.load_stats.v3.LoadStatsResponse; +import io.grpc.Context; import io.grpc.InternalLogId; import io.grpc.ManagedChannel; import io.grpc.Status; @@ -55,6 +56,7 @@ final class LoadReportClient { private final InternalLogId logId; private final XdsLogger logger; private final ManagedChannel channel; + private final Context context; private final boolean useProtocolV3; private final Node node; private final SynchronizationContext syncContext; @@ -74,6 +76,7 @@ final class LoadReportClient { LoadReportClient( LoadStatsManager2 loadStatsManager, ManagedChannel channel, + Context context, boolean useProtocolV3, Node node, SynchronizationContext syncContext, @@ -82,6 +85,7 @@ final class LoadReportClient { Supplier stopwatchSupplier) { this.loadStatsManager = checkNotNull(loadStatsManager, "loadStatsManager"); this.channel = checkNotNull(channel, "xdsChannel"); + this.context = checkNotNull(context, "context"); this.useProtocolV3 = useProtocolV3; this.syncContext = checkNotNull(syncContext, "syncContext"); this.timerService = checkNotNull(scheduledExecutorService, "timeService"); @@ -163,7 +167,12 @@ private void startLrsRpc() { lrsStream = new LrsStreamV2(); } retryStopwatch.reset().start(); - lrsStream.start(); + Context prevContext = context.attach(); + try { + lrsStream.start(); + } finally { + context.detach(prevContext); + } } private abstract class LrsStream { diff --git a/xds/src/main/java/io/grpc/xds/Matchers.java b/xds/src/main/java/io/grpc/xds/Matchers.java deleted file mode 100644 index 8018dc5ffad..00000000000 --- a/xds/src/main/java/io/grpc/xds/Matchers.java +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Copyright 2021 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds; - -import static com.google.common.base.Preconditions.checkNotNull; - -import com.google.auto.value.AutoValue; -import com.google.re2j.Pattern; -import javax.annotation.Nullable; - -/** A group of request matchers. */ -final class Matchers { - private Matchers() {} - - /** Matcher for HTTP request path. */ - @AutoValue - abstract static class PathMatcher { - // Exact full path to be matched. - @Nullable - abstract String path(); - - // Path prefix to be matched. - @Nullable - abstract String prefix(); - - // Regular expression pattern of the path to be matched. - @Nullable - abstract Pattern regEx(); - - // Whether case sensitivity is taken into account for matching. - // Only valid for full path matching or prefix matching. - abstract boolean caseSensitive(); - - static PathMatcher fromPath(String path, boolean caseSensitive) { - checkNotNull(path, "path"); - return PathMatcher.create(path, null, null, caseSensitive); - } - - static PathMatcher fromPrefix(String prefix, boolean caseSensitive) { - checkNotNull(prefix, "prefix"); - return PathMatcher.create(null, prefix, null, caseSensitive); - } - - static PathMatcher fromRegEx(Pattern regEx) { - checkNotNull(regEx, "regEx"); - return PathMatcher.create(null, null, regEx, false /* doesn't matter */); - } - - private static PathMatcher create(@Nullable String path, @Nullable String prefix, - @Nullable Pattern regEx, boolean caseSensitive) { - return new AutoValue_Matchers_PathMatcher(path, prefix, regEx, caseSensitive); - } - } - - /** Matcher for HTTP request headers. */ - @AutoValue - abstract static class HeaderMatcher { - // Name of the header to be matched. - abstract String name(); - - // Matches exact header value. - @Nullable - abstract String exactValue(); - - // Matches header value with the regular expression pattern. - @Nullable - abstract Pattern safeRegEx(); - - // Matches header value an integer value in the range. - @Nullable - abstract Range range(); - - // Matches header presence. - @Nullable - abstract Boolean present(); - - // Matches header value with the prefix. - @Nullable - abstract String prefix(); - - // Matches header value with the suffix. - @Nullable - abstract String suffix(); - - // Whether the matching semantics is inverted. E.g., present && !inverted -> !present - abstract boolean inverted(); - - static HeaderMatcher forExactValue(String name, String exactValue, boolean inverted) { - checkNotNull(name, "name"); - checkNotNull(exactValue, "exactValue"); - return HeaderMatcher.create(name, exactValue, null, null, null, null, null, inverted); - } - - static HeaderMatcher forSafeRegEx(String name, Pattern safeRegEx, boolean inverted) { - checkNotNull(name, "name"); - checkNotNull(safeRegEx, "safeRegEx"); - return HeaderMatcher.create(name, null, safeRegEx, null, null, null, null, inverted); - } - - static HeaderMatcher forRange(String name, Range range, boolean inverted) { - checkNotNull(name, "name"); - checkNotNull(range, "range"); - return HeaderMatcher.create(name, null, null, range, null, null, null, inverted); - } - - static HeaderMatcher forPresent(String name, boolean present, boolean inverted) { - checkNotNull(name, "name"); - return HeaderMatcher.create(name, null, null, null, present, null, null, inverted); - } - - static HeaderMatcher forPrefix(String name, String prefix, boolean inverted) { - checkNotNull(name, "name"); - checkNotNull(prefix, "prefix"); - return HeaderMatcher.create(name, null, null, null, null, prefix, null, inverted); - } - - static HeaderMatcher forSuffix(String name, String suffix, boolean inverted) { - checkNotNull(name, "name"); - checkNotNull(suffix, "suffix"); - return HeaderMatcher.create(name, null, null, null, null, null, suffix, inverted); - } - - private static HeaderMatcher create(String name, @Nullable String exactValue, - @Nullable Pattern safeRegEx, @Nullable Range range, - @Nullable Boolean present, @Nullable String prefix, - @Nullable String suffix, boolean inverted) { - checkNotNull(name, "name"); - return new AutoValue_Matchers_HeaderMatcher(name, exactValue, safeRegEx, range, present, - prefix, suffix, inverted); - } - - /** Represents an integer range. */ - @AutoValue - abstract static class Range { - abstract long start(); - - abstract long end(); - - static Range create(long start, long end) { - return new AutoValue_Matchers_HeaderMatcher_Range(start, end); - } - - } - } - - /** Represents a fractional value. */ - @AutoValue - abstract static class FractionMatcher { - abstract int numerator(); - - abstract int denominator(); - - static FractionMatcher create(int numerator, int denominator) { - return new AutoValue_Matchers_FractionMatcher(numerator, denominator); - } - } -} diff --git a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancerProvider.java index dff1e778409..6e178c62c1b 100644 --- a/xds/src/main/java/io/grpc/xds/PriorityLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/PriorityLoadBalancerProvider.java @@ -95,6 +95,14 @@ static final class PriorityChildConfig { this.policySelection = checkNotNull(policySelection, "policySelection"); this.ignoreReresolution = ignoreReresolution; } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("policySelection", policySelection) + .add("ignoreReresolution", ignoreReresolution) + .toString(); + } } } } diff --git a/xds/src/main/java/io/grpc/xds/RbacConfig.java b/xds/src/main/java/io/grpc/xds/RbacConfig.java new file mode 100644 index 00000000000..14f6ae33e1f --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/RbacConfig.java @@ -0,0 +1,38 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.auto.value.AutoValue; +import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthConfig; +import javax.annotation.Nullable; + +/** Rbac configuration for Rbac filter. */ +@AutoValue +abstract class RbacConfig implements FilterConfig { + @Override + public final String typeUrl() { + return RbacFilter.TYPE_URL; + } + + @Nullable + abstract AuthConfig authConfig(); + + static RbacConfig create(@Nullable AuthConfig authConfig) { + return new AutoValue_RbacConfig(authConfig); + } +} diff --git a/xds/src/main/java/io/grpc/xds/RbacFilter.java b/xds/src/main/java/io/grpc/xds/RbacFilter.java new file mode 100644 index 00000000000..48b4954767a --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/RbacFilter.java @@ -0,0 +1,325 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.common.annotations.VisibleForTesting; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import io.envoyproxy.envoy.config.core.v3.CidrRange; +import io.envoyproxy.envoy.config.rbac.v3.Permission; +import io.envoyproxy.envoy.config.rbac.v3.Policy; +import io.envoyproxy.envoy.config.rbac.v3.Principal; +import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC; +import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBACPerRoute; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.xds.Filter.ServerInterceptorBuilder; +import io.grpc.xds.internal.MatcherParser; +import io.grpc.xds.internal.Matchers; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AlwaysTrueMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AndMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthConfig; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthDecision; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthHeaderMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthenticatedMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.DestinationIpMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.DestinationPortMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.InvertMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.Matcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.OrMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.PathMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.PolicyMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.RequestedServerNameMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.SourceIpMatcher; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; + +/** RBAC Http filter implementation. */ +final class RbacFilter implements Filter, ServerInterceptorBuilder { + private static final Logger logger = Logger.getLogger(RbacFilter.class.getName()); + + static final RbacFilter INSTANCE = new RbacFilter(); + + static final String TYPE_URL = + "type.googleapis.com/envoy.extensions.filters.http.rbac.v3.RBAC"; + + private static final String TYPE_URL_OVERRIDE_CONFIG = + "type.googleapis.com/envoy.extensions.filters.http.rbac.v3.RBACPerRoute"; + + RbacFilter() {} + + @Override + public String[] typeUrls() { + return new String[] { TYPE_URL, TYPE_URL_OVERRIDE_CONFIG }; + } + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + RBAC rbacProto; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + } + Any anyMessage = (Any) rawProtoMessage; + try { + rbacProto = anyMessage.unpack(RBAC.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + return parseRbacConfig(rbacProto); + } + + @VisibleForTesting + static ConfigOrError parseRbacConfig(RBAC rbac) { + if (!rbac.hasRules()) { + return ConfigOrError.fromConfig(RbacConfig.create(null)); + } + io.envoyproxy.envoy.config.rbac.v3.RBAC rbacConfig = rbac.getRules(); + GrpcAuthorizationEngine.Action authAction; + switch (rbacConfig.getAction()) { + case ALLOW: + authAction = GrpcAuthorizationEngine.Action.ALLOW; + break; + case DENY: + authAction = GrpcAuthorizationEngine.Action.DENY; + break; + case LOG: + return ConfigOrError.fromConfig(RbacConfig.create(null)); + case UNRECOGNIZED: + default: + return ConfigOrError.fromError("Unknown rbacConfig action type: " + rbacConfig.getAction()); + } + Map policyMap = rbacConfig.getPoliciesMap(); + List policyMatchers = new ArrayList<>(); + for (Map.Entry entry: policyMap.entrySet()) { + try { + Policy policy = entry.getValue(); + if (policy.hasCondition() || policy.hasCheckedCondition()) { + return ConfigOrError.fromError( + "Policy.condition and Policy.checked_condition must not set: " + entry.getKey()); + } + policyMatchers.add(new PolicyMatcher(entry.getKey(), + parsePermissionList(policy.getPermissionsList()), + parsePrincipalList(policy.getPrincipalsList()))); + } catch (Exception e) { + return ConfigOrError.fromError("Encountered error parsing policy: " + e); + } + } + return ConfigOrError.fromConfig(RbacConfig.create(new AuthConfig(policyMatchers, authAction))); + } + + @Override + public ConfigOrError parseFilterConfigOverride(Message rawProtoMessage) { + RBACPerRoute rbacPerRoute; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + } + Any anyMessage = (Any) rawProtoMessage; + try { + rbacPerRoute = anyMessage.unpack(RBACPerRoute.class); + } catch (InvalidProtocolBufferException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + if (rbacPerRoute.hasRbac()) { + return parseRbacConfig(rbacPerRoute.getRbac()); + } else { + return ConfigOrError.fromConfig(RbacConfig.create(null)); + } + } + + @Nullable + @Override + public ServerInterceptor buildServerInterceptor(FilterConfig config, + @Nullable FilterConfig overrideConfig) { + checkNotNull(config, "config"); + if (overrideConfig != null) { + config = overrideConfig; + } + AuthConfig authConfig = ((RbacConfig) config).authConfig(); + return authConfig == null ? null : generateAuthorizationInterceptor(authConfig); + } + + private ServerInterceptor generateAuthorizationInterceptor(AuthConfig config) { + checkNotNull(config, "config"); + final GrpcAuthorizationEngine authEngine = new GrpcAuthorizationEngine(config); + return new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall( + final ServerCall call, + final Metadata headers, ServerCallHandler next) { + AuthDecision authResult = authEngine.evaluate(headers, call); + logger.log(Level.FINE, + "Authorization result for serverCall {0}: {1}, matching policy: {2}.", + new Object[]{call, authResult.decision(), authResult.matchingPolicyName()}); + if (GrpcAuthorizationEngine.Action.DENY.equals(authResult.decision())) { + Status status = Status.UNAUTHENTICATED.withDescription( + "Access Denied, matching policy: " + authResult.matchingPolicyName()); + call.close(status, new Metadata()); + return new ServerCall.Listener(){}; + } + return next.startCall(call, headers); + } + }; + } + + private static OrMatcher parsePermissionList(List permissions) { + List anyMatch = new ArrayList<>(); + for (Permission permission : permissions) { + anyMatch.add(parsePermission(permission)); + } + return new OrMatcher(anyMatch); + } + + private static Matcher parsePermission(Permission permission) { + switch (permission.getRuleCase()) { + case AND_RULES: + List andMatch = new ArrayList<>(); + for (Permission p : permission.getAndRules().getRulesList()) { + andMatch.add(parsePermission(p)); + } + return new AndMatcher(andMatch); + case OR_RULES: + return parsePermissionList(permission.getOrRules().getRulesList()); + case ANY: + return AlwaysTrueMatcher.INSTANCE; + case HEADER: + return parseHeaderMatcher(permission.getHeader()); + case URL_PATH: + return parsePathMatcher(permission.getUrlPath()); + case DESTINATION_IP: + return createDestinationIpMatcher(permission.getDestinationIp()); + case DESTINATION_PORT: + return createDestinationPortMatcher(permission.getDestinationPort()); + case NOT_RULE: + return new InvertMatcher(parsePermission(permission.getNotRule())); + case METADATA: // hard coded, never match. + return new InvertMatcher(AlwaysTrueMatcher.INSTANCE); + case REQUESTED_SERVER_NAME: + return parseRequestedServerNameMatcher(permission.getRequestedServerName()); + case RULE_NOT_SET: + default: + throw new IllegalArgumentException( + "Unknown permission rule case: " + permission.getRuleCase()); + } + } + + private static OrMatcher parsePrincipalList(List principals) { + List anyMatch = new ArrayList<>(); + for (Principal principal: principals) { + anyMatch.add(parsePrincipal(principal)); + } + return new OrMatcher(anyMatch); + } + + private static Matcher parsePrincipal(Principal principal) { + switch (principal.getIdentifierCase()) { + case OR_IDS: + return parsePrincipalList(principal.getOrIds().getIdsList()); + case AND_IDS: + List nextMatchers = new ArrayList<>(); + for (Principal next : principal.getAndIds().getIdsList()) { + nextMatchers.add(parsePrincipal(next)); + } + return new AndMatcher(nextMatchers); + case ANY: + return AlwaysTrueMatcher.INSTANCE; + case AUTHENTICATED: + return parseAuthenticatedMatcher(principal.getAuthenticated()); + case DIRECT_REMOTE_IP: + return createSourceIpMatcher(principal.getDirectRemoteIp()); + case REMOTE_IP: + return createSourceIpMatcher(principal.getRemoteIp()); + case SOURCE_IP: + return createSourceIpMatcher(principal.getSourceIp()); + case HEADER: + return parseHeaderMatcher(principal.getHeader()); + case NOT_ID: + return new InvertMatcher(parsePrincipal(principal.getNotId())); + case URL_PATH: + return parsePathMatcher(principal.getUrlPath()); + case METADATA: // hard coded, never match. + return new InvertMatcher(AlwaysTrueMatcher.INSTANCE); + case IDENTIFIER_NOT_SET: + default: + throw new IllegalArgumentException( + "Unknown principal identifier case: " + principal.getIdentifierCase()); + } + } + + private static PathMatcher parsePathMatcher( + io.envoyproxy.envoy.type.matcher.v3.PathMatcher proto) { + switch (proto.getRuleCase()) { + case PATH: + return new PathMatcher(MatcherParser.parseStringMatcher(proto.getPath())); + case RULE_NOT_SET: + default: + throw new IllegalArgumentException( + "Unknown path matcher rule type: " + proto.getRuleCase()); + } + } + + private static RequestedServerNameMatcher parseRequestedServerNameMatcher( + io.envoyproxy.envoy.type.matcher.v3.StringMatcher proto) { + return new RequestedServerNameMatcher(MatcherParser.parseStringMatcher(proto)); + } + + private static AuthHeaderMatcher parseHeaderMatcher( + io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto) { + return new AuthHeaderMatcher(MatcherParser.parseHeaderMatcher(proto)); + } + + private static AuthenticatedMatcher parseAuthenticatedMatcher( + Principal.Authenticated proto) { + Matchers.StringMatcher matcher = MatcherParser.parseStringMatcher(proto.getPrincipalName()); + return new AuthenticatedMatcher(matcher); + } + + private static DestinationPortMatcher createDestinationPortMatcher(int port) { + return new DestinationPortMatcher(port); + } + + private static DestinationIpMatcher createDestinationIpMatcher(CidrRange cidrRange) { + return new DestinationIpMatcher(Matchers.CidrMatcher.create( + resolve(cidrRange), cidrRange.getPrefixLen().getValue())); + } + + private static SourceIpMatcher createSourceIpMatcher(CidrRange cidrRange) { + return new SourceIpMatcher(Matchers.CidrMatcher.create( + resolve(cidrRange), cidrRange.getPrefixLen().getValue())); + } + + private static InetAddress resolve(CidrRange cidrRange) { + try { + return InetAddress.getByName(cidrRange.getAddressPrefix()); + } catch (UnknownHostException ex) { + throw new IllegalArgumentException("IP address can not be found: " + ex); + } + } +} + diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java index 260b45efea5..05f29e20112 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancer.java @@ -191,38 +191,8 @@ public void shutdown() { subchannels.clear(); } - private void updateBalancingState() { - checkState(!subchannels.isEmpty(), "no subchannel has been created"); - ConnectivityState overallState = aggregateState(subchannels.values()); - RingHashPicker picker = new RingHashPicker(syncContext, ring, subchannels); - // TODO(chengyuanzhang): avoid unnecessary reprocess caused by duplicated server addr updates - helper.updateBalancingState(overallState, picker); - currentState = overallState; - } - - private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { - if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) { - return; - } - if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) { - helper.refreshNameResolution(); - } - Ref subchannelStateRef = getSubchannelStateInfoRef(subchannel); - - // Don't proactively reconnect if the subchannel enters IDLE, even if previously was connected. - // If the subchannel was previously in TRANSIENT_FAILURE, it is considered to stay in - // TRANSIENT_FAILURE until it becomes READY. - if (subchannelStateRef.value.getState() == TRANSIENT_FAILURE) { - if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) { - return; - } - } - subchannelStateRef.value = stateInfo; - updateBalancingState(); - } - /** - * Aggregates the connectivity states of a group of subchannels for overall connectivity state. + * Updates the overall balancing state by aggregating the connectivity states of all subchannels. * *

Aggregation rules (in order of dominance): *

    @@ -235,30 +205,70 @@ private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo *
  1. Otherwise, overall state is TRANSIENT_FAILURE
  2. *
*/ - private static ConnectivityState aggregateState(Iterable subchannels) { + private void updateBalancingState() { + checkState(!subchannels.isEmpty(), "no subchannel has been created"); int failureCount = 0; - boolean hasIdle = false; boolean hasConnecting = false; - for (Subchannel subchannel : subchannels) { + Subchannel idleSubchannel = null; + ConnectivityState overallState = null; + for (Subchannel subchannel : subchannels.values()) { ConnectivityState state = getSubchannelStateInfoRef(subchannel).value.getState(); if (state == READY) { - return state; + overallState = READY; + break; } if (state == TRANSIENT_FAILURE) { failureCount++; } else if (state == CONNECTING) { hasConnecting = true; } else if (state == IDLE) { - hasIdle = true; + if (idleSubchannel == null) { + idleSubchannel = subchannel; + } + } + } + if (overallState == null) { + if (failureCount >= 2) { + // This load balancer may not get any pick requests from the upstream if it's reporting + // TRANSIENT_FAILURE. It needs to recover by itself by attempting to connect to at least + // one subchannel that has not failed at any given time. + if (!hasConnecting && idleSubchannel != null) { + idleSubchannel.requestConnection(); + } + overallState = TRANSIENT_FAILURE; + } else if (hasConnecting) { + overallState = CONNECTING; + } else if (idleSubchannel != null) { + overallState = IDLE; + } else { + overallState = TRANSIENT_FAILURE; } } - if (failureCount >= 2) { - return TRANSIENT_FAILURE; + RingHashPicker picker = new RingHashPicker(syncContext, ring, subchannels); + // TODO(chengyuanzhang): avoid unnecessary reprocess caused by duplicated server addr updates + helper.updateBalancingState(overallState, picker); + currentState = overallState; + } + + private void processSubchannelState(Subchannel subchannel, ConnectivityStateInfo stateInfo) { + if (subchannels.get(stripAttrs(subchannel.getAddresses())) != subchannel) { + return; + } + if (stateInfo.getState() == TRANSIENT_FAILURE || stateInfo.getState() == IDLE) { + helper.refreshNameResolution(); } - if (hasConnecting) { - return CONNECTING; + Ref subchannelStateRef = getSubchannelStateInfoRef(subchannel); + + // Don't proactively reconnect if the subchannel enters IDLE, even if previously was connected. + // If the subchannel was previously in TRANSIENT_FAILURE, it is considered to stay in + // TRANSIENT_FAILURE until it becomes READY. + if (subchannelStateRef.value.getState() == TRANSIENT_FAILURE) { + if (stateInfo.getState() == CONNECTING || stateInfo.getState() == IDLE) { + return; + } } - return hasIdle ? IDLE : TRANSIENT_FAILURE; + subchannelStateRef.value = stateInfo; + updateBalancingState(); } private static void shutdownSubchannel(Subchannel subchannel) { diff --git a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java index fcbd527bf5c..af613b26078 100644 --- a/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java +++ b/xds/src/main/java/io/grpc/xds/RingHashLoadBalancerProvider.java @@ -16,6 +16,7 @@ package io.grpc.xds; +import com.google.common.annotations.VisibleForTesting; import io.grpc.Internal; import io.grpc.LoadBalancer; import io.grpc.LoadBalancer.Helper; @@ -32,6 +33,17 @@ @Internal public final class RingHashLoadBalancerProvider extends LoadBalancerProvider { + // Same as ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MIN_RING_SIZE + @VisibleForTesting + static final long DEFAULT_MIN_RING_SIZE = 1024L; + // Same as ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE + @VisibleForTesting + static final long DEFAULT_MAX_RING_SIZE = 8 * 1024 * 1024L; + // Maximum number of ring entries allowed. Setting this too large can result in slow + // ring construction and OOM error. + // Same as ClientXdsClient.MAX_RING_HASH_LB_POLICY_RING_SIZE + static final long MAX_RING_SIZE = 8 * 1024 * 1024L; + private static final boolean enableRingHash = Boolean.parseBoolean(System.getenv("GRPC_XDS_EXPERIMENTAL_ENABLE_RING_HASH")); @@ -59,11 +71,14 @@ public String getPolicyName() { public ConfigOrError parseLoadBalancingPolicyConfig(Map rawLoadBalancingPolicyConfig) { Long minRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "minRingSize"); Long maxRingSize = JsonUtil.getNumberAsLong(rawLoadBalancingPolicyConfig, "maxRingSize"); - if (minRingSize == null || maxRingSize == null) { - return ConfigOrError.fromError(Status.INVALID_ARGUMENT.withDescription( - "Missing 'mingRingSize'/'maxRingSize'")); + if (minRingSize == null) { + minRingSize = DEFAULT_MIN_RING_SIZE; + } + if (maxRingSize == null) { + maxRingSize = DEFAULT_MAX_RING_SIZE; } - if (minRingSize <= 0 || maxRingSize <= 0 || minRingSize > maxRingSize) { + if (minRingSize <= 0 || maxRingSize <= 0 || minRingSize > maxRingSize + || maxRingSize > MAX_RING_SIZE) { return ConfigOrError.fromError(Status.INVALID_ARGUMENT.withDescription( "Invalid 'mingRingSize'/'maxRingSize'")); } diff --git a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java index aaf7d2848f1..cc87b9c6b6f 100644 --- a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java @@ -20,6 +20,7 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.ChannelCredentials; +import io.grpc.Context; import io.grpc.Grpc; import io.grpc.ManagedChannel; import io.grpc.internal.ExponentialBackoffPolicy; @@ -30,6 +31,7 @@ import io.grpc.xds.Bootstrapper.BootstrapInfo; import io.grpc.xds.Bootstrapper.ServerInfo; import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; +import io.grpc.xds.internal.sds.TlsContextManagerImpl; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -105,6 +107,7 @@ private static class SharedXdsClientPoolProviderHolder { @ThreadSafe @VisibleForTesting static class RefCountedXdsClientObjectPool implements ObjectPool { + private final Context context = Context.ROOT; private final BootstrapInfo bootstrapInfo; private final Object lock = new Object(); @GuardedBy("lock") @@ -132,9 +135,9 @@ public XdsClient getObject() { .keepAliveTime(5, TimeUnit.MINUTES) .build(); scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE); - xdsClient = new ClientXdsClient(channel, bootstrapInfo, scheduler, + xdsClient = new ClientXdsClient(channel, bootstrapInfo, context, scheduler, new ExponentialBackoffPolicy.Provider(), GrpcUtil.STOPWATCH_SUPPLIER, - TimeProvider.SYSTEM_TIME_PROVIDER); + TimeProvider.SYSTEM_TIME_PROVIDER, new TlsContextManagerImpl(bootstrapInfo)); } refCount++; return xdsClient; diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManager.java b/xds/src/main/java/io/grpc/xds/TlsContextManager.java similarity index 95% rename from xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManager.java rename to xds/src/main/java/io/grpc/xds/TlsContextManager.java index 561806ec670..e35eb68f219 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManager.java +++ b/xds/src/main/java/io/grpc/xds/TlsContextManager.java @@ -14,11 +14,14 @@ * limitations under the License. */ -package io.grpc.xds.internal.sds; +package io.grpc.xds; +import io.grpc.Internal; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.internal.sds.SslContextProvider; +@Internal public interface TlsContextManager { /** Creates a SslContextProvider. Used for retrieving a server-side SslContext. */ diff --git a/xds/src/main/java/io/grpc/xds/VirtualHost.java b/xds/src/main/java/io/grpc/xds/VirtualHost.java index 708b6a922df..59ca702d675 100644 --- a/xds/src/main/java/io/grpc/xds/VirtualHost.java +++ b/xds/src/main/java/io/grpc/xds/VirtualHost.java @@ -25,9 +25,8 @@ import com.google.common.collect.ImmutableMap; import com.google.re2j.Pattern; import io.grpc.xds.Filter.FilterConfig; -import io.grpc.xds.Matchers.FractionMatcher; -import io.grpc.xds.Matchers.HeaderMatcher; -import io.grpc.xds.Matchers.PathMatcher; +import io.grpc.xds.internal.Matchers.FractionMatcher; +import io.grpc.xds.internal.Matchers.HeaderMatcher; import java.util.Collections; import java.util.List; import java.util.Map; @@ -90,6 +89,47 @@ static RouteMatch create(PathMatcher pathMatcher, return new AutoValue_VirtualHost_Route_RouteMatch(pathMatcher, ImmutableList.copyOf(headerMatchers), fractionMatcher); } + + /** Matcher for HTTP request path. */ + @AutoValue + abstract static class PathMatcher { + // Exact full path to be matched. + @Nullable + abstract String path(); + + // Path prefix to be matched. + @Nullable + abstract String prefix(); + + // Regular expression pattern of the path to be matched. + @Nullable + abstract Pattern regEx(); + + // Whether case sensitivity is taken into account for matching. + // Only valid for full path matching or prefix matching. + abstract boolean caseSensitive(); + + static PathMatcher fromPath(String path, boolean caseSensitive) { + checkNotNull(path, "path"); + return create(path, null, null, caseSensitive); + } + + static PathMatcher fromPrefix(String prefix, boolean caseSensitive) { + checkNotNull(prefix, "prefix"); + return create(null, prefix, null, caseSensitive); + } + + static PathMatcher fromRegEx(Pattern regEx) { + checkNotNull(regEx, "regEx"); + return create(null, null, regEx, false /* doesn't matter */); + } + + private static PathMatcher create(@Nullable String path, @Nullable String prefix, + @Nullable Pattern regEx, boolean caseSensitive) { + return new AutoValue_VirtualHost_Route_RouteMatch_PathMatcher(path, prefix, regEx, + caseSensitive); + } + } } @AutoValue diff --git a/xds/src/main/java/io/grpc/xds/XdsClient.java b/xds/src/main/java/io/grpc/xds/XdsClient.java index 37c07eb1bcf..1c2238c1e17 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClient.java +++ b/xds/src/main/java/io/grpc/xds/XdsClient.java @@ -191,6 +191,12 @@ abstract static class CdsUpdate implements ResourceUpdate { @Nullable abstract String edsServiceName(); + // Corresponding DNS name to be used if upstream endpoints of the cluster is resolvable + // via DNS. + // Only valid for LOGICAL_DNS cluster. + @Nullable + abstract String dnsHostName(); + // Load report server name for reporting loads via LRS. // Only valid for EDS or LOGICAL_DNS cluster. @Nullable @@ -235,13 +241,15 @@ static Builder forEds(String clusterName, @Nullable String edsServiceName, .upstreamTlsContext(upstreamTlsContext); } - static Builder forLogicalDns(String clusterName, @Nullable String lrsServerName, - @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext upstreamTlsContext) { + static Builder forLogicalDns(String clusterName, String dnsHostName, + @Nullable String lrsServerName, @Nullable Long maxConcurrentRequests, + @Nullable UpstreamTlsContext upstreamTlsContext) { return new AutoValue_XdsClient_CdsUpdate.Builder() .clusterName(clusterName) .clusterType(ClusterType.LOGICAL_DNS) .minRingSize(0) .maxRingSize(0) + .dnsHostName(dnsHostName) .lrsServerName(lrsServerName) .maxConcurrentRequests(maxConcurrentRequests) .upstreamTlsContext(upstreamTlsContext); @@ -265,6 +273,7 @@ public final String toString() { .add("minRingSize", minRingSize()) .add("maxRingSize", maxRingSize()) .add("edsServiceName", edsServiceName()) + .add("dnsHostName", dnsHostName()) .add("lrsServerName", lrsServerName()) .add("maxConcurrentRequests", maxConcurrentRequests()) // Exclude upstreamTlsContext as its string representation is cumbersome. @@ -280,21 +289,29 @@ abstract static class Builder { // Private, use one of the static factory methods instead. protected abstract Builder clusterType(ClusterType clusterType); - abstract Builder lbPolicy(LbPolicy lbPolicy); + // Private, use roundRobinLbPolicy() or ringHashLbPolicy(long, long). + protected abstract Builder lbPolicy(LbPolicy lbPolicy); + + Builder roundRobinLbPolicy() { + return this.lbPolicy(LbPolicy.ROUND_ROBIN); + } - Builder lbPolicy(LbPolicy lbPolicy, long minRingSize, long maxRingSize) { - return this.lbPolicy(lbPolicy).minRingSize(minRingSize).maxRingSize(maxRingSize); + Builder ringHashLbPolicy(long minRingSize, long maxRingSize) { + return this.lbPolicy(LbPolicy.RING_HASH).minRingSize(minRingSize).maxRingSize(maxRingSize); } - // Private, use lbPolicy(LbPolicy, long, long). + // Private, use ringHashLbPolicy(long, long). protected abstract Builder minRingSize(long minRingSize); - // Private, use lbPolicy(.LbPolicy, long, long) + // Private, use ringHashLbPolicy(long, long). protected abstract Builder maxRingSize(long maxRingSize); // Private, use CdsUpdate.forEds() instead. protected abstract Builder edsServiceName(String edsServiceName); + // Private, use CdsUpdate.forLogicalDns() instead. + protected abstract Builder dnsHostName(String dnsHostName); + // Private, use one of the static factory methods instead. protected abstract Builder lrsServerName(String lrsServerName); @@ -537,6 +554,13 @@ Bootstrapper.BootstrapInfo getBootstrapInfo() { throw new UnsupportedOperationException(); } + /** + * Returns the {@link TlsContextManager} used in this XdsClient. + */ + TlsContextManager getTlsContextManager() { + throw new UnsupportedOperationException(); + } + /** * Returns the latest accepted version of the given resource type. */ diff --git a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java index 040892c016d..2cfbb325ad9 100644 --- a/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java +++ b/xds/src/main/java/io/grpc/xds/XdsClientWrapperForServerSds.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; import com.google.common.collect.ImmutableSet; import com.google.protobuf.UInt32Value; import io.grpc.Internal; @@ -27,23 +28,21 @@ import io.grpc.internal.ObjectPool; import io.grpc.internal.SharedResourceHolder; import io.grpc.xds.EnvoyServerProtoData.CidrRange; -import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.FilterChain; import io.grpc.xds.EnvoyServerProtoData.FilterChainMatch; +import io.grpc.xds.internal.Matchers.CidrMatcher; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; import io.netty.channel.Channel; import io.netty.channel.epoll.Epoll; import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.util.concurrent.DefaultThreadFactory; -import java.math.BigInteger; import java.net.Inet6Address; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.ArrayList; -import java.util.Comparator; import java.util.HashSet; import java.util.List; -import java.util.PriorityQueue; import java.util.Set; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; @@ -74,6 +73,7 @@ public final class XdsClientWrapperForServerSds { private ScheduledExecutorService timeService; private XdsClient.LdsResourceWatcher listenerWatcher; private boolean newServerApi; + private String grpcServerResourceId; @VisibleForTesting final Set serverWatchers = new HashSet<>(); /** @@ -96,6 +96,10 @@ public final class XdsClientWrapperForServerSds { return xdsClient; } + public TlsContextManager getTlsContextManager() { + return xdsClient.getTlsContextManager(); + } + /** Accepts an XdsClient and starts a watch. */ @VisibleForTesting public void start() { @@ -110,14 +114,14 @@ public void start() { new XdsClient.LdsResourceWatcher() { @Override public void onChanged(XdsClient.LdsUpdate update) { - curListener.set(update.listener); + releaseOldSuppliers(curListener.getAndSet(update.listener)); reportSuccess(); } @Override public void onResourceDoesNotExist(String resourceName) { logger.log(Level.WARNING, "Resource {0} is unavailable", resourceName); - curListener.set(null); + releaseOldSuppliers(curListener.getAndSet(null)); reportError(Status.NOT_FOUND.asException(), true); } @@ -125,10 +129,15 @@ public void onResourceDoesNotExist(String resourceName) { public void onError(Status error) { logger.log( Level.WARNING, "LdsResourceWatcher in XdsClientWrapperForServerSds: {0}", error); - reportError(error.asException(), isResourceAbsent(error)); + if (isResourceAbsent(error)) { + releaseOldSuppliers(curListener.getAndSet(null)); + reportError(error.asException(), true); + } else { + reportError(error.asException(), false); + } } }; - String grpcServerResourceId = xdsClient.getBootstrapInfo() + grpcServerResourceId = xdsClient.getBootstrapInfo() .getServerListenerResourceNameTemplate(); newServerApi = xdsClient.getBootstrapInfo().getServers().get(0).isUseProtocolV3(); if (newServerApi && grpcServerResourceId == null) { @@ -141,6 +150,27 @@ public void onError(Status error) { xdsClient.watchLdsResource(grpcServerResourceId, listenerWatcher); } + // go thru the old listener and release all the old SslContextProviderSupplier + private void releaseOldSuppliers(EnvoyServerProtoData.Listener oldListener) { + if (oldListener != null) { + List filterChains = oldListener.getFilterChains(); + for (FilterChain filterChain : filterChains) { + releaseSupplier(filterChain); + } + releaseSupplier(oldListener.getDefaultFilterChain()); + } + } + + private static void releaseSupplier(FilterChain filterChain) { + if (filterChain != null) { + SslContextProviderSupplier sslContextProviderSupplier = + filterChain.getSslContextProviderSupplier(); + if (sslContextProviderSupplier != null) { + sslContextProviderSupplier.close(); + } + } + } + /** Whether the throwable indicates our listener resource is absent/deleted. */ private static boolean isResourceAbsent(Status status) { Status.Code code = status.getCode(); @@ -158,10 +188,10 @@ private static boolean isResourceAbsent(Status status) { /** * Locates the best matching FilterChain to the channel from the current listener and if found - * returns the DownstreamTlsContext from that FilterChain, else null. + * returns the SslContextProviderSupplier from that FilterChain, else null. */ @Nullable - public DownstreamTlsContext getDownstreamTlsContext(Channel channel) { + public SslContextProviderSupplier getSslContextProviderSupplier(Channel channel) { EnvoyServerProtoData.Listener copyListener = curListener.get(); if (copyListener != null && channel != null) { SocketAddress localAddress = channel.localAddress(); @@ -172,7 +202,7 @@ public DownstreamTlsContext getDownstreamTlsContext(Channel channel) { checkState( port == localInetAddr.getPort(), "Channel localAddress port does not match requested listener port"); - return getDownstreamTlsContext(localInetAddr, remoteInetAddr, copyListener); + return getSslContextProviderSupplier(localInetAddr, remoteInetAddr, copyListener); } } return null; @@ -181,19 +211,22 @@ public DownstreamTlsContext getDownstreamTlsContext(Channel channel) { /** * Using the logic specified at * https://www.envoyproxy.io/docs/envoy/latest/api-v2/api/v2/listener/listener_components.proto.html?highlight=filter%20chain#listener-filterchainmatch - * locate a matching filter and return the corresponding DownstreamTlsContext or else return one - * from default filter chain. + * locate a matching filter and return the corresponding SslContextProviderSupplier or else + * return one from default filter chain. * * @param localInetAddr dest address of the inbound connection * @param remoteInetAddr source address of the inbound connection */ - private static DownstreamTlsContext getDownstreamTlsContext( + private static SslContextProviderSupplier getSslContextProviderSupplier( InetSocketAddress localInetAddr, InetSocketAddress remoteInetAddr, EnvoyServerProtoData.Listener listener) { List filterChains = listener.getFilterChains(); filterChains = filterOnDestinationPort(filterChains); filterChains = filterOnIpAddress(filterChains, localInetAddr.getAddress(), true); + filterChains = filterOnServerNames(filterChains); + filterChains = filterOnTransportProtocol(filterChains); + filterChains = filterOnApplicationProtocols(filterChains); filterChains = filterOnSourceType(filterChains, remoteInetAddr.getAddress(), localInetAddr.getAddress()); filterChains = filterOnIpAddress(filterChains, remoteInetAddr.getAddress(), false); @@ -203,9 +236,49 @@ private static DownstreamTlsContext getDownstreamTlsContext( // close the connection throw new IllegalStateException("Found 2 matching filter-chains"); } else if (filterChains.size() == 1) { - return filterChains.get(0).getDownstreamTlsContext(); + return filterChains.get(0).getSslContextProviderSupplier(); + } + return listener.getDefaultFilterChain().getSslContextProviderSupplier(); + } + + // reject if filer-chain-match has non-empty application_protocols + private static List filterOnApplicationProtocols(List filterChains) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + if (filterChainMatch.getApplicationProtocols().isEmpty()) { + filtered.add(filterChain); + } } - return listener.getDefaultFilterChain().getDownstreamTlsContext(); + return filtered; + } + + // reject if filer-chain-match has non-empty transport protocol other than "raw_buffer" + private static List filterOnTransportProtocol(List filterChains) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + String transportProtocol = filterChainMatch.getTransportProtocol(); + if ( Strings.isNullOrEmpty(transportProtocol) || "raw_buffer".equals(transportProtocol)) { + filtered.add(filterChain); + } + } + return filtered; + } + + // reject if filer-chain-match has server_name(s) + private static List filterOnServerNames(List filterChains) { + ArrayList filtered = new ArrayList<>(filterChains.size()); + for (FilterChain filterChain : filterChains) { + FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); + + if (filterChainMatch.getServerNames().isEmpty()) { + filtered.add(filterChain); + } + } + return filtered; } // destination_port present => Always fail match @@ -265,99 +338,53 @@ private static List filterOnSourceType( return filtered; } - private static boolean isCidrMatching(byte[] cidrBytes, byte[] addressBytes, int prefixLen) { - BigInteger cidrInt = new BigInteger(cidrBytes); - BigInteger addrInt = new BigInteger(addressBytes); - - int shiftAmount = 8 * cidrBytes.length - prefixLen; - - cidrInt = cidrInt.shiftRight(shiftAmount); - addrInt = addrInt.shiftRight(shiftAmount); - return cidrInt.equals(addrInt); - } - - private static class QueueElement { - FilterChain filterChain; - int indexOfMatchingPrefixRange; + private static int getMatchingPrefixLength( + FilterChainMatch filterChainMatch, InetAddress address, boolean forDestination) { + boolean isIPv6 = address instanceof Inet6Address; + List cidrRanges = + forDestination + ? filterChainMatch.getPrefixRanges() + : filterChainMatch.getSourcePrefixRanges(); int matchingPrefixLength; - - public QueueElement(FilterChain filterChain, InetAddress address, boolean forDestination) { - this.filterChain = filterChain; - FilterChainMatch filterChainMatch = filterChain.getFilterChainMatch(); - byte[] addressBytes = address.getAddress(); - boolean isIPv6 = address instanceof Inet6Address; - List cidrRanges = - forDestination - ? filterChainMatch.getPrefixRanges() - : filterChainMatch.getSourcePrefixRanges(); - indexOfMatchingPrefixRange = -1; - if (cidrRanges.isEmpty()) { // if there is no CidrRange assume 0-length match - matchingPrefixLength = 0; - } else { - matchingPrefixLength = -1; - int index = 0; - for (CidrRange cidrRange : cidrRanges) { - InetAddress cidrAddr = cidrRange.getAddressPrefix(); - boolean cidrIsIpv6 = cidrAddr instanceof Inet6Address; - if (isIPv6 == cidrIsIpv6) { - byte[] cidrBytes = cidrAddr.getAddress(); - int prefixLen = cidrRange.getPrefixLen(); - if (isCidrMatching(cidrBytes, addressBytes, prefixLen) - && prefixLen > matchingPrefixLength) { - matchingPrefixLength = prefixLen; - indexOfMatchingPrefixRange = index; - } + if (cidrRanges.isEmpty()) { // if there is no CidrRange assume 0-length match + matchingPrefixLength = 0; + } else { + matchingPrefixLength = -1; + for (CidrRange cidrRange : cidrRanges) { + InetAddress cidrAddr = cidrRange.getAddressPrefix(); + boolean cidrIsIpv6 = cidrAddr instanceof Inet6Address; + if (isIPv6 == cidrIsIpv6) { + int prefixLen = cidrRange.getPrefixLen(); + CidrMatcher matcher = CidrMatcher.create(cidrAddr, prefixLen); + if (matcher.matches(address) && prefixLen > matchingPrefixLength) { + matchingPrefixLength = prefixLen; } - index++; } } } - } - - private static final class QueueElementComparator implements Comparator { - - @Override - public int compare(QueueElement o1, QueueElement o2) { - // descending order for max heap - return o2.matchingPrefixLength - o1.matchingPrefixLength; - } - - @Override - public boolean equals(Object obj) { - return obj instanceof QueueElementComparator; - } - - @Override - public int hashCode() { - return super.hashCode(); - } + return matchingPrefixLength; } // use prefix_ranges (CIDR) and get the most specific matches private static List filterOnIpAddress( List filterChains, InetAddress address, boolean forDestination) { - PriorityQueue heap = new PriorityQueue<>(10, new QueueElementComparator()); - + // curent list of top ones + ArrayList topOnes = new ArrayList<>(filterChains.size()); + int topMatchingPrefixLen = -1; for (FilterChain filterChain : filterChains) { - QueueElement element = new QueueElement(filterChain, address, forDestination); + int currentMatchingPrefixLen = + getMatchingPrefixLength(filterChain.getFilterChainMatch(), address, forDestination); - if (element.matchingPrefixLength >= 0) { - heap.add(element); - } - } - // get the top ones - ArrayList topOnes = new ArrayList<>(heap.size()); - int topMatchingPrefixLen = -1; - while (!heap.isEmpty()) { - QueueElement element = heap.remove(); - if (topMatchingPrefixLen == -1) { - topMatchingPrefixLen = element.matchingPrefixLength; - } else { - if (element.matchingPrefixLength < topMatchingPrefixLen) { - break; + if (currentMatchingPrefixLen >= 0) { + if (currentMatchingPrefixLen < topMatchingPrefixLen) { + continue; + } + if (currentMatchingPrefixLen > topMatchingPrefixLen) { + topMatchingPrefixLen = currentMatchingPrefixLen; + topOnes.clear(); } + topOnes.add(filterChain); } - topOnes.add(element.filterChain); } return topOnes; } @@ -419,8 +446,10 @@ public interface ServerWatcher { public void shutdown() { logger.log(Level.FINER, "Shutdown"); if (xdsClient != null) { + xdsClient.cancelLdsResourceWatch(grpcServerResourceId, listenerWatcher); xdsClient = xdsClientPool.returnObject(xdsClient); } + releaseOldSuppliers(curListener.getAndSet(null)); if (timeService != null) { timeService = SharedResourceHolder.release(timeServiceResource, timeService); } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index b979d6edf06..c88d71857b0 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -46,15 +46,13 @@ import io.grpc.xds.Filter.ClientInterceptorBuilder; import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; -import io.grpc.xds.Matchers.FractionMatcher; -import io.grpc.xds.Matchers.HeaderMatcher; -import io.grpc.xds.Matchers.PathMatcher; import io.grpc.xds.ThreadSafeRandom.ThreadSafeRandomImpl; import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.VirtualHost.Route.RouteAction; import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight; import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy; import io.grpc.xds.VirtualHost.Route.RouteMatch; +import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; import io.grpc.xds.XdsClient.LdsResourceWatcher; import io.grpc.xds.XdsClient.LdsUpdate; import io.grpc.xds.XdsClient.RdsResourceWatcher; @@ -62,6 +60,8 @@ import io.grpc.xds.XdsLogger.XdsLogLevel; import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider; import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; +import io.grpc.xds.internal.Matchers.FractionMatcher; +import io.grpc.xds.internal.Matchers.HeaderMatcher; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -334,26 +334,12 @@ static boolean matchHostName(String hostName, String pattern) { private final class ConfigSelector extends InternalConfigSelector { @Override public Result selectConfig(PickSubchannelArgs args) { - // Index ASCII headers by key, multi-value headers are concatenated for matching purposes. - Map asciiHeaders = new HashMap<>(); - Metadata headers = args.getHeaders(); - for (String headerName : headers.keys()) { - if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { - continue; - } - Metadata.Key key = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER); - Iterable values = headers.getAll(key); - if (values != null) { - asciiHeaders.put(headerName, Joiner.on(",").join(values)); - } - } - // Special hack for exposing headers: "content-type". - asciiHeaders.put("content-type", "application/grpc"); String cluster = null; Route selectedRoute = null; RoutingConfig routingCfg; Map selectedOverrideConfigs; List filterInterceptors = new ArrayList<>(); + Metadata headers = args.getHeaders(); do { routingCfg = routingConfig; selectedOverrideConfigs = new HashMap<>(routingCfg.virtualHostOverrideConfig); @@ -363,7 +349,7 @@ public Result selectConfig(PickSubchannelArgs args) { } for (Route route : routingCfg.routes) { if (matchRoute(route.routeMatch(), "/" + args.getMethodDescriptor().getFullMethodName(), - asciiHeaders, random)) { + headers, random)) { selectedRoute = route; selectedOverrideConfigs.putAll(route.filterConfigOverrides()); break; @@ -442,7 +428,7 @@ public Result selectConfig(PickSubchannelArgs args) { } } final String finalCluster = cluster; - final long hash = generateHash(selectedRoute.routeAction().hashPolicies(), asciiHeaders); + final long hash = generateHash(selectedRoute.routeAction().hashPolicies(), headers); class ClusterSelectionInterceptor implements ClientInterceptor { @Override public ClientCall interceptCall( @@ -517,13 +503,13 @@ public void run() { } } - private long generateHash(List hashPolicies, Map headers) { + private long generateHash(List hashPolicies, Metadata headers) { Long hash = null; for (HashPolicy policy : hashPolicies) { Long newHash = null; if (policy.type() == HashPolicy.Type.HEADER) { - if (headers.containsKey(policy.headerName())) { - String value = headers.get(policy.headerName()); + String value = getHeaderValue(headers, policy.headerName()); + if (value != null) { if (policy.regEx() != null && policy.regExSubstitution() != null) { value = policy.regEx().matcher(value).replaceAll(policy.regExSubstitution()); } @@ -565,19 +551,20 @@ public ClientCall interceptCall( @VisibleForTesting static boolean matchRoute(RouteMatch routeMatch, String fullMethodName, - Map headers, ThreadSafeRandom random) { + Metadata headers, ThreadSafeRandom random) { if (!matchPath(routeMatch.pathMatcher(), fullMethodName)) { return false; } - if (!matchHeaders(routeMatch.headerMatchers(), headers)) { - return false; + for (HeaderMatcher headerMatcher : routeMatch.headerMatchers()) { + if (!matchHeader(headerMatcher, getHeaderValue(headers, headerMatcher.name()))) { + return false; + } } FractionMatcher fraction = routeMatch.fractionMatcher(); return fraction == null || random.nextInt(fraction.denominator()) < fraction.numerator(); } - @VisibleForTesting - static boolean matchPath(PathMatcher pathMatcher, String fullMethodName) { + private static boolean matchPath(PathMatcher pathMatcher, String fullMethodName) { if (pathMatcher.path() != null) { return pathMatcher.caseSensitive() ? pathMatcher.path().equals(fullMethodName) @@ -590,18 +577,8 @@ static boolean matchPath(PathMatcher pathMatcher, String fullMethodName) { return pathMatcher.regEx().matches(fullMethodName); } - private static boolean matchHeaders( - List headerMatchers, Map headers) { - for (HeaderMatcher headerMatcher : headerMatchers) { - if (!matchHeader(headerMatcher, headers.get(headerMatcher.name()))) { - return false; - } - } - return true; - } - - @VisibleForTesting - static boolean matchHeader(HeaderMatcher headerMatcher, @Nullable String value) { + // TODO(zivy): consider reuse Matchers.HeaderMatcher.matches() + private static boolean matchHeader(HeaderMatcher headerMatcher, @Nullable String value) { if (headerMatcher.present() != null) { return (value == null) == headerMatcher.present().equals(headerMatcher.inverted()); } @@ -630,6 +607,24 @@ static boolean matchHeader(HeaderMatcher headerMatcher, @Nullable String value) return baseMatch != headerMatcher.inverted(); } + @Nullable + private static String getHeaderValue(Metadata headers, String headerName) { + if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return null; + } + if (headerName.equals("content-type")) { + return "application/grpc"; + } + Metadata.Key key; + try { + key = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER); + } catch (IllegalArgumentException e) { + return null; + } + Iterable values = headers.getAll(key); + return values == null ? null : Joiner.on(",").join(values); + } + private class ResolveState implements LdsResourceWatcher { private final ConfigOrError emptyServiceConfig = serviceConfigParser.parseServiceConfig(Collections.emptyMap()); diff --git a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java index fd3f537245d..d201c565caa 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerBuilder.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkState; import com.google.common.annotations.VisibleForTesting; +import com.google.errorprone.annotations.DoNotCall; import io.grpc.Attributes; import io.grpc.ExperimentalApi; import io.grpc.ForwardingServerBuilder; @@ -69,6 +70,7 @@ public XdsServerBuilder xdsServingStatusListener( /** * Unsupported call. Users should only use {@link #forPort(int, ServerCredentials)}. */ + @DoNotCall("Unsupported. Use forPort(int, ServerCredentials) instead") public static ServerBuilder forPort(int port) { throw new UnsupportedOperationException( "Unsupported call - use forPort(int, ServerCredentials)"); diff --git a/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java b/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java new file mode 100644 index 00000000000..0a971655df1 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/MatcherParser.java @@ -0,0 +1,85 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import com.google.re2j.Pattern; +import com.google.re2j.PatternSyntaxException; + +// TODO(zivy@): may reuse common matchers parsers. +public final class MatcherParser { + /** Translates envoy proto HeaderMatcher to internal HeaderMatcher.*/ + public static Matchers.HeaderMatcher parseHeaderMatcher( + io.envoyproxy.envoy.config.route.v3.HeaderMatcher proto) { + switch (proto.getHeaderMatchSpecifierCase()) { + case EXACT_MATCH: + return Matchers.HeaderMatcher.forExactValue( + proto.getName(), proto.getExactMatch(), proto.getInvertMatch()); + case SAFE_REGEX_MATCH: + String rawPattern = proto.getSafeRegexMatch().getRegex(); + Pattern safeRegExMatch; + try { + safeRegExMatch = Pattern.compile(rawPattern); + } catch (PatternSyntaxException e) { + throw new IllegalArgumentException( + "HeaderMatcher [" + proto.getName() + "] contains malformed safe regex pattern: " + + e.getMessage()); + } + return Matchers.HeaderMatcher.forSafeRegEx( + proto.getName(), safeRegExMatch, proto.getInvertMatch()); + case RANGE_MATCH: + Matchers.HeaderMatcher.Range rangeMatch = Matchers.HeaderMatcher.Range.create( + proto.getRangeMatch().getStart(), proto.getRangeMatch().getEnd()); + return Matchers.HeaderMatcher.forRange( + proto.getName(), rangeMatch, proto.getInvertMatch()); + case PRESENT_MATCH: + return Matchers.HeaderMatcher.forPresent( + proto.getName(), proto.getPresentMatch(), proto.getInvertMatch()); + case PREFIX_MATCH: + return Matchers.HeaderMatcher.forPrefix( + proto.getName(), proto.getPrefixMatch(), proto.getInvertMatch()); + case SUFFIX_MATCH: + return Matchers.HeaderMatcher.forSuffix( + proto.getName(), proto.getSuffixMatch(), proto.getInvertMatch()); + case HEADERMATCHSPECIFIER_NOT_SET: + default: + throw new IllegalArgumentException( + "Unknown header matcher type: " + proto.getHeaderMatchSpecifierCase()); + } + } + + /** Translate StringMatcher envoy proto to internal StringMatcher. */ + public static Matchers.StringMatcher parseStringMatcher( + io.envoyproxy.envoy.type.matcher.v3.StringMatcher proto) { + switch (proto.getMatchPatternCase()) { + case EXACT: + return Matchers.StringMatcher.forExact(proto.getExact(), proto.getIgnoreCase()); + case PREFIX: + return Matchers.StringMatcher.forPrefix(proto.getPrefix(), proto.getIgnoreCase()); + case SUFFIX: + return Matchers.StringMatcher.forSuffix(proto.getSuffix(), proto.getIgnoreCase()); + case SAFE_REGEX: + return Matchers.StringMatcher.forSafeRegEx( + Pattern.compile(proto.getSafeRegex().getRegex())); + case CONTAINS: + return Matchers.StringMatcher.forContains(proto.getContains()); + case MATCHPATTERN_NOT_SET: + default: + throw new IllegalArgumentException( + "Unknown StringMatcher match pattern: " + proto.getMatchPatternCase()); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/Matchers.java b/xds/src/main/java/io/grpc/xds/internal/Matchers.java new file mode 100644 index 00000000000..28ec8418297 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/Matchers.java @@ -0,0 +1,301 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.auto.value.AutoValue; +import com.google.re2j.Pattern; +import java.math.BigInteger; +import java.net.InetAddress; +import javax.annotation.Nullable; + +/** + * Provides a group of request matchers. A matcher evaluates an input and tells whether certain + * argument in the input matches a predefined matching pattern. + */ +public final class Matchers { + // Prevent instantiation. + private Matchers() {} + + /** Matcher for HTTP request headers. */ + @AutoValue + public abstract static class HeaderMatcher { + // Name of the header to be matched. + public abstract String name(); + + // Matches exact header value. + @Nullable + public abstract String exactValue(); + + // Matches header value with the regular expression pattern. + @Nullable + public abstract Pattern safeRegEx(); + + // Matches header value an integer value in the range. + @Nullable + public abstract Range range(); + + // Matches header presence. + @Nullable + public abstract Boolean present(); + + // Matches header value with the prefix. + @Nullable + public abstract String prefix(); + + // Matches header value with the suffix. + @Nullable + public abstract String suffix(); + + // Whether the matching semantics is inverted. E.g., present && !inverted -> !present + public abstract boolean inverted(); + + /** The request header value should exactly match the specified value. */ + public static HeaderMatcher forExactValue(String name, String exactValue, boolean inverted) { + checkNotNull(name, "name"); + checkNotNull(exactValue, "exactValue"); + return HeaderMatcher.create(name, exactValue, null, null, null, null, null, inverted); + } + + /** The request header value should match the regular expression pattern. */ + public static HeaderMatcher forSafeRegEx(String name, Pattern safeRegEx, boolean inverted) { + checkNotNull(name, "name"); + checkNotNull(safeRegEx, "safeRegEx"); + return HeaderMatcher.create(name, null, safeRegEx, null, null, null, null, inverted); + } + + /** The request header value should be within the range. */ + public static HeaderMatcher forRange(String name, Range range, boolean inverted) { + checkNotNull(name, "name"); + checkNotNull(range, "range"); + return HeaderMatcher.create(name, null, null, range, null, null, null, inverted); + } + + /** The request header value should exist. */ + public static HeaderMatcher forPresent(String name, boolean present, boolean inverted) { + checkNotNull(name, "name"); + return HeaderMatcher.create(name, null, null, null, present, null, null, inverted); + } + + /** The request header value should have this prefix. */ + public static HeaderMatcher forPrefix(String name, String prefix, boolean inverted) { + checkNotNull(name, "name"); + checkNotNull(prefix, "prefix"); + return HeaderMatcher.create(name, null, null, null, null, prefix, null, inverted); + } + + /** The request header value should have this suffix. */ + public static HeaderMatcher forSuffix(String name, String suffix, boolean inverted) { + checkNotNull(name, "name"); + checkNotNull(suffix, "suffix"); + return HeaderMatcher.create(name, null, null, null, null, null, suffix, inverted); + } + + private static HeaderMatcher create(String name, @Nullable String exactValue, + @Nullable Pattern safeRegEx, @Nullable Range range, + @Nullable Boolean present, @Nullable String prefix, + @Nullable String suffix, boolean inverted) { + checkNotNull(name, "name"); + return new AutoValue_Matchers_HeaderMatcher(name, exactValue, safeRegEx, range, present, + prefix, suffix, inverted); + } + + /** Returns the matching result. */ + public boolean matches(@Nullable String value) { + if (present() != null) { + return (value == null) == present().equals(inverted()); + } + // FIXME(zivy@): invert result for null value. + // https://github.com/envoyproxy/envoy/blob/0fae6970ddaf93f024908ba304bbd2b34e997a51/source/common/http/header_utility.cc#L130 + if (value == null) { + return false; + } + boolean baseMatch; + if (exactValue() != null) { + baseMatch = exactValue().equals(value); + } else if (safeRegEx() != null) { + baseMatch = safeRegEx().matches(value); + } else if (range() != null) { + long numValue; + try { + numValue = Long.parseLong(value); + baseMatch = numValue >= range().start() + && numValue <= range().end(); + } catch (NumberFormatException ignored) { + baseMatch = false; + } + } else if (prefix() != null) { + baseMatch = value.startsWith(prefix()); + } else { + baseMatch = value.endsWith(suffix()); + } + return baseMatch != inverted(); + } + + /** Represents an integer range. */ + @AutoValue + public abstract static class Range { + public abstract long start(); + + public abstract long end(); + + public static Range create(long start, long end) { + return new AutoValue_Matchers_HeaderMatcher_Range(start, end); + } + } + } + + /** Represents a fractional value. */ + @AutoValue + public abstract static class FractionMatcher { + public abstract int numerator(); + + public abstract int denominator(); + + public static FractionMatcher create(int numerator, int denominator) { + return new AutoValue_Matchers_FractionMatcher(numerator, denominator); + } + } + + /** Represents various ways to match a string .*/ + @AutoValue + public abstract static class StringMatcher { + @Nullable + abstract String exact(); + + // The input string has this prefix. + @Nullable + abstract String prefix(); + + // The input string has this suffix. + @Nullable + abstract String suffix(); + + // The input string matches the regular expression. + @Nullable + abstract Pattern regEx(); + + // The input string has this substring. + @Nullable + abstract String contains(); + + // If true, exact/prefix/suffix matching should be case insensitive. + abstract boolean ignoreCase(); + + /** The input string should exactly matches the specified string. */ + public static StringMatcher forExact(String exact, boolean ignoreCase) { + checkNotNull(exact, "exact"); + return StringMatcher.create(exact, null, null, null, null, + ignoreCase); + } + + /** The input string should have the prefix. */ + public static StringMatcher forPrefix(String prefix, boolean ignoreCase) { + checkNotNull(prefix, "prefix"); + return StringMatcher.create(null, prefix, null, null, null, + ignoreCase); + } + + /** The input string should have the suffix. */ + public static StringMatcher forSuffix(String suffix, boolean ignoreCase) { + checkNotNull(suffix, "suffix"); + return StringMatcher.create(null, null, suffix, null, null, + ignoreCase); + } + + /** The input string should match this pattern. */ + public static StringMatcher forSafeRegEx(Pattern regEx) { + checkNotNull(regEx, "regEx"); + return StringMatcher.create(null, null, null, regEx, null, + false/* doesn't matter */); + } + + /** The input string should contain this substring. */ + public static StringMatcher forContains(String contains) { + checkNotNull(contains, "contains"); + return StringMatcher.create(null, null, null, null, contains, + false/* doesn't matter */); + } + + /** Returns the matching result for this string. */ + public boolean matches(String args) { + if (args == null) { + return false; + } + if (exact() != null) { + return ignoreCase() + ? exact().equalsIgnoreCase(args) + : exact().equals(args); + } else if (prefix() != null) { + return ignoreCase() + ? args.toLowerCase().startsWith(prefix().toLowerCase()) + : args.startsWith(prefix()); + } else if (suffix() != null) { + return ignoreCase() + ? args.toLowerCase().endsWith(suffix().toLowerCase()) + : args.endsWith(suffix()); + } else if (contains() != null) { + return args.contains(contains()); + } + return regEx().matches(args); + } + + private static StringMatcher create(@Nullable String exact, @Nullable String prefix, + @Nullable String suffix, @Nullable Pattern regEx, @Nullable String contains, + boolean ignoreCase) { + return new AutoValue_Matchers_StringMatcher(exact, prefix, suffix, regEx, contains, + ignoreCase); + } + } + + /** Matcher to evaluate whether an IPv4 or IPv6 address is within a CIDR range. */ + @AutoValue + public abstract static class CidrMatcher { + + abstract InetAddress addressPrefix(); + + abstract int prefixLen(); + + /** Returns matching result for this address. */ + public boolean matches(InetAddress address) { + if (address == null) { + return false; + } + byte[] cidr = addressPrefix().getAddress(); + byte[] addr = address.getAddress(); + if (addr.length != cidr.length) { + return false; + } + BigInteger cidrInt = new BigInteger(cidr); + BigInteger addrInt = new BigInteger(addr); + + int shiftAmount = 8 * cidr.length - prefixLen(); + + cidrInt = cidrInt.shiftRight(shiftAmount); + addrInt = addrInt.shiftRight(shiftAmount); + return cidrInt.equals(addrInt); + } + + /** Constructs a CidrMatcher with this prefix and prefix length. + * Do not provide string addressPrefix constructor to avoid IO exception handling. + * */ + public static CidrMatcher create(InetAddress addressPrefix, int prefixLen) { + return new AutoValue_Matchers_CidrMatcher(addressPrefix, prefixLen); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java index ac5ca9711fd..12eb6f6573f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/CertificateProviderRegistry.java @@ -40,8 +40,6 @@ public static synchronized CertificateProviderRegistry getInstance() { instance = new CertificateProviderRegistry(); // TODO(sanjaypujare): replace with Java's SPI mechanism and META-INF resource instance.register(new FileWatcherCertificateProviderProvider()); - instance.register(new DynamicReloadingCertificateProviderProvider()); - instance.register(new MeshCaCertificateProviderProvider()); } return instance; } diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProvider.java deleted file mode 100644 index af7324f2581..00000000000 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProvider.java +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.base.Preconditions.checkNotNull; - -import com.google.common.annotations.VisibleForTesting; -import io.grpc.InternalLogId; -import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.internal.TimeProvider; -import io.grpc.xds.internal.sds.trust.CertificateUtils; -import java.io.File; -import java.io.FileInputStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.security.PrivateKey; -import java.security.cert.X509Certificate; -import java.util.Arrays; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import java.util.logging.Level; -import java.util.logging.Logger; - -/** Implementation of {@link CertificateProvider} for dynamic reloading cert provider. */ -final class DynamicReloadingCertificateProvider extends CertificateProvider { - private static final Logger logger = - Logger.getLogger(DynamicReloadingCertificateProvider.class.getName()); - - private final SynchronizationContext syncContext; - private final ScheduledExecutorService scheduledExecutorService; - private final TimeProvider timeProvider; - private final Path directory; - private final String certFile; - private final String privateKeyFile; - private final String trustFile; - private final long refreshIntervalInSeconds; - @VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle; - private Path lastModifiedTarget; - - DynamicReloadingCertificateProvider( - DistributorWatcher watcher, - boolean notifyCertUpdates, - String directory, - String certFile, - String privateKeyFile, - String trustFile, - long refreshIntervalInSeconds, - ScheduledExecutorService scheduledExecutorService, - TimeProvider timeProvider) { - super(watcher, notifyCertUpdates); - this.scheduledExecutorService = - checkNotNull(scheduledExecutorService, "scheduledExecutorService"); - this.timeProvider = checkNotNull(timeProvider, "timeProvider"); - this.directory = Paths.get(checkNotNull(directory, "diretory")); - this.certFile = checkNotNull(certFile, "certFile"); - this.privateKeyFile = checkNotNull(privateKeyFile, "privateKeyFile"); - this.trustFile = checkNotNull(trustFile, "trustFile"); - this.refreshIntervalInSeconds = refreshIntervalInSeconds; - this.syncContext = createSynchronizationContext(directory); - } - - private SynchronizationContext createSynchronizationContext(String details) { - final InternalLogId logId = - InternalLogId.allocate("DynamicReloadingCertificateProvider", details); - return new SynchronizationContext( - new Thread.UncaughtExceptionHandler() { - private boolean panicMode; - - @Override - public void uncaughtException(Thread t, Throwable e) { - logger.log( - Level.SEVERE, - "[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!", - e); - panic(e); - } - - void panic(final Throwable t) { - if (panicMode) { - // Preserve the first panic information - return; - } - panicMode = true; - close(); - } - }); - } - - @Override - public void start() { - scheduleNextRefreshCertificate(/* delayInSeconds= */0); - } - - @Override - public void close() { - if (scheduledHandle != null) { - scheduledHandle.cancel(); - scheduledHandle = null; - } - getWatcher().close(); - } - - private void scheduleNextRefreshCertificate(long delayInSeconds) { - RefreshCertificateTask runnable = new RefreshCertificateTask(); - scheduledHandle = - syncContext.schedule(runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService); - } - - @VisibleForTesting - void checkAndReloadCertificates() { - try { - Path targetPath = Files.readSymbolicLink(directory); - if (targetPath.equals(lastModifiedTarget)) { - return; - } - try (FileInputStream privateKeyStream = - new FileInputStream(new File(targetPath.toFile(), privateKeyFile)); - FileInputStream certsStream = - new FileInputStream(new File(targetPath.toFile(), certFile)); - FileInputStream caCertsStream = - new FileInputStream(new File(targetPath.toFile(), trustFile))) { - PrivateKey privateKey = CertificateUtils.getPrivateKey(privateKeyStream); - X509Certificate[] certs = CertificateUtils.toX509Certificates(certsStream); - X509Certificate[] caCerts = CertificateUtils.toX509Certificates(caCertsStream); - getWatcher().updateCertificate(privateKey, Arrays.asList(certs)); - getWatcher().updateTrustedRoots(Arrays.asList(caCerts)); - } - lastModifiedTarget = targetPath; - } catch (Throwable t) { - generateErrorIfCurrentCertExpired(t); - } finally { - scheduleNextRefreshCertificate(refreshIntervalInSeconds); - } - } - - private void generateErrorIfCurrentCertExpired(Throwable t) { - X509Certificate currentCert = getWatcher().getLastIdentityCert(); - if (currentCert != null) { - long delaySeconds = computeDelaySecondsToCertExpiry(currentCert); - if (delaySeconds > refreshIntervalInSeconds) { - logger.log(Level.FINER, "reload certificate error", t); - return; - } - // The current cert is going to expire in less than {@link refreshIntervalInSeconds} - // Clear the current cert and notify our watchers thru {@code onError} - getWatcher().clearValues(); - } - getWatcher().onError(Status.fromThrowable(t)); - } - - @SuppressWarnings("JdkObsolete") - private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) { - checkNotNull(lastCert, "lastCert"); - return TimeUnit.NANOSECONDS.toSeconds( - TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime()) - - timeProvider.currentTimeNanos()); - } - - @VisibleForTesting - class RefreshCertificateTask implements Runnable { - @Override - public void run() { - checkAndReloadCertificates(); - } - } - - abstract static class Factory { - private static final Factory DEFAULT_INSTANCE = - new Factory() { - @Override - DynamicReloadingCertificateProvider create( - DistributorWatcher watcher, - boolean notifyCertUpdates, - String directory, - String certFile, - String privateKeyFile, - String trustFile, - long refreshIntervalInSeconds, - ScheduledExecutorService scheduledExecutorService, - TimeProvider timeProvider) { - return new DynamicReloadingCertificateProvider( - watcher, - notifyCertUpdates, - directory, - certFile, - privateKeyFile, - trustFile, - refreshIntervalInSeconds, - scheduledExecutorService, - timeProvider); - } - }; - - static Factory getInstance() { - return DEFAULT_INSTANCE; - } - - abstract DynamicReloadingCertificateProvider create( - DistributorWatcher watcher, - boolean notifyCertUpdates, - String directory, - String certFile, - String privateKeyFile, - String trustFile, - long refreshIntervalInSeconds, - ScheduledExecutorService scheduledExecutorService, - TimeProvider timeProvider); - } -} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderProvider.java deleted file mode 100644 index 0d1cf509220..00000000000 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderProvider.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.util.concurrent.ThreadFactoryBuilder; -import io.grpc.internal.JsonUtil; -import io.grpc.internal.TimeProvider; -import java.util.Map; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; - -/** - * Provider of {@link DynamicReloadingCertificateProvider}s. - */ -final class DynamicReloadingCertificateProviderProvider implements CertificateProviderProvider { - - private static final String DIRECTORY_KEY = "directory"; - private static final String CERT_FILE_KEY = "certificate-file"; - private static final String KEY_FILE_KEY = "private-key-file"; - private static final String ROOT_FILE_KEY = "ca-certificate-file"; - private static final String REFRESH_INTERVAL_KEY = "refresh-interval"; - - @VisibleForTesting static final long REFRESH_INTERVAL_DEFAULT = 600L; - - - static final String DYNAMIC_RELOADING_PROVIDER_NAME = "gke-cas-certs"; - - final DynamicReloadingCertificateProvider.Factory dynamicReloadingCertificateProviderFactory; - private final ScheduledExecutorServiceFactory scheduledExecutorServiceFactory; - private final TimeProvider timeProvider; - - DynamicReloadingCertificateProviderProvider() { - this( - DynamicReloadingCertificateProvider.Factory.getInstance(), - ScheduledExecutorServiceFactory.DEFAULT_INSTANCE, - TimeProvider.SYSTEM_TIME_PROVIDER); - } - - @VisibleForTesting - DynamicReloadingCertificateProviderProvider( - DynamicReloadingCertificateProvider.Factory dynamicReloadingCertificateProviderFactory, - ScheduledExecutorServiceFactory scheduledExecutorServiceFactory, - TimeProvider timeProvider) { - this.dynamicReloadingCertificateProviderFactory = dynamicReloadingCertificateProviderFactory; - this.scheduledExecutorServiceFactory = scheduledExecutorServiceFactory; - this.timeProvider = timeProvider; - } - - @Override - public String getName() { - return DYNAMIC_RELOADING_PROVIDER_NAME; - } - - @Override - public CertificateProvider createCertificateProvider( - Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) { - - Config configObj = validateAndTranslateConfig(config); - return dynamicReloadingCertificateProviderFactory.create( - watcher, - notifyCertUpdates, - configObj.directory, - configObj.certFile, - configObj.keyFile, - configObj.rootFile, - configObj.refrehInterval, - scheduledExecutorServiceFactory.create(), - timeProvider); - } - - private static String checkForNullAndGet(Map map, String key) { - return checkNotNull(JsonUtil.getString(map, key), "'" + key + "' is required in the config"); - } - - private static Config validateAndTranslateConfig(Object config) { - checkArgument(config instanceof Map, "Only Map supported for config"); - @SuppressWarnings("unchecked") Map map = (Map)config; - - Config configObj = new Config(); - configObj.directory = checkForNullAndGet(map, DIRECTORY_KEY); - configObj.certFile = checkForNullAndGet(map, CERT_FILE_KEY); - configObj.keyFile = checkForNullAndGet(map, KEY_FILE_KEY); - configObj.rootFile = checkForNullAndGet(map, ROOT_FILE_KEY); - configObj.refrehInterval = JsonUtil.getNumberAsLong(map, REFRESH_INTERVAL_KEY); - if (configObj.refrehInterval == null) { - configObj.refrehInterval = REFRESH_INTERVAL_DEFAULT; - } - return configObj; - } - - abstract static class ScheduledExecutorServiceFactory { - - private static final ScheduledExecutorServiceFactory DEFAULT_INSTANCE = - new ScheduledExecutorServiceFactory() { - - @Override - ScheduledExecutorService create() { - return Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder() - .setNameFormat("dynamicReloading" + "-%d") - .setDaemon(true) - .build()); - } - }; - - abstract ScheduledExecutorService create(); - } - - /** POJO class for storing various config values. */ - @VisibleForTesting - static class Config { - String directory; - String certFile; - String keyFile; - String rootFile; - Long refrehInterval; - } -} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java index bbcb521c0d5..b86de55766e 100644 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProvider.java @@ -82,6 +82,7 @@ public void start() { @Override public synchronized void close() { shutdown = true; + scheduledExecutorService.shutdownNow(); if (scheduledFuture != null) { scheduledFuture.cancel(true); scheduledFuture = null; diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java deleted file mode 100644 index dee649a613e..00000000000 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProvider.java +++ /dev/null @@ -1,500 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.Status.Code.ABORTED; -import static io.grpc.Status.Code.CANCELLED; -import static io.grpc.Status.Code.DEADLINE_EXCEEDED; -import static io.grpc.Status.Code.INTERNAL; -import static io.grpc.Status.Code.RESOURCE_EXHAUSTED; -import static io.grpc.Status.Code.UNAVAILABLE; -import static io.grpc.Status.Code.UNKNOWN; -import static java.nio.charset.StandardCharsets.UTF_8; - -import com.google.auth.oauth2.GoogleCredentials; -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; -import com.google.protobuf.Duration; -import com.google.security.meshca.v1.MeshCertificateRequest; -import com.google.security.meshca.v1.MeshCertificateResponse; -import com.google.security.meshca.v1.MeshCertificateServiceGrpc; -import io.grpc.CallOptions; -import io.grpc.Channel; -import io.grpc.ClientCall; -import io.grpc.ClientInterceptor; -import io.grpc.ForwardingClientCall; -import io.grpc.Grpc; -import io.grpc.InternalLogId; -import io.grpc.ManagedChannel; -import io.grpc.Metadata; -import io.grpc.MethodDescriptor; -import io.grpc.Status; -import io.grpc.SynchronizationContext; -import io.grpc.TlsChannelCredentials; -import io.grpc.auth.MoreCallCredentials; -import io.grpc.internal.BackoffPolicy; -import io.grpc.internal.TimeProvider; -import io.grpc.xds.internal.sds.trust.CertificateUtils; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.StringWriter; -import java.security.KeyPair; -import java.security.KeyPairGenerator; -import java.security.NoSuchAlgorithmException; -import java.security.cert.CertificateException; -import java.security.cert.X509Certificate; -import java.util.ArrayList; -import java.util.EnumSet; -import java.util.List; -import java.util.UUID; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; -import java.util.logging.Level; -import java.util.logging.Logger; -import javax.security.auth.x500.X500Principal; -import org.bouncycastle.openssl.jcajce.JcaPEMWriter; -import org.bouncycastle.operator.ContentSigner; -import org.bouncycastle.operator.OperatorCreationException; -import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; -import org.bouncycastle.pkcs.PKCS10CertificationRequest; -import org.bouncycastle.pkcs.PKCS10CertificationRequestBuilder; -import org.bouncycastle.pkcs.jcajce.JcaPKCS10CertificationRequestBuilder; -import org.bouncycastle.util.io.pem.PemObject; - -/** Implementation of {@link CertificateProvider} for the Google Mesh CA. */ -final class MeshCaCertificateProvider extends CertificateProvider { - private static final Logger logger = Logger.getLogger(MeshCaCertificateProvider.class.getName()); - - MeshCaCertificateProvider( - DistributorWatcher watcher, - boolean notifyCertUpdates, - String meshCaUrl, - String zone, - long validitySeconds, - int keySize, - String unused, //TODO(sanjaypujare): to remove during refactoring - String signatureAlg, MeshCaChannelFactory meshCaChannelFactory, - BackoffPolicy.Provider backoffPolicyProvider, - long renewalGracePeriodSeconds, - int maxRetryAttempts, - GoogleCredentials oauth2Creds, - ScheduledExecutorService scheduledExecutorService, - TimeProvider timeProvider, - long rpcTimeoutMillis) { - super(watcher, notifyCertUpdates); - this.meshCaUrl = checkNotNull(meshCaUrl, "meshCaUrl"); - checkArgument( - validitySeconds > INITIAL_DELAY_SECONDS, - "validitySeconds must be greater than " + INITIAL_DELAY_SECONDS); - this.validitySeconds = validitySeconds; - this.keySize = keySize; - this.signatureAlg = checkNotNull(signatureAlg, "signatureAlg"); - this.meshCaChannelFactory = checkNotNull(meshCaChannelFactory, "meshCaChannelFactory"); - this.backoffPolicyProvider = checkNotNull(backoffPolicyProvider, "backoffPolicyProvider"); - checkArgument( - renewalGracePeriodSeconds > 0L && renewalGracePeriodSeconds < validitySeconds, - "renewalGracePeriodSeconds should be between 0 and " + validitySeconds); - this.renewalGracePeriodSeconds = renewalGracePeriodSeconds; - checkArgument(maxRetryAttempts >= 0, "maxRetryAttempts must be >= 0"); - this.maxRetryAttempts = maxRetryAttempts; - this.oauth2Creds = checkNotNull(oauth2Creds, "oauth2Creds"); - this.scheduledExecutorService = - checkNotNull(scheduledExecutorService, "scheduledExecutorService"); - this.timeProvider = checkNotNull(timeProvider, "timeProvider"); - this.headerInterceptor = new ZoneInfoClientInterceptor(checkNotNull(zone, "zone")); - this.syncContext = createSynchronizationContext(meshCaUrl); - this.rpcTimeoutMillis = rpcTimeoutMillis; - } - - private SynchronizationContext createSynchronizationContext(String details) { - final InternalLogId logId = InternalLogId.allocate("MeshCaCertificateProvider", details); - return new SynchronizationContext( - new Thread.UncaughtExceptionHandler() { - private boolean panicMode; - - @Override - public void uncaughtException(Thread t, Throwable e) { - logger.log( - Level.SEVERE, - "[" + logId + "] Uncaught exception in the SynchronizationContext. Panic!", - e); - panic(e); - } - - void panic(final Throwable t) { - if (panicMode) { - // Preserve the first panic information - return; - } - panicMode = true; - close(); - } - }); - } - - @Override - public void start() { - scheduleNextRefreshCertificate(INITIAL_DELAY_SECONDS); - } - - @Override - public void close() { - if (scheduledHandle != null) { - scheduledHandle.cancel(); - scheduledHandle = null; - } - getWatcher().close(); - } - - private void scheduleNextRefreshCertificate(long delayInSeconds) { - if (scheduledHandle != null && scheduledHandle.isPending()) { - logger.log(Level.SEVERE, "Pending task found: inconsistent state in scheduledHandle!"); - scheduledHandle.cancel(); - } - RefreshCertificateTask runnable = new RefreshCertificateTask(); - scheduledHandle = syncContext.schedule( - runnable, delayInSeconds, TimeUnit.SECONDS, scheduledExecutorService); - } - - @VisibleForTesting - void refreshCertificate() - throws NoSuchAlgorithmException, IOException, OperatorCreationException { - long refreshDelaySeconds = computeRefreshSecondsFromCurrentCertExpiry(); - ManagedChannel channel = meshCaChannelFactory.createChannel(meshCaUrl); - try { - String uniqueReqIdForAllRetries = UUID.randomUUID().toString(); - Duration duration = Duration.newBuilder().setSeconds(validitySeconds).build(); - KeyPair keyPair = generateKeyPair(); - String csr = generateCsr(keyPair); - MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub = - createStubToMeshCa(channel); - List x509Chain = makeRequestWithRetries(stub, uniqueReqIdForAllRetries, - duration, csr); - if (x509Chain != null) { - refreshDelaySeconds = - computeDelaySecondsToCertExpiry(x509Chain.get(0)) - renewalGracePeriodSeconds; - getWatcher().updateCertificate(keyPair.getPrivate(), x509Chain); - getWatcher().updateTrustedRoots(ImmutableList.of(x509Chain.get(x509Chain.size() - 1))); - } - } finally { - shutdownChannel(channel); - scheduleNextRefreshCertificate(refreshDelaySeconds); - } - } - - private MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub createStubToMeshCa( - ManagedChannel channel) { - return MeshCertificateServiceGrpc - .newBlockingStub(channel) - .withCallCredentials(MoreCallCredentials.from(oauth2Creds)) - .withInterceptors(headerInterceptor); - } - - private List makeRequestWithRetries( - MeshCertificateServiceGrpc.MeshCertificateServiceBlockingStub stub, - String reqId, - Duration duration, - String csr) { - MeshCertificateRequest request = - MeshCertificateRequest.newBuilder() - .setValidity(duration) - .setCsr(csr) - .setRequestId(reqId) - .build(); - - BackoffPolicy backoffPolicy = backoffPolicyProvider.get(); - Throwable lastException = null; - for (int i = 0; i <= maxRetryAttempts; i++) { - try { - MeshCertificateResponse response = - stub.withDeadlineAfter(rpcTimeoutMillis, TimeUnit.MILLISECONDS) - .createCertificate(request); - return getX509CertificatesFromResponse(response); - } catch (Throwable t) { - if (!retriable(t)) { - generateErrorIfCurrentCertExpired(t); - return null; - } - lastException = t; - sleepForNanos(backoffPolicy.nextBackoffNanos()); - } - } - generateErrorIfCurrentCertExpired(lastException); - return null; - } - - private void sleepForNanos(long nanos) { - ScheduledFuture future = scheduledExecutorService.schedule(new Runnable() { - @Override - public void run() { - // do nothing - } - }, nanos, TimeUnit.NANOSECONDS); - try { - future.get(nanos, TimeUnit.NANOSECONDS); - } catch (InterruptedException ie) { - logger.log(Level.SEVERE, "Inside sleep", ie); - Thread.currentThread().interrupt(); - } catch (ExecutionException | TimeoutException ex) { - logger.log(Level.SEVERE, "Inside sleep", ex); - } - } - - private static boolean retriable(Throwable t) { - return RETRIABLE_CODES.contains(Status.fromThrowable(t).getCode()); - } - - private void generateErrorIfCurrentCertExpired(Throwable t) { - X509Certificate currentCert = getWatcher().getLastIdentityCert(); - if (currentCert != null) { - long delaySeconds = computeDelaySecondsToCertExpiry(currentCert); - if (delaySeconds > INITIAL_DELAY_SECONDS) { - return; - } - getWatcher().clearValues(); - } - getWatcher().onError(Status.fromThrowable(t)); - } - - private KeyPair generateKeyPair() throws NoSuchAlgorithmException { - KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA"); - keyPairGenerator.initialize(keySize); - return keyPairGenerator.generateKeyPair(); - } - - private String generateCsr(KeyPair pair) throws IOException, OperatorCreationException { - PKCS10CertificationRequestBuilder p10Builder = - new JcaPKCS10CertificationRequestBuilder( - new X500Principal("CN=EXAMPLE.COM"), pair.getPublic()); - JcaContentSignerBuilder csBuilder = new JcaContentSignerBuilder(signatureAlg); - ContentSigner signer = csBuilder.build(pair.getPrivate()); - PKCS10CertificationRequest csr = p10Builder.build(signer); - PemObject pemObject = new PemObject("NEW CERTIFICATE REQUEST", csr.getEncoded()); - try (StringWriter str = new StringWriter()) { - try (JcaPEMWriter pemWriter = new JcaPEMWriter(str)) { - pemWriter.writeObject(pemObject); - } - return str.toString(); - } - } - - /** Compute refresh interval as half of interval to current cert expiry. */ - private long computeRefreshSecondsFromCurrentCertExpiry() { - X509Certificate lastCert = getWatcher().getLastIdentityCert(); - if (lastCert == null) { - return INITIAL_DELAY_SECONDS; - } - long delayToCertExpirySeconds = computeDelaySecondsToCertExpiry(lastCert) / 2; - return Math.max(delayToCertExpirySeconds, INITIAL_DELAY_SECONDS); - } - - @SuppressWarnings("JdkObsolete") - private long computeDelaySecondsToCertExpiry(X509Certificate lastCert) { - checkNotNull(lastCert, "lastCert"); - return TimeUnit.NANOSECONDS.toSeconds( - TimeUnit.MILLISECONDS.toNanos(lastCert.getNotAfter().getTime()) - timeProvider - .currentTimeNanos()); - } - - private static void shutdownChannel(ManagedChannel channel) { - channel.shutdown(); - try { - channel.awaitTermination(10, TimeUnit.SECONDS); - } catch (InterruptedException ex) { - logger.log(Level.SEVERE, "awaiting channel Termination", ex); - channel.shutdownNow(); - Thread.currentThread().interrupt(); - } - } - - private List getX509CertificatesFromResponse( - MeshCertificateResponse response) throws CertificateException, IOException { - List certChain = response.getCertChainList(); - List x509Chain = new ArrayList<>(certChain.size()); - for (String certString : certChain) { - try (ByteArrayInputStream bais = new ByteArrayInputStream(certString.getBytes(UTF_8))) { - x509Chain.add(CertificateUtils.toX509Certificate(bais)); - } - } - return x509Chain; - } - - @VisibleForTesting - class RefreshCertificateTask implements Runnable { - @Override - public void run() { - try { - refreshCertificate(); - } catch (NoSuchAlgorithmException | OperatorCreationException | IOException ex) { - logger.log(Level.SEVERE, "refreshing certificate", ex); - } - } - } - - /** Factory for creating channels to MeshCA sever. */ - abstract static class MeshCaChannelFactory { - - private static final MeshCaChannelFactory DEFAULT_INSTANCE = - new MeshCaChannelFactory() { - - /** Creates a channel to the URL in the given list. */ - @Override - ManagedChannel createChannel(String serverUri) { - checkArgument(serverUri != null && !serverUri.isEmpty(), "serverUri is null/empty!"); - logger.log(Level.INFO, "Creating channel to {0}", serverUri); - - return Grpc.newChannelBuilder(serverUri, TlsChannelCredentials.create()) - .keepAliveTime(1, TimeUnit.MINUTES) - .build(); - } - }; - - static MeshCaChannelFactory getInstance() { - return DEFAULT_INSTANCE; - } - - /** - * Creates a channel to the server. - */ - abstract ManagedChannel createChannel(String serverUri); - } - - /** Factory for creating channels to MeshCA sever. */ - abstract static class Factory { - private static final Factory DEFAULT_INSTANCE = - new Factory() { - - @Override - MeshCaCertificateProvider create( - DistributorWatcher watcher, - boolean notifyCertUpdates, - String meshCaUrl, - String zone, - long validitySeconds, - int keySize, - String alg, - String signatureAlg, - MeshCaChannelFactory meshCaChannelFactory, - BackoffPolicy.Provider backoffPolicyProvider, - long renewalGracePeriodSeconds, - int maxRetryAttempts, - GoogleCredentials oauth2Creds, - ScheduledExecutorService scheduledExecutorService, - TimeProvider timeProvider, - long rpcTimeoutMillis) { - return new MeshCaCertificateProvider( - watcher, - notifyCertUpdates, - meshCaUrl, - zone, - validitySeconds, - keySize, - alg, - signatureAlg, - meshCaChannelFactory, - backoffPolicyProvider, - renewalGracePeriodSeconds, - maxRetryAttempts, - oauth2Creds, - scheduledExecutorService, - timeProvider, - rpcTimeoutMillis); - } - }; - - static Factory getInstance() { - return DEFAULT_INSTANCE; - } - - abstract MeshCaCertificateProvider create( - DistributorWatcher watcher, - boolean notifyCertUpdates, - String meshCaUrl, - String zone, - long validitySeconds, - int keySize, - String alg, - String signatureAlg, - MeshCaChannelFactory meshCaChannelFactory, - BackoffPolicy.Provider backoffPolicyProvider, - long renewalGracePeriodSeconds, - int maxRetryAttempts, - GoogleCredentials oauth2Creds, - ScheduledExecutorService scheduledExecutorService, - TimeProvider timeProvider, - long rpcTimeoutMillis); - } - - private class ZoneInfoClientInterceptor implements ClientInterceptor { - private final String zone; - - ZoneInfoClientInterceptor(String zone) { - this.zone = zone; - } - - @Override - public ClientCall interceptCall( - MethodDescriptor method, CallOptions callOptions, Channel next) { - return new ForwardingClientCall.SimpleForwardingClientCall( - next.newCall(method, callOptions)) { - - @Override - public void start(Listener responseListener, Metadata headers) { - headers.put(KEY_FOR_ZONE_INFO, "location=locations/" + zone); - super.start(responseListener, headers); - } - }; - } - } - - @VisibleForTesting - static final Metadata.Key KEY_FOR_ZONE_INFO = - Metadata.Key.of("x-goog-request-params", Metadata.ASCII_STRING_MARSHALLER); - @VisibleForTesting - static final long INITIAL_DELAY_SECONDS = 4L; - - private static final EnumSet RETRIABLE_CODES = - EnumSet.of( - CANCELLED, - UNKNOWN, - DEADLINE_EXCEEDED, - RESOURCE_EXHAUSTED, - ABORTED, - INTERNAL, - UNAVAILABLE); - - private final SynchronizationContext syncContext; - private final ScheduledExecutorService scheduledExecutorService; - private final int maxRetryAttempts; - private final ZoneInfoClientInterceptor headerInterceptor; - private final BackoffPolicy.Provider backoffPolicyProvider; - private final String meshCaUrl; - private final long validitySeconds; - private final long renewalGracePeriodSeconds; - private final int keySize; - private final String signatureAlg; - private final GoogleCredentials oauth2Creds; - private final TimeProvider timeProvider; - private final MeshCaChannelFactory meshCaChannelFactory; - @VisibleForTesting SynchronizationContext.ScheduledHandle scheduledHandle; - private final long rpcTimeoutMillis; -} diff --git a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java b/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java deleted file mode 100644 index a605f15ae62..00000000000 --- a/xds/src/main/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProvider.java +++ /dev/null @@ -1,286 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Preconditions.checkState; -import static io.grpc.internal.JsonUtil.getObject; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.util.concurrent.ThreadFactoryBuilder; -import io.grpc.internal.BackoffPolicy; -import io.grpc.internal.ExponentialBackoffPolicy; -import io.grpc.internal.JsonUtil; -import io.grpc.internal.TimeProvider; -import io.grpc.xds.internal.sts.StsCredentials; -import java.util.List; -import java.util.Map; -import java.util.concurrent.Executors; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -/** - * Provider of {@link CertificateProvider}s. Implemented by the implementer of the plugin. We may - * move this out of the internal package and make this an official API in the future. - */ -final class MeshCaCertificateProviderProvider implements CertificateProviderProvider { - - private static final String SERVER_CONFIG_KEY = "server"; - private static final String MESHCA_URL_KEY = "target_uri"; - private static final String RPC_TIMEOUT_SECONDS_KEY = "time_out"; - private static final String GKECLUSTER_URL_KEY = "location"; - private static final String CERT_VALIDITY_SECONDS_KEY = "certificate_lifetime"; - private static final String RENEWAL_GRACE_PERIOD_SECONDS_KEY = "renewal_grace_period"; - private static final String KEY_ALGO_KEY = "key_type"; // aka keyType - private static final String KEY_SIZE_KEY = "key_size"; - private static final String STS_SERVICE_KEY = "sts_service"; - private static final String TOKEN_EXCHANGE_SERVICE_KEY = "token_exchange_service"; - private static final String GKE_SA_JWT_LOCATION_KEY = "subject_token_path"; - - @VisibleForTesting static final String MESHCA_URL_DEFAULT = "meshca.googleapis.com"; - @VisibleForTesting static final long RPC_TIMEOUT_SECONDS_DEFAULT = 5L; - @VisibleForTesting static final long CERT_VALIDITY_SECONDS_DEFAULT = 9L * 3600L; - @VisibleForTesting static final long RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT = 1L * 3600L; - @VisibleForTesting static final String KEY_ALGO_DEFAULT = "RSA"; // aka keyType - @VisibleForTesting static final int KEY_SIZE_DEFAULT = 2048; - @VisibleForTesting static final String SIGNATURE_ALGO_DEFAULT = "SHA256withRSA"; - @VisibleForTesting static final int MAX_RETRY_ATTEMPTS_DEFAULT = 3; - @VisibleForTesting - static final String STS_URL_DEFAULT = "https://securetoken.googleapis.com/v1/identitybindingtoken"; - - @VisibleForTesting - static final long RPC_TIMEOUT_SECONDS = 10L; - - private static final Pattern CLUSTER_URL_PATTERN = Pattern - .compile(".*/projects/(.*)/(?:locations|zones)/(.*)/clusters/.*"); - - private static final String TRUST_DOMAIN_SUFFIX = ".svc.id.goog"; - private static final String AUDIENCE_PREFIX = "identitynamespace:"; - static final String MESH_CA_NAME = "meshCA"; - - final StsCredentials.Factory stsCredentialsFactory; - final MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory; - final BackoffPolicy.Provider backoffPolicyProvider; - final MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory; - final ScheduledExecutorServiceFactory scheduledExecutorServiceFactory; - final TimeProvider timeProvider; - - MeshCaCertificateProviderProvider() { - this( - StsCredentials.Factory.getInstance(), - MeshCaCertificateProvider.MeshCaChannelFactory.getInstance(), - new ExponentialBackoffPolicy.Provider(), - MeshCaCertificateProvider.Factory.getInstance(), - ScheduledExecutorServiceFactory.DEFAULT_INSTANCE, - TimeProvider.SYSTEM_TIME_PROVIDER); - } - - @VisibleForTesting - MeshCaCertificateProviderProvider( - StsCredentials.Factory stsCredentialsFactory, - MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory, - BackoffPolicy.Provider backoffPolicyProvider, - MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory, - ScheduledExecutorServiceFactory scheduledExecutorServiceFactory, - TimeProvider timeProvider) { - this.stsCredentialsFactory = stsCredentialsFactory; - this.meshCaChannelFactory = meshCaChannelFactory; - this.backoffPolicyProvider = backoffPolicyProvider; - this.meshCaCertificateProviderFactory = meshCaCertificateProviderFactory; - this.scheduledExecutorServiceFactory = scheduledExecutorServiceFactory; - this.timeProvider = timeProvider; - } - - @Override - public String getName() { - return MESH_CA_NAME; - } - - @Override - public CertificateProvider createCertificateProvider( - Object config, CertificateProvider.DistributorWatcher watcher, boolean notifyCertUpdates) { - - Config configObj = validateAndTranslateConfig(config); - - // Construct audience from project and gkeClusterUrl - String audience = - AUDIENCE_PREFIX + configObj.project + TRUST_DOMAIN_SUFFIX + ":" + configObj.gkeClusterUrl; - StsCredentials stsCredentials = stsCredentialsFactory - .create(configObj.stsUrl, audience, configObj.gkeSaJwtLocation); - - return meshCaCertificateProviderFactory.create( - watcher, - notifyCertUpdates, - configObj.meshCaUrl, - configObj.zone, - configObj.certValiditySeconds, - configObj.keySize, - configObj.keyAlgo, - configObj.signatureAlgo, - meshCaChannelFactory, - backoffPolicyProvider, - configObj.renewalGracePeriodSeconds, - configObj.maxRetryAttempts, - stsCredentials, - scheduledExecutorServiceFactory.create(configObj.meshCaUrl), - timeProvider, - TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS)); - } - - private static Config validateAndTranslateConfig(Object config) { - // TODO(sanjaypujare): add support for string, struct proto etc - checkArgument(config instanceof Map, "Only Map supported for config"); - @SuppressWarnings("unchecked") Map map = (Map)config; - - Config configObj = new Config(); - extractMeshCaServerConfig(configObj, getObject(map, SERVER_CONFIG_KEY)); - configObj.certValiditySeconds = - getSeconds( - JsonUtil.getObject(map, CERT_VALIDITY_SECONDS_KEY), CERT_VALIDITY_SECONDS_DEFAULT); - configObj.renewalGracePeriodSeconds = - getSeconds( - JsonUtil.getObject(map, RENEWAL_GRACE_PERIOD_SECONDS_KEY), - RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT); - String keyType = JsonUtil.getString(map, KEY_ALGO_KEY); - checkArgument( - keyType == null || keyType.equals(KEY_ALGO_DEFAULT), "key_type can only be null or 'RSA'"); - // TODO: remove signatureAlgo, keyType (or keyAlgo), maxRetryAttempts - configObj.maxRetryAttempts = MAX_RETRY_ATTEMPTS_DEFAULT; - configObj.keyAlgo = KEY_ALGO_DEFAULT; - configObj.signatureAlgo = SIGNATURE_ALGO_DEFAULT; - configObj.keySize = JsonUtil.getNumberAsInteger(map, KEY_SIZE_KEY); - if (configObj.keySize == null) { - configObj.keySize = KEY_SIZE_DEFAULT; - } - configObj.gkeClusterUrl = - checkNotNull(JsonUtil.getString(map, GKECLUSTER_URL_KEY), - "'location' is required in the config"); - parseProjectAndZone(configObj.gkeClusterUrl, configObj); - return configObj; - } - - private static void extractMeshCaServerConfig(Config configObj, Map serverConfig) { - // init with defaults - configObj.meshCaUrl = MESHCA_URL_DEFAULT; - configObj.rpcTimeoutSeconds = RPC_TIMEOUT_SECONDS_DEFAULT; - configObj.stsUrl = STS_URL_DEFAULT; - if (serverConfig != null) { - checkArgument( - "GRPC".equals(JsonUtil.getString(serverConfig, "api_type")), - "Only GRPC api_type supported"); - List> grpcServices = - checkNotNull( - JsonUtil.getListOfObjects(serverConfig, "grpc_services"), "grpc_services not found"); - for (Map grpcService : grpcServices) { - Map googleGrpcConfig = JsonUtil.getObject(grpcService, "google_grpc"); - if (googleGrpcConfig != null) { - String value = JsonUtil.getString(googleGrpcConfig, MESHCA_URL_KEY); - if (value != null) { - configObj.meshCaUrl = value; - } - Map channelCreds = - JsonUtil.getObject(googleGrpcConfig, "channel_credentials"); - if (channelCreds != null) { - Map googleDefaultChannelCreds = - checkNotNull( - JsonUtil.getObject(channelCreds, "google_default"), - "channel_credentials need to be google_default!"); - checkArgument( - googleDefaultChannelCreds.isEmpty(), - "google_default credentials contain illegal value"); - } - List> callCreds = - JsonUtil.getListOfObjects(googleGrpcConfig, "call_credentials"); - for (Map callCred : callCreds) { - Map stsCreds = JsonUtil.getObject(callCred, STS_SERVICE_KEY); - if (stsCreds != null) { - value = JsonUtil.getString(stsCreds, TOKEN_EXCHANGE_SERVICE_KEY); - if (value != null) { - configObj.stsUrl = value; - } - configObj.gkeSaJwtLocation = JsonUtil.getString(stsCreds, GKE_SA_JWT_LOCATION_KEY); - } - } - configObj.rpcTimeoutSeconds = - getSeconds( - JsonUtil.getObject(grpcService, RPC_TIMEOUT_SECONDS_KEY), - RPC_TIMEOUT_SECONDS_DEFAULT); - } - } - } - // check required value(s) - checkNotNull(configObj.gkeSaJwtLocation, "'subject_token_path' is required in the config"); - } - - private static Long getSeconds(Map duration, long defaultValue) { - if (duration != null) { - return JsonUtil.getNumberAsLong(duration, "seconds"); - } - return defaultValue; - } - - private static void parseProjectAndZone(String gkeClusterUrl, Config configObj) { - Matcher matcher = CLUSTER_URL_PATTERN.matcher(gkeClusterUrl); - checkState(matcher.find(), "gkeClusterUrl does not have correct format"); - checkState(matcher.groupCount() == 2, "gkeClusterUrl does not have project and location parts"); - configObj.project = matcher.group(1); - configObj.zone = matcher.group(2); - } - - abstract static class ScheduledExecutorServiceFactory { - - private static final ScheduledExecutorServiceFactory DEFAULT_INSTANCE = - new ScheduledExecutorServiceFactory() { - - @Override - ScheduledExecutorService create(String serverUri) { - return Executors.newSingleThreadScheduledExecutor( - new ThreadFactoryBuilder() - .setNameFormat("meshca-" + serverUri + "-%d") - .setDaemon(true) - .build()); - } - }; - - static ScheduledExecutorServiceFactory getInstance() { - return DEFAULT_INSTANCE; - } - - abstract ScheduledExecutorService create(String serverUri); - } - - /** POJO class for storing various config values. */ - @VisibleForTesting - static class Config { - String meshCaUrl; - Long rpcTimeoutSeconds; - String gkeClusterUrl; - Long certValiditySeconds; - Long renewalGracePeriodSeconds; - String keyAlgo; // aka keyType - Integer keySize; - String signatureAlgo; - Integer maxRetryAttempts; - String stsUrl; - String gkeSaJwtLocation; - String zone; - String project; - } -} diff --git a/xds/src/main/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngine.java b/xds/src/main/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngine.java new file mode 100644 index 00000000000..6d275d322a2 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngine.java @@ -0,0 +1,433 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.rbac.engine; + +import static com.google.common.base.Preconditions.checkNotNull; + +import com.google.auto.value.AutoValue; +import com.google.common.base.Joiner; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.xds.internal.Matchers; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.security.cert.Certificate; +import java.security.cert.CertificateParsingException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.annotation.Nullable; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; + +/** + * Implementation of gRPC server access control based on envoy RBAC protocol: + * https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/rbac/v3/rbac.proto + * + *

One GrpcAuthorizationEngine is initialized with one action type and a list of policies. + * Policies are examined sequentially in order in an any match fashion, and the first matched policy + * will be returned. If not matched at all, the opposite action type is returned as a result. + */ +public final class GrpcAuthorizationEngine { + private static final Logger log = Logger.getLogger(GrpcAuthorizationEngine.class.getName()); + + private final AuthConfig authConfig; + + /** Instantiated with envoy policyMatcher configuration. */ + public GrpcAuthorizationEngine(AuthConfig authConfig) { + this.authConfig = authConfig; + } + + /** Return the auth decision for the request argument against the policies. */ + public AuthDecision evaluate(Metadata metadata, ServerCall serverCall) { + checkNotNull(metadata, "metadata"); + checkNotNull(serverCall, "serverCall"); + String firstMatch = null; + EvaluateArgs args = new EvaluateArgs(metadata, serverCall); + for (PolicyMatcher policyMatcher : authConfig.policies) { + if (policyMatcher.matches(args)) { + firstMatch = policyMatcher.name; + break; + } + } + Action decisionType = Action.DENY; + if (Action.DENY.equals(authConfig.action) == (firstMatch == null)) { + decisionType = Action.ALLOW; + } + log.log(Level.FINER, "RBAC decision: {0}, policy match: {1}.", + new Object[]{decisionType, firstMatch}); + return AuthDecision.create(decisionType, firstMatch); + } + + public enum Action { + ALLOW, + DENY, + } + + /** + * An authorization decision provides information about the decision type and the policy name + * identifier based on the authorization engine evaluation. */ + @AutoValue + public abstract static class AuthDecision { + public abstract Action decision(); + + @Nullable + public abstract String matchingPolicyName(); + + static AuthDecision create(Action decisionType, @Nullable String matchingPolicy) { + return new AutoValue_GrpcAuthorizationEngine_AuthDecision(decisionType, matchingPolicy); + } + } + + /** Represents authorization config policy that the engine will evaluate against. */ + public static final class AuthConfig { + private final List policies; + private final Action action; + + public AuthConfig(List policies, Action action) { + this.policies = Collections.unmodifiableList(new ArrayList<>(policies)); + this.action = action; + } + } + + /** + * Implements a top level {@link Matcher} for a single RBAC policy configuration per envoy + * protocol: + * https://www.envoyproxy.io/docs/envoy/latest/api-v3/config/rbac/v3/rbac.proto#config-rbac-v3-policy. + * + *

Currently we only support matching some of the request fields. Those unsupported fields are + * considered not match until we stop ignoring them. + */ + public static final class PolicyMatcher implements Matcher { + private final OrMatcher permissions; + private final OrMatcher principals; + private final String name; + + /** Constructs a matcher for one RBAC policy. */ + public PolicyMatcher(String name, OrMatcher permissions, OrMatcher principals) { + this.name = name; + this.permissions = permissions; + this.principals = principals; + } + + @Override + public boolean matches(EvaluateArgs args) { + return permissions.matches(args) && principals.matches(args); + } + } + + public static final class AuthenticatedMatcher implements Matcher { + private final Matchers.StringMatcher delegate; + + /** + * Passing in null will match all authenticated user, i.e. SSL session is present. + * https://github.com/envoyproxy/envoy/blob/main/api/envoy/config/rbac/v3/rbac.proto#L240 + * */ + public AuthenticatedMatcher(@Nullable Matchers.StringMatcher delegate) { + this.delegate = delegate; + } + + @Override + public boolean matches(EvaluateArgs args) { + Collection principalNames = args.getPrincipalNames(); + log.log(Level.FINER, "Matching principal names: {0}", new Object[]{principalNames}); + // Null means unauthenticated connection. + if (principalNames == null) { + return false; + } + // Connection is authenticated, so returns match when delegated string matcher is not present. + if (delegate == null) { + return true; + } + for (String name : principalNames) { + if (delegate.matches(name)) { + return true; + } + } + return false; + } + } + + public static final class DestinationIpMatcher implements Matcher { + private final Matchers.CidrMatcher delegate; + + public DestinationIpMatcher(Matchers.CidrMatcher delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public boolean matches(EvaluateArgs args) { + return delegate.matches(args.getDestinationIp()); + } + } + + public static final class SourceIpMatcher implements Matcher { + private final Matchers.CidrMatcher delegate; + + public SourceIpMatcher(Matchers.CidrMatcher delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public boolean matches(EvaluateArgs args) { + return delegate.matches(args.getSourceIp()); + } + } + + public static final class PathMatcher implements Matcher { + private final Matchers.StringMatcher delegate; + + public PathMatcher(Matchers.StringMatcher delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public boolean matches(EvaluateArgs args) { + return delegate.matches(args.getPath()); + } + } + + public static final class AuthHeaderMatcher implements Matcher { + private final Matchers.HeaderMatcher delegate; + + public AuthHeaderMatcher(Matchers.HeaderMatcher delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public boolean matches(EvaluateArgs args) { + return delegate.matches(args.getHeader(delegate.name())); + } + } + + public static final class DestinationPortMatcher implements Matcher { + private final int port; + + public DestinationPortMatcher(int port) { + this.port = port; + } + + @Override + public boolean matches(EvaluateArgs args) { + return port == args.getDestinationPort(); + } + } + + public static final class RequestedServerNameMatcher implements Matcher { + private final Matchers.StringMatcher delegate; + + public RequestedServerNameMatcher(Matchers.StringMatcher delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public boolean matches(EvaluateArgs args) { + return delegate.matches(args.getRequestedServerName()); + } + } + + private static final class EvaluateArgs { + private final Metadata metadata; + private final ServerCall serverCall; + // https://github.com/envoyproxy/envoy/blob/63619d578e1abe0c1725ea28ba02f361466662e1/api/envoy/config/rbac/v3/rbac.proto#L238-L240 + private static final int URI_SAN = 6; + private static final int DNS_SAN = 2; + + private EvaluateArgs(Metadata metadata, ServerCall serverCall) { + this.metadata = metadata; + this.serverCall = serverCall; + } + + private String getPath() { + return "/" + serverCall.getMethodDescriptor().getFullMethodName(); + } + + /** + * Returns null for unauthenticated connection. + * Returns empty string collection if no valid certificate and no + * principal names we are interested in. + * https://github.com/envoyproxy/envoy/blob/0fae6970ddaf93f024908ba304bbd2b34e997a51/envoy/ssl/connection.h#L70 + */ + @Nullable + private Collection getPrincipalNames() { + SSLSession sslSession = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_SSL_SESSION); + if (sslSession == null) { + return null; + } + try { + Certificate[] certs = sslSession.getPeerCertificates(); + if (certs == null || certs.length < 1) { + return Collections.singleton(""); + } + X509Certificate cert = (X509Certificate)certs[0]; + if (cert == null) { + return Collections.singleton(""); + } + Collection> names = cert.getSubjectAlternativeNames(); + List principalNames = new ArrayList<>(); + if (names != null) { + for (List name : names) { + if (URI_SAN == (Integer) name.get(0)) { + principalNames.add((String) name.get(1)); + } + } + if (!principalNames.isEmpty()) { + return Collections.unmodifiableCollection(principalNames); + } + for (List name : names) { + if (DNS_SAN == (Integer) name.get(0)) { + principalNames.add((String) name.get(1)); + } + } + if (!principalNames.isEmpty()) { + return Collections.unmodifiableCollection(principalNames); + } + } + if (cert.getSubjectDN() == null || cert.getSubjectDN().getName() == null) { + return Collections.singleton(""); + } + return Collections.singleton(cert.getSubjectDN().getName()); + } catch (SSLPeerUnverifiedException | CertificateParsingException ex) { + log.log(Level.FINE, "Unexpected getPrincipalNames error.", ex); + return Collections.singleton(""); + } + } + + @Nullable + private String getHeader(String headerName) { + if (headerName.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + return null; + } + Metadata.Key key; + try { + key = Metadata.Key.of(headerName, Metadata.ASCII_STRING_MARSHALLER); + } catch (IllegalArgumentException e) { + return null; + } + Iterable values = metadata.getAll(key); + return values == null ? null : Joiner.on(",").join(values); + } + + private InetAddress getDestinationIp() { + SocketAddress addr = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR); + return addr == null ? null : ((InetSocketAddress) addr).getAddress(); + } + + private InetAddress getSourceIp() { + SocketAddress addr = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + return addr == null ? null : ((InetSocketAddress) addr).getAddress(); + } + + private int getDestinationPort() { + SocketAddress addr = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR); + return addr == null ? -1 : ((InetSocketAddress) addr).getPort(); + } + + private String getRequestedServerName() { + return ""; + } + } + + public interface Matcher { + boolean matches(EvaluateArgs args); + } + + public static final class OrMatcher implements Matcher { + private final List anyMatch; + + /** Matches when any of the matcher matches. */ + public OrMatcher(List matchers) { + checkNotNull(matchers, "matchers"); + for (Matcher matcher : matchers) { + checkNotNull(matcher, "matcher"); + } + this.anyMatch = Collections.unmodifiableList(new ArrayList<>(matchers)); + } + + public static OrMatcher create(Matcher...matchers) { + return new OrMatcher(Arrays.asList(matchers)); + } + + @Override + public boolean matches(EvaluateArgs args) { + for (Matcher m : anyMatch) { + if (m.matches(args)) { + return true; + } + } + return false; + } + } + + public static final class AndMatcher implements Matcher { + private final List allMatch; + + /** Matches when all of the matchers match. */ + public AndMatcher(List matchers) { + checkNotNull(matchers, "matchers"); + for (Matcher matcher : matchers) { + checkNotNull(matcher, "matcher"); + } + this.allMatch = Collections.unmodifiableList(new ArrayList<>(matchers)); + } + + public static AndMatcher create(Matcher...matchers) { + return new AndMatcher(Arrays.asList(matchers)); + } + + @Override + public boolean matches(EvaluateArgs args) { + for (Matcher m : allMatch) { + if (!m.matches(args)) { + return false; + } + } + return true; + } + } + + /** Always true matcher.*/ + public static final class AlwaysTrueMatcher implements Matcher { + public static AlwaysTrueMatcher INSTANCE = new AlwaysTrueMatcher(); + + @Override + public boolean matches(EvaluateArgs args) { + return true; + } + } + + /** Negate matcher.*/ + public static final class InvertMatcher implements Matcher { + private final Matcher toInvertMatcher; + + public InvertMatcher(Matcher matcher) { + this.toInvertMatcher = checkNotNull(matcher, "matcher"); + } + + @Override + public boolean matches(EvaluateArgs args) { + return !toInvertMatcher.matches(args); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java index 9b339cdcabb..b232ee2707d 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactory.java @@ -19,9 +19,8 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import io.grpc.xds.Bootstrapper; +import io.grpc.xds.Bootstrapper.BootstrapInfo; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; -import io.grpc.xds.XdsInitializationException; import io.grpc.xds.internal.certprovider.CertProviderClientSslContextProvider; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; import java.util.concurrent.Executors; @@ -30,18 +29,17 @@ final class ClientSslContextProviderFactory implements ValueFactory { - private final Bootstrapper bootstrapper; - private Bootstrapper.BootstrapInfo bootstrapInfo; + private BootstrapInfo bootstrapInfo; private final CertProviderClientSslContextProvider.Factory certProviderClientSslContextProviderFactory; - ClientSslContextProviderFactory(Bootstrapper bootstrapper) { - this(bootstrapper, CertProviderClientSslContextProvider.Factory.getInstance()); + ClientSslContextProviderFactory(BootstrapInfo bootstrapInfo) { + this(bootstrapInfo, CertProviderClientSslContextProvider.Factory.getInstance()); } ClientSslContextProviderFactory( - Bootstrapper bootstrapper, CertProviderClientSslContextProvider.Factory factory) { - this.bootstrapper = bootstrapper; + BootstrapInfo bootstrapInfo, CertProviderClientSslContextProvider.Factory factory) { + this.bootstrapInfo = bootstrapInfo; this.certProviderClientSslContextProviderFactory = factory; } @@ -53,38 +51,24 @@ public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) { upstreamTlsContext.getCommonTlsContext(), "upstreamTlsContext should have CommonTlsContext"); if (CommonTlsContextUtil.hasCertProviderInstance( - upstreamTlsContext.getCommonTlsContext())) { - try { - if (bootstrapInfo == null) { - bootstrapInfo = bootstrapper.bootstrap(); - } - return certProviderClientSslContextProviderFactory.getProvider( - upstreamTlsContext, - bootstrapInfo.getNode().toEnvoyProtoNode(), - bootstrapInfo.getCertProviders()); - } catch (XdsInitializationException e) { - throw new RuntimeException(e); - } + upstreamTlsContext.getCommonTlsContext())) { + return certProviderClientSslContextProviderFactory.getProvider( + upstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNode(), + bootstrapInfo.getCertProviders()); } else if (CommonTlsContextUtil.hasAllSecretsUsingFilename( upstreamTlsContext.getCommonTlsContext())) { return SecretVolumeClientSslContextProvider.getProvider(upstreamTlsContext); } else if (CommonTlsContextUtil.hasAllSecretsUsingSds( upstreamTlsContext.getCommonTlsContext())) { - try { - if (bootstrapInfo == null) { - bootstrapInfo = bootstrapper.bootstrap(); - } - return SdsClientSslContextProvider.getProvider( - upstreamTlsContext, - bootstrapInfo.getNode().toEnvoyProtoNodeV2(), - Executors.newSingleThreadExecutor(new ThreadFactoryBuilder() - .setNameFormat("client-sds-sslcontext-provider-%d") - .setDaemon(true) - .build()), - /* channelExecutor= */ null); - } catch (XdsInitializationException e) { - throw new RuntimeException(e); - } + return SdsClientSslContextProvider.getProvider( + upstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNodeV2(), + Executors.newSingleThreadExecutor(new ThreadFactoryBuilder() + .setNameFormat("client-sds-sslcontext-provider-%d") + .setDaemon(true) + .build()), + /* channelExecutor= */ null); } throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!"); } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java index da464afbb17..37161325746 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SdsProtocolNegotiators.java @@ -23,15 +23,11 @@ import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; import io.grpc.netty.GrpcHttp2ConnectionHandler; -import io.grpc.netty.InternalNettyChannelBuilder; -import io.grpc.netty.InternalNettyChannelBuilder.ProtocolNegotiatorFactory; import io.grpc.netty.InternalProtocolNegotiationEvent; import io.grpc.netty.InternalProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; -import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.ProtocolNegotiationEvent; -import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.InternalXdsAttributes; import io.grpc.xds.XdsClientWrapperForServerSds; import io.netty.channel.ChannelHandler; @@ -66,18 +62,7 @@ private SdsProtocolNegotiators() { private static final AsciiString SCHEME = AsciiString.of("http"); /** - * Returns a {@link ProtocolNegotiatorFactory} to be used on {@link NettyChannelBuilder}. - * - * @param fallbackNegotiator protocol negotiator to use as fallback. - */ - public static ProtocolNegotiatorFactory clientProtocolNegotiatorFactory( - @Nullable ProtocolNegotiator fallbackNegotiator) { - return new ClientSdsProtocolNegotiatorFactory(fallbackNegotiator); - } - - /** - * Returns a {@link InternalProtocolNegotiator.ClientFactory} to be used on {@link - * NettyChannelBuilder}. + * Returns a {@link InternalProtocolNegotiator.ClientFactory}. * * @param fallbackNegotiator protocol negotiator to use as fallback. */ @@ -91,17 +76,6 @@ public static InternalProtocolNegotiator.ServerFactory serverProtocolNegotiatorF return new ServerFactory(fallbackNegotiator); } - /** - * Creates an SDS based {@link ProtocolNegotiator} for a {@link io.grpc.netty.NettyServerBuilder}. - * If xDS returns no DownstreamTlsContext, it will fall back to plaintext. - * - * @param fallbackProtocolNegotiator protocol negotiator to use as fallback. - */ - public static ServerSdsProtocolNegotiator serverProtocolNegotiator( - @Nullable ProtocolNegotiator fallbackProtocolNegotiator) { - return new ServerSdsProtocolNegotiator(fallbackProtocolNegotiator); - } - private static final class ServerFactory implements InternalProtocolNegotiator.ServerFactory { private final InternalProtocolNegotiator.ServerFactory fallbackProtocolNegotiator; @@ -136,41 +110,6 @@ public int getDefaultPort() { } } - private static final class ClientSdsProtocolNegotiatorFactory - implements InternalNettyChannelBuilder.ProtocolNegotiatorFactory { - - private final ProtocolNegotiator fallbackProtocolNegotiator; - - private ClientSdsProtocolNegotiatorFactory(ProtocolNegotiator fallbackNegotiator) { - this.fallbackProtocolNegotiator = fallbackNegotiator; - } - - @Override - public InternalProtocolNegotiator.ProtocolNegotiator buildProtocolNegotiator() { - final ClientSdsProtocolNegotiator negotiator = - new ClientSdsProtocolNegotiator(fallbackProtocolNegotiator); - final class LocalSdsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { - - @Override - public AsciiString scheme() { - return negotiator.scheme(); - } - - @Override - public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { - return negotiator.newHandler(grpcHandler); - } - - @Override - public void close() { - negotiator.close(); - } - } - - return new LocalSdsNegotiator(); - } - } - @VisibleForTesting static final class ClientSdsProtocolNegotiator implements ProtocolNegotiator { @@ -297,8 +236,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) } } - @VisibleForTesting - public static final class ServerSdsProtocolNegotiator implements ProtocolNegotiator { + private static final class ServerSdsProtocolNegotiator implements ProtocolNegotiator { @Nullable private final ProtocolNegotiator fallbackProtocolNegotiator; @@ -344,11 +282,11 @@ static final class HandlerPickerHandler @Override public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { if (evt instanceof ProtocolNegotiationEvent) { - DownstreamTlsContext downstreamTlsContext = + SslContextProviderSupplier sslContextProviderSupplier = xdsClientWrapperForServerSds == null ? null - : xdsClientWrapperForServerSds.getDownstreamTlsContext(ctx.channel()); - if (downstreamTlsContext == null) { + : xdsClientWrapperForServerSds.getSslContextProviderSupplier(ctx.channel()); + if (sslContextProviderSupplier == null) { if (fallbackProtocolNegotiator == null) { ctx.fireExceptionCaught(new CertStoreException("No certificate source found!")); return; @@ -368,7 +306,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc this, null, new ServerSdsHandler( - grpcHandler, downstreamTlsContext, fallbackProtocolNegotiator)); + grpcHandler, sslContextProviderSupplier)); ProtocolNegotiationEvent pne = InternalProtocolNegotiationEvent.getDefault(); ctx.fireUserEventTriggered(pne); return; @@ -383,13 +321,11 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc static final class ServerSdsHandler extends InternalProtocolNegotiators.ProtocolNegotiationHandler { private final GrpcHttp2ConnectionHandler grpcHandler; - private final DownstreamTlsContext downstreamTlsContext; - @Nullable private final ProtocolNegotiator fallbackProtocolNegotiator; + private final SslContextProviderSupplier sslContextProviderSupplier; ServerSdsHandler( GrpcHttp2ConnectionHandler grpcHandler, - DownstreamTlsContext downstreamTlsContext, - ProtocolNegotiator fallbackProtocolNegotiator) { + SslContextProviderSupplier sslContextProviderSupplier) { super( // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior @@ -402,8 +338,7 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { }, grpcHandler.getNegotiationLogger()); checkNotNull(grpcHandler, "grpcHandler"); this.grpcHandler = grpcHandler; - this.downstreamTlsContext = downstreamTlsContext; - this.fallbackProtocolNegotiator = fallbackProtocolNegotiator; + this.sslContextProviderSupplier = sslContextProviderSupplier; } @Override @@ -411,24 +346,7 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { final BufferReadsHandler bufferReads = new BufferReadsHandler(); ctx.pipeline().addBefore(ctx.name(), null, bufferReads); - SslContextProvider sslContextProviderTemp = null; - try { - sslContextProviderTemp = - TlsContextManagerImpl.getInstance() - .findOrCreateServerSslContextProvider(downstreamTlsContext); - } catch (Exception e) { - if (fallbackProtocolNegotiator == null) { - ctx.fireExceptionCaught(new CertStoreException("No certificate source found!", e)); - return; - } - logger.log(Level.INFO, "Using fallback for {0}", ctx.channel().localAddress()); - // Delegate rest of handshake to fallback handler - ctx.pipeline().replace(this, null, fallbackProtocolNegotiator.newHandler(grpcHandler)); - ctx.pipeline().remove(bufferReads); - return; - } - final SslContextProvider sslContextProvider = sslContextProviderTemp; - sslContextProvider.addCallback( + sslContextProviderSupplier.updateSslContext( new SslContextProvider.Callback(ctx.executor()) { @Override @@ -442,8 +360,6 @@ public void updateSecret(SslContext sslContext) { fireProtocolNegotiationEvent(ctx); ctx.pipeline().remove(bufferReads); } - TlsContextManagerImpl.getInstance() - .releaseServerSslContextProvider(sslContextProvider); } @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java index b6d714e1fe2..9d89b169f2a 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactory.java @@ -19,9 +19,8 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.util.concurrent.ThreadFactoryBuilder; -import io.grpc.xds.Bootstrapper; +import io.grpc.xds.Bootstrapper.BootstrapInfo; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; -import io.grpc.xds.XdsInitializationException; import io.grpc.xds.internal.certprovider.CertProviderServerSslContextProvider; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; import java.util.concurrent.Executors; @@ -30,18 +29,17 @@ final class ServerSslContextProviderFactory implements ValueFactory { - private final Bootstrapper bootstrapper; - private Bootstrapper.BootstrapInfo bootstrapInfo; + private BootstrapInfo bootstrapInfo; private final CertProviderServerSslContextProvider.Factory certProviderServerSslContextProviderFactory; - ServerSslContextProviderFactory(Bootstrapper bootstrapper) { - this(bootstrapper, CertProviderServerSslContextProvider.Factory.getInstance()); + ServerSslContextProviderFactory(BootstrapInfo bootstrapInfo) { + this(bootstrapInfo, CertProviderServerSslContextProvider.Factory.getInstance()); } ServerSslContextProviderFactory( - Bootstrapper bootstrapper, CertProviderServerSslContextProvider.Factory factory) { - this.bootstrapper = bootstrapper; + BootstrapInfo bootstrapInfo, CertProviderServerSslContextProvider.Factory factory) { + this.bootstrapInfo = bootstrapInfo; this.certProviderServerSslContextProviderFactory = factory; } @@ -54,38 +52,24 @@ public SslContextProvider create( downstreamTlsContext.getCommonTlsContext(), "downstreamTlsContext should have CommonTlsContext"); if (CommonTlsContextUtil.hasCertProviderInstance( - downstreamTlsContext.getCommonTlsContext())) { - try { - if (bootstrapInfo == null) { - bootstrapInfo = bootstrapper.bootstrap(); - } - return certProviderServerSslContextProviderFactory.getProvider( - downstreamTlsContext, - bootstrapInfo.getNode().toEnvoyProtoNode(), - bootstrapInfo.getCertProviders()); - } catch (XdsInitializationException e) { - throw new RuntimeException(e); - } + downstreamTlsContext.getCommonTlsContext())) { + return certProviderServerSslContextProviderFactory.getProvider( + downstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNode(), + bootstrapInfo.getCertProviders()); } else if (CommonTlsContextUtil.hasAllSecretsUsingFilename( downstreamTlsContext.getCommonTlsContext())) { return SecretVolumeServerSslContextProvider.getProvider(downstreamTlsContext); } else if (CommonTlsContextUtil.hasAllSecretsUsingSds( downstreamTlsContext.getCommonTlsContext())) { - try { - if (bootstrapInfo == null) { - bootstrapInfo = bootstrapper.bootstrap(); - } - return SdsServerSslContextProvider.getProvider( - downstreamTlsContext, - bootstrapInfo.getNode().toEnvoyProtoNodeV2(), - Executors.newSingleThreadExecutor(new ThreadFactoryBuilder() - .setNameFormat("server-sds-sslcontext-provider-%d") - .setDaemon(true) - .build()), - /* channelExecutor= */ null); - } catch (XdsInitializationException e) { - throw new RuntimeException(e); - } + return SdsServerSslContextProvider.getProvider( + downstreamTlsContext, + bootstrapInfo.getNode().toEnvoyProtoNodeV2(), + Executors.newSingleThreadExecutor(new ThreadFactoryBuilder() + .setNameFormat("server-sds-sslcontext-provider-%d") + .setDaemon(true) + .build()), + /* channelExecutor= */ null); } throw new UnsupportedOperationException("Unsupported configurations in DownstreamTlsContext!"); } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java index 020acd8eee2..3902569d873 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/SslContextProviderSupplier.java @@ -19,10 +19,14 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.MoreObjects; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; +import java.util.Objects; /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} @@ -40,8 +44,8 @@ public final class SslContextProviderSupplier implements Closeable { public SslContextProviderSupplier( BaseTlsContext tlsContext, TlsContextManager tlsContextManager) { - this.tlsContext = tlsContext; - this.tlsContextManager = tlsContextManager; + this.tlsContext = checkNotNull(tlsContext, "tlsContext"); + this.tlsContextManager = checkNotNull(tlsContextManager, "tlsContextManager"); } public BaseTlsContext getTlsContext() { @@ -51,27 +55,36 @@ public BaseTlsContext getTlsContext() { /** Updates SslContext via the passed callback. */ public synchronized void updateSslContext(final SslContextProvider.Callback callback) { checkNotNull(callback, "callback"); - checkState(!shutdown, "Supplier is shutdown!"); - if (sslContextProvider == null) { - sslContextProvider = getSslContextProvider(); + try { + checkState(!shutdown, "Supplier is shutdown!"); + if (sslContextProvider == null) { + sslContextProvider = getSslContextProvider(); + } + // we want to increment the ref-count so call findOrCreate again... + final SslContextProvider toRelease = getSslContextProvider(); + sslContextProvider.addCallback( + new SslContextProvider.Callback(callback.getExecutor()) { + + @Override + public void updateSecret(SslContext sslContext) { + callback.updateSecret(sslContext); + releaseSslContextProvider(toRelease); + } + + @Override + public void onException(Throwable throwable) { + callback.onException(throwable); + releaseSslContextProvider(toRelease); + } + }); + } catch (final Throwable throwable) { + callback.getExecutor().execute(new Runnable() { + @Override + public void run() { + callback.onException(throwable); + } + }); } - // we want to increment the ref-count so call findOrCreate again... - final SslContextProvider toRelease = getSslContextProvider(); - sslContextProvider.addCallback( - new SslContextProvider.Callback(callback.getExecutor()) { - - @Override - public void updateSecret(SslContext sslContext) { - callback.updateSecret(sslContext); - releaseSslContextProvider(toRelease); - } - - @Override - public void onException(Throwable throwable) { - callback.onException(throwable); - releaseSslContextProvider(toRelease); - } - }); } private void releaseSslContextProvider(SslContextProvider toRelease) { @@ -88,14 +101,50 @@ private SslContextProvider getSslContextProvider() { : tlsContextManager.findOrCreateServerSslContextProvider((DownstreamTlsContext) tlsContext); } + @VisibleForTesting public boolean isShutdown() { + return shutdown; + } + /** Called by consumer when tlsContext changes. */ @Override public synchronized void close() { - if (tlsContext instanceof UpstreamTlsContext) { - tlsContextManager.releaseClientSslContextProvider(sslContextProvider); - } else { - tlsContextManager.releaseServerSslContextProvider(sslContextProvider); + if (sslContextProvider != null) { + if (tlsContext instanceof UpstreamTlsContext) { + tlsContextManager.releaseClientSslContextProvider(sslContextProvider); + } else { + tlsContextManager.releaseServerSslContextProvider(sslContextProvider); + } } shutdown = true; } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SslContextProviderSupplier that = (SslContextProviderSupplier) o; + return shutdown == that.shutdown + && Objects.equals(tlsContext, that.tlsContext) + && Objects.equals(tlsContextManager, that.tlsContextManager) + && Objects.equals(sslContextProvider, that.sslContextProvider); + } + + @Override + public int hashCode() { + return Objects.hash(tlsContext, tlsContextManager, sslContextProvider, shutdown); + } + + @Override + public String toString() { + return MoreObjects.toStringHelper(this) + .add("tlsContext", tlsContext) + .add("tlsContextManager", tlsContextManager) + .add("sslContextProvider", sslContextProvider) + .add("shutdown", shutdown) + .toString(); + } } diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java b/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java index c2c10cf3681..75a5d297d90 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/TlsContextManagerImpl.java @@ -20,10 +20,10 @@ import com.google.common.annotations.VisibleForTesting; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.grpc.xds.Bootstrapper; -import io.grpc.xds.BootstrapperImpl; +import io.grpc.xds.Bootstrapper.BootstrapInfo; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.TlsContextManager; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; /** @@ -34,16 +34,16 @@ */ public final class TlsContextManagerImpl implements TlsContextManager { - private static TlsContextManagerImpl instance; - private final ReferenceCountingMap mapForClients; private final ReferenceCountingMap mapForServers; - /** Create a TlsContextManagerImpl instance using the passed in {@link Bootstrapper}. */ - @VisibleForTesting public TlsContextManagerImpl(Bootstrapper bootstrapper) { + /** + * Create a TlsContextManagerImpl instance using the passed in {@link BootstrapInfo}. + */ + @VisibleForTesting public TlsContextManagerImpl(BootstrapInfo bootstrapInfo) { this( - new ClientSslContextProviderFactory(bootstrapper), - new ServerSslContextProviderFactory(bootstrapper)); + new ClientSslContextProviderFactory(bootstrapInfo), + new ServerSslContextProviderFactory(bootstrapInfo)); } @VisibleForTesting @@ -56,14 +56,6 @@ public final class TlsContextManagerImpl implements TlsContextManager { mapForServers = new ReferenceCountingMap<>(serverFactory); } - /** Gets the TlsContextManagerImpl singleton. */ - public static synchronized TlsContextManagerImpl getInstance() { - if (instance == null) { - instance = new TlsContextManagerImpl(new BootstrapperImpl()); - } - return instance; - } - @Override public SslContextProvider findOrCreateServerSslContextProvider( DownstreamTlsContext downstreamTlsContext) { diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/XdsChannelBuilder.java b/xds/src/main/java/io/grpc/xds/internal/sds/XdsChannelBuilder.java deleted file mode 100644 index ab0384c42e9..00000000000 --- a/xds/src/main/java/io/grpc/xds/internal/sds/XdsChannelBuilder.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2019 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.sds; - -import io.grpc.ForwardingChannelBuilder; -import io.grpc.ManagedChannel; -import io.grpc.ManagedChannelBuilder; -import io.grpc.netty.InternalNettyChannelBuilder; -import io.grpc.netty.InternalProtocolNegotiator; -import io.grpc.netty.NettyChannelBuilder; -import java.net.SocketAddress; -import javax.annotation.CheckReturnValue; - -/** - * A version of {@link ManagedChannelBuilder} to create xDS managed channels that will use SDS to - * set up SSL with peers. Note, this is not ready to use yet. - */ -public final class XdsChannelBuilder extends ForwardingChannelBuilder { - - private final NettyChannelBuilder delegate; - private InternalProtocolNegotiator.ProtocolNegotiator fallbackProtocolNegotiator; - - private XdsChannelBuilder(NettyChannelBuilder delegate) { - this.delegate = delegate; - } - - /** - * Creates a new builder with the given server address. See {@link - * NettyChannelBuilder#forAddress(SocketAddress)} for more info. - */ - @CheckReturnValue - public static XdsChannelBuilder forAddress(SocketAddress serverAddress) { - return new XdsChannelBuilder(NettyChannelBuilder.forAddress(serverAddress)); - } - - /** - * Creates a new builder with the given host and port. See {@link - * NettyChannelBuilder#forAddress(String, int)} for more info. - */ - @CheckReturnValue - public static XdsChannelBuilder forAddress(String host, int port) { - return new XdsChannelBuilder(NettyChannelBuilder.forAddress(host, port)); - } - - /** - * Creates a new builder with the given target string. See {@link - * NettyChannelBuilder#forTarget(String)} for more info. - */ - @CheckReturnValue - public static XdsChannelBuilder forTarget(String target) { - return new XdsChannelBuilder(NettyChannelBuilder.forTarget(target)); - } - - /** Set the fallback protocolNegotiator. Pass null to unset a previously set value. */ - public XdsChannelBuilder fallbackProtocolNegotiator( - InternalProtocolNegotiator.ProtocolNegotiator fallbackProtocolNegotiator) { - this.fallbackProtocolNegotiator = fallbackProtocolNegotiator; - return this; - } - - @Override - protected ManagedChannelBuilder delegate() { - return delegate; - } - - @Override - public ManagedChannel build() { - InternalNettyChannelBuilder.setProtocolNegotiatorFactory( - delegate, - SdsProtocolNegotiators.clientProtocolNegotiatorFactory(fallbackProtocolNegotiator)); - return delegate.build(); - } -} diff --git a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManager.java index 1e16676d0a8..3178d2b3e4b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManager.java @@ -151,14 +151,13 @@ private static boolean verifyOneSanInList(List entry, List ver if (altNameType == null) { throw new CertificateParsingException("Invalid SAN entry: null altNameType"); } - String altNameFromCert = (String) entry.get(1); switch (altNameType) { case ALT_DNS_NAME: case ALT_URI_NAME: case ALT_IPA_NAME: - return verifyDnsNameInSanList(altNameFromCert, verifySanList); + return verifyDnsNameInSanList((String) entry.get(1), verifySanList); default: - throw new CertificateParsingException("Unsupported altNameType: " + altNameType); + return false; } } diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index 2287b8fc1b2..3875e66ded9 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -48,7 +48,6 @@ import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.RingHashLoadBalancer.RingHashConfig; import io.grpc.xds.XdsClient.CdsUpdate; -import io.grpc.xds.XdsClient.CdsUpdate.LbPolicy; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import java.util.ArrayList; import java.util.Arrays; @@ -73,9 +72,9 @@ @RunWith(JUnit4.class) public class CdsLoadBalancer2Test { - private static final String CLUSTER = "cluster-foo.googleapis.com"; // cluster of entry point - + private static final String CLUSTER = "cluster-foo.googleapis.com"; private static final String EDS_SERVICE_NAME = "backend-service-1.googleapis.com"; + private static final String DNS_HOST_NAME = "backend-service-dns.googleapis.com:443"; private static final String LRS_SERVER_NAME = "lrs.googleapis.com"; private final UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( @@ -148,7 +147,7 @@ public void tearDown() { public void discoverTopLevelEdsCluster() { CdsUpdate update = CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_NAME, 100L, upstreamTlsContext) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); @@ -157,15 +156,15 @@ public void discoverTopLevelEdsCluster() { assertThat(childLbConfig.discoveryMechanisms).hasSize(1); DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - LRS_SERVER_NAME, 100L, upstreamTlsContext); + null, LRS_SERVER_NAME, 100L, upstreamTlsContext); assertThat(childLbConfig.lbPolicy.getProvider().getPolicyName()).isEqualTo("round_robin"); } @Test public void discoverTopLevelLogicalDnsCluster() { CdsUpdate update = - CdsUpdate.forLogicalDns(CLUSTER, LRS_SERVER_NAME, 100L, upstreamTlsContext) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + CdsUpdate.forLogicalDns(CLUSTER, DNS_HOST_NAME, LRS_SERVER_NAME, 100L, upstreamTlsContext) + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); @@ -174,7 +173,7 @@ public void discoverTopLevelLogicalDnsCluster() { assertThat(childLbConfig.discoveryMechanisms).hasSize(1); DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.LOGICAL_DNS, null, - LRS_SERVER_NAME, 100L, upstreamTlsContext); + DNS_HOST_NAME, LRS_SERVER_NAME, 100L, upstreamTlsContext); assertThat(childLbConfig.lbPolicy.getProvider().getPolicyName()).isEqualTo("round_robin"); } @@ -192,36 +191,36 @@ public void nonAggregateCluster_resourceNotExist_returnErrorPicker() { public void nonAggregateCluster_resourceUpdate() { CdsUpdate update = CdsUpdate.forEds(CLUSTER, null, null, 100L, upstreamTlsContext) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, null, null, 100L, - upstreamTlsContext); + assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, null, null, null, + 100L, upstreamTlsContext); update = CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_NAME, 200L, null) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); childLbConfig = (ClusterResolverConfig) childBalancer.config; instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - LRS_SERVER_NAME, 200L, null); + null, LRS_SERVER_NAME, 200L, null); } @Test public void nonAggregateCluster_resourceRevoked() { CdsUpdate update = - CdsUpdate.forLogicalDns(CLUSTER, null, 100L, upstreamTlsContext) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + CdsUpdate.forLogicalDns(CLUSTER, DNS_HOST_NAME, null, 100L, upstreamTlsContext) + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.LOGICAL_DNS, null, null, - 100L, upstreamTlsContext); + assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.LOGICAL_DNS, null, + DNS_HOST_NAME, null, 100L, upstreamTlsContext); xdsClient.deliverResourceNotExist(CLUSTER); assertThat(childBalancer.shutdown).isTrue(); @@ -240,7 +239,7 @@ public void discoverAggregateCluster() { // CLUSTER (aggr.) -> [cluster1 (aggr.), cluster2 (logical DNS)] CdsUpdate update = CdsUpdate.forAggregate(CLUSTER, Arrays.asList(cluster1, cluster2)) - .lbPolicy(LbPolicy.RING_HASH, 100L, 1000L).build(); + .ringHashLbPolicy(100L, 1000L).build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); assertThat(childBalancers).isEmpty(); @@ -249,24 +248,24 @@ public void discoverAggregateCluster() { // cluster1 (aggr.) -> [cluster3 (EDS), cluster4 (EDS)] CdsUpdate update1 = CdsUpdate.forAggregate(cluster1, Arrays.asList(cluster3, cluster4)) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster1, update1); assertThat(xdsClient.watchers.keySet()).containsExactly( CLUSTER, cluster1, cluster2, cluster3, cluster4); assertThat(childBalancers).isEmpty(); CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_NAME, 200L, upstreamTlsContext) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster3, update3); assertThat(childBalancers).isEmpty(); CdsUpdate update2 = - CdsUpdate.forLogicalDns(cluster2, null, 100L, null) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, null, 100L, null) + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster2, update2); assertThat(childBalancers).isEmpty(); CdsUpdate update4 = CdsUpdate.forEds(cluster4, null, LRS_SERVER_NAME, 300L, null) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster4, update4); assertThat(childBalancers).hasSize(1); // all non-aggregate clusters discovered FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); @@ -275,11 +274,12 @@ public void discoverAggregateCluster() { assertThat(childLbConfig.discoveryMechanisms).hasSize(3); // Clusters on higher level has higher priority: [cluster2, cluster3, cluster4] assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, null, 100L, null); + DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, null, 100L, null); assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster3, - DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, LRS_SERVER_NAME, 200L, upstreamTlsContext); + DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_NAME, 200L, + upstreamTlsContext); assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(2), cluster4, - DiscoveryMechanism.Type.EDS, null, LRS_SERVER_NAME, 300L, null); + DiscoveryMechanism.Type.EDS, null, null, LRS_SERVER_NAME, 300L, null); assertThat(childLbConfig.lbPolicy.getProvider().getPolicyName()) .isEqualTo("ring_hash"); // dominated by top-level cluster's config assertThat(((RingHashConfig) childLbConfig.lbPolicy.getConfig()).minRingSize).isEqualTo(100L); @@ -292,7 +292,7 @@ public void aggregateCluster_noNonAggregateClusterExits_returnErrorPicker() { // CLUSTER (aggr.) -> [cluster1 (EDS)] CdsUpdate update = CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); xdsClient.deliverResourceNotExist(cluster1); @@ -310,24 +310,25 @@ public void aggregateCluster_descendantClustersRevoked() { // CLUSTER (aggr.) -> [cluster1 (EDS), cluster2 (logical DNS)] CdsUpdate update = CdsUpdate.forAggregate(CLUSTER, Arrays.asList(cluster1, cluster2)) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); CdsUpdate update1 = CdsUpdate.forEds(cluster1, EDS_SERVICE_NAME, LRS_SERVER_NAME, 200L, upstreamTlsContext) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster1, update1); CdsUpdate update2 = - CdsUpdate.forLogicalDns(cluster2, LRS_SERVER_NAME, 100L, null) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, LRS_SERVER_NAME, 100L, null) + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster2, update2); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanisms).hasSize(2); assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster1, - DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, LRS_SERVER_NAME, 200L, upstreamTlsContext); + DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_NAME, 200L, + upstreamTlsContext); assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, LRS_SERVER_NAME, 100L, null); + DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_NAME, 100L, null); // Revoke cluster1, should still be able to proceed with cluster2. xdsClient.deliverResourceNotExist(cluster1); @@ -335,7 +336,7 @@ public void aggregateCluster_descendantClustersRevoked() { childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanisms).hasSize(1); assertDiscoveryMechanism(Iterables.getOnlyElement(childLbConfig.discoveryMechanisms), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, LRS_SERVER_NAME, 100L, null); + DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_NAME, 100L, null); verify(helper, never()).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), any(SubchannelPicker.class)); @@ -356,24 +357,25 @@ public void aggregateCluster_rootClusterRevoked() { // CLUSTER (aggr.) -> [cluster1 (EDS), cluster2 (logical DNS)] CdsUpdate update = CdsUpdate.forAggregate(CLUSTER, Arrays.asList(cluster1, cluster2)) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); CdsUpdate update1 = CdsUpdate.forEds(cluster1, EDS_SERVICE_NAME, LRS_SERVER_NAME, 200L, upstreamTlsContext) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster1, update1); CdsUpdate update2 = - CdsUpdate.forLogicalDns(cluster2, LRS_SERVER_NAME, 100L, null) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, LRS_SERVER_NAME, 100L, null) + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster2, update2); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanisms).hasSize(2); assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster1, - DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, LRS_SERVER_NAME, 200L, upstreamTlsContext); + DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_NAME, 200L, + upstreamTlsContext); assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, LRS_SERVER_NAME, 100L, null); + DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_NAME, 100L, null); xdsClient.deliverResourceNotExist(CLUSTER); assertThat(xdsClient.watchers.keySet()) @@ -392,7 +394,7 @@ public void aggregateCluster_intermediateClusterChanges() { // CLUSTER (aggr.) -> [cluster1] CdsUpdate update = CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); @@ -400,7 +402,7 @@ public void aggregateCluster_intermediateClusterChanges() { String cluster2 = "cluster-02.googleapis.com"; update = CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster2)) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster2); @@ -408,19 +410,19 @@ public void aggregateCluster_intermediateClusterChanges() { String cluster3 = "cluster-03.googleapis.com"; CdsUpdate update2 = CdsUpdate.forAggregate(cluster2, Collections.singletonList(cluster3)) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster2, update2); assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster2, cluster3); CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_NAME, 100L, upstreamTlsContext) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster3, update3); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanisms).hasSize(1); DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); assertDiscoveryMechanism(instance, cluster3, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - LRS_SERVER_NAME, 100L, upstreamTlsContext); + null, LRS_SERVER_NAME, 100L, upstreamTlsContext); // cluster2 revoked xdsClient.deliverResourceNotExist(cluster2); @@ -440,7 +442,7 @@ public void aggregateCluster_discoveryErrorBeforeChildLbCreated_returnErrorPicke // CLUSTER (aggr.) -> [cluster1] CdsUpdate update = CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); Status error = Status.RESOURCE_EXHAUSTED.withDescription("OOM"); @@ -457,11 +459,11 @@ public void aggregateCluster_discoveryErrorAfterChildLbCreated_propagateToChildL // CLUSTER (aggr.) -> [cluster1 (logical DNS)] CdsUpdate update = CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); CdsUpdate update1 = - CdsUpdate.forLogicalDns(cluster1, LRS_SERVER_NAME, 200L, null) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + CdsUpdate.forLogicalDns(cluster1, DNS_HOST_NAME, LRS_SERVER_NAME, 200L, null) + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(cluster1, update1); FakeLoadBalancer childLb = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childLb.config; @@ -486,7 +488,7 @@ public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_returnErr public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThrough() { CdsUpdate update = CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_NAME, 100L, upstreamTlsContext) - .lbPolicy(LbPolicy.ROUND_ROBIN).build(); + .roundRobinLbPolicy().build(); xdsClient.deliverCdsUpdate(CLUSTER, update); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.shutdown).isFalse(); @@ -509,12 +511,13 @@ private static void assertPicker(SubchannelPicker picker, Status expectedStatus, } private static void assertDiscoveryMechanism(DiscoveryMechanism instance, String name, - DiscoveryMechanism.Type type, @Nullable String edsServiceName, + DiscoveryMechanism.Type type, @Nullable String edsServiceName, @Nullable String dnsHostName, @Nullable String lrsServerName, @Nullable Long maxConcurrentRequests, @Nullable UpstreamTlsContext tlsContext) { assertThat(instance.cluster).isEqualTo(name); assertThat(instance.type).isEqualTo(type); assertThat(instance.edsServiceName).isEqualTo(edsServiceName); + assertThat(instance.dnsHostName).isEqualTo(dnsHostName); assertThat(instance.lrsServerName).isEqualTo(lrsServerName); assertThat(instance.maxConcurrentRequests).isEqualTo(maxConcurrentRequests); assertThat(instance.tlsContext).isEqualTo(tlsContext); diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java index e092fd759c5..913e1693f79 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientDataTest.java @@ -23,9 +23,18 @@ import com.google.protobuf.BoolValue; import com.google.protobuf.StringValue; import com.google.protobuf.UInt32Value; +import com.google.protobuf.UInt64Value; import com.google.protobuf.util.Durations; import com.google.re2j.Pattern; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.DiscoveryType; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.EdsClusterConfig; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.LbPolicy; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig; +import io.envoyproxy.envoy.config.cluster.v3.Cluster.RingHashLbConfig.HashFunction; import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; +import io.envoyproxy.envoy.config.core.v3.ConfigSource; import io.envoyproxy.envoy.config.core.v3.ExtensionConfigSource; import io.envoyproxy.envoy.config.core.v3.Locality; import io.envoyproxy.envoy.config.core.v3.RuntimeFractionalPercent; @@ -58,31 +67,40 @@ import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; import io.envoyproxy.envoy.type.v3.Int64Range; import io.grpc.Status.Code; +import io.grpc.xds.ClientXdsClient.ResourceInvalidException; import io.grpc.xds.ClientXdsClient.StructOrError; import io.grpc.xds.Endpoints.LbEndpoint; import io.grpc.xds.Endpoints.LocalityLbEndpoints; import io.grpc.xds.FaultConfig.FaultAbort; import io.grpc.xds.Filter.FilterConfig; -import io.grpc.xds.Matchers.FractionMatcher; -import io.grpc.xds.Matchers.HeaderMatcher; -import io.grpc.xds.Matchers.PathMatcher; import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.VirtualHost.Route.RouteAction; import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight; import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy; import io.grpc.xds.VirtualHost.Route.RouteMatch; +import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; +import io.grpc.xds.XdsClient.CdsUpdate; +import io.grpc.xds.internal.Matchers.FractionMatcher; +import io.grpc.xds.internal.Matchers.HeaderMatcher; import java.util.Arrays; import java.util.Collections; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class ClientXdsClientDataTest { + @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 + @Rule + public final ExpectedException thrown = ExpectedException.none(); + @Test public void parseRoute_withRouteAction() { io.envoyproxy.envoy.config.route.v3.Route proto = @@ -680,6 +698,80 @@ public void parseOverrideFilterConfigs_unsupportedAndRequired() { + "type.googleapis.com/google.protobuf.StringValue"); } + @Test + public void parseCluster_ringHashLbPolicy_defaultLbConfig() throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.RING_HASH) + .build(); + + CdsUpdate update = ClientXdsClient.parseCluster(cluster, new HashSet()); + assertThat(update.lbPolicy()).isEqualTo(CdsUpdate.LbPolicy.RING_HASH); + assertThat(update.minRingSize()) + .isEqualTo(ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MIN_RING_SIZE); + assertThat(update.maxRingSize()) + .isEqualTo(ClientXdsClient.DEFAULT_RING_HASH_LB_POLICY_MAX_RING_SIZE); + } + + @Test + public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_minGreaterThanMax() + throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.RING_HASH) + .setRingHashLbConfig( + RingHashLbConfig.newBuilder() + .setHashFunction(HashFunction.XX_HASH) + .setMinimumRingSize(UInt64Value.newBuilder().setValue(1000L)) + .setMaximumRingSize(UInt64Value.newBuilder().setValue(100L))) + .build(); + + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("Cluster cluster-foo.googleapis.com: invalid ring_hash_lb_config"); + ClientXdsClient.parseCluster(cluster, new HashSet()); + } + + @Test + public void parseCluster_ringHashLbPolicy_invalidRingSizeConfig_tooLargeRingSize() + throws ResourceInvalidException { + Cluster cluster = Cluster.newBuilder() + .setName("cluster-foo.googleapis.com") + .setType(DiscoveryType.EDS) + .setEdsClusterConfig( + EdsClusterConfig.newBuilder() + .setEdsConfig( + ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.getDefaultInstance())) + .setServiceName("service-foo.googleapis.com")) + .setLbPolicy(LbPolicy.RING_HASH) + .setRingHashLbConfig( + RingHashLbConfig.newBuilder() + .setHashFunction(HashFunction.XX_HASH) + .setMinimumRingSize(UInt64Value.newBuilder().setValue(1000L)) + .setMaximumRingSize( + UInt64Value.newBuilder() + .setValue(ClientXdsClient.MAX_RING_HASH_LB_POLICY_RING_SIZE + 1))) + .build(); + + thrown.expect(ResourceInvalidException.class); + thrown.expectMessage("Cluster cluster-foo.googleapis.com: invalid ring_hash_lb_config"); + ClientXdsClient.parseCluster(cluster, new HashSet()); + } + @Test public void parseServerSideListener_invalidTrafficDirection() { Listener listener = @@ -688,7 +780,7 @@ public void parseServerSideListener_invalidTrafficDirection() { .setTrafficDirection(TrafficDirection.OUTBOUND) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()).isEqualTo("Listener listener1 is not INBOUND"); } @@ -701,7 +793,7 @@ public void parseServerSideListener_listenerFiltersPresent() { .addListenerFilters(ListenerFilter.newBuilder().build()) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("Listener listener1 cannot have listener_filters"); } @@ -715,7 +807,7 @@ public void parseServerSideListener_useOriginalDst() { .setUseOriginalDst(BoolValue.of(true)) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("Listener listener1 cannot have use_original_dst set to true"); } @@ -729,7 +821,7 @@ public void parseServerSideListener_noHcm() { .addFilterChains(FilterChain.newBuilder().build()) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("filerChain has to have envoy.http_connection_manager"); } @@ -753,7 +845,7 @@ public void parseServerSideListener_duplicateFilterName() { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("filerChain has non-unique filter name:envoy.http_connection_manager"); } @@ -773,7 +865,7 @@ public void parseServerSideListener_configDiscoveryFilter() { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("filter envoy.http_connection_manager with config_discovery not supported"); } @@ -789,7 +881,7 @@ public void parseServerSideListener_expectTypedConfigFilter() { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("filter envoy.http_connection_manager expected to have typed_config"); } @@ -809,7 +901,7 @@ public void parseServerSideListener_wrongTypeUrl() { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo( "filter envoy.http_connection_manager with unsupported typed_config type:badTypeUrl"); @@ -830,7 +922,7 @@ public void parseServerSideListener_duplicateHttpFilter() { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo("http-connection-manager has non-unique http-filter name:hf"); } @@ -852,7 +944,7 @@ public void parseServerSideListener_configDiscoveryHttpFilter() { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo( "http-connection-manager http-filter envoy.router uses " @@ -877,7 +969,7 @@ public void parseServerSideListener_badTypeUrlHttpFilter() { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo( "http-connection-manager http-filter envoy.router has unsupported typed-config type:" @@ -898,7 +990,7 @@ public void parseServerSideListener_missingTypeUrlHttpFilter() { .addFilterChains(filterChain) .build(); StructOrError struct = - ClientXdsClient.parseServerSideListener(listener); + ClientXdsClient.parseServerSideListener(listener, null); assertThat(struct.getErrorDetail()) .isEqualTo( "http-connection-manager http-filter envoy.filters.http.router should have " diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java index 20a9da63079..85367118848 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientTestBase.java @@ -22,7 +22,9 @@ import static io.grpc.xds.AbstractXdsClient.ResourceType.EDS; import static io.grpc.xds.AbstractXdsClient.ResourceType.LDS; import static io.grpc.xds.AbstractXdsClient.ResourceType.RDS; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -40,6 +42,8 @@ import io.envoyproxy.envoy.config.route.v3.FilterConfig; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.SdsSecretConfig; import io.grpc.BindableService; +import io.grpc.Context; +import io.grpc.Context.CancellableContext; import io.grpc.InsecureChannelCredentials; import io.grpc.ManagedChannel; import io.grpc.Status; @@ -230,6 +234,8 @@ public long currentTimeNanos() { private CdsResourceWatcher cdsResourceWatcher; @Mock private EdsResourceWatcher edsResourceWatcher; + @Mock + private TlsContextManager tlsContextManager; private ManagedChannel channel; private ClientXdsClient xdsClient; @@ -270,10 +276,12 @@ public void setUp() throws IOException { new ClientXdsClient( channel, bootstrapInfo, + Context.ROOT, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, fakeClock.getStopwatchSupplier(), - timeProvider); + timeProvider, + tlsContextManager); assertThat(resourceDiscoveryCalls).isEmpty(); assertThat(loadReportCalls).isEmpty(); @@ -1227,8 +1235,8 @@ public void cdsResponseWithUpstreamTlsContext() { null, true, mf.buildUpstreamTlsContext("secret1", "unix:/var/uds2"), null)); List clusters = ImmutableList.of( - Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", "round_robin", null, - false, null, null)), + Any.pack(mf.buildLogicalDnsCluster("cluster-bar.googleapis.com", + "dns-service-bar.googleapis.com", 443, "round_robin", null, false, null, null)), clusterEds, Any.pack(mf.buildEdsCluster("cluster-baz.googleapis.com", null, "round_robin", null, false, null, null))); @@ -1299,14 +1307,18 @@ public void cdsResourceUpdated() { verifyResourceMetadataRequested(CDS, CDS_RESOURCE); // Initial CDS response. + String dnsHostAddr = "dns-service-bar.googleapis.com"; + int dnsHostPort = 443; Any clusterDns = - Any.pack(mf.buildLogicalDnsCluster(CDS_RESOURCE, "round_robin", null, false, null, null)); + Any.pack(mf.buildLogicalDnsCluster(CDS_RESOURCE, dnsHostAddr, dnsHostPort, "round_robin", + null, false, null, null)); call.sendResponse(CDS, clusterDns, VERSION_1, "0000"); call.verifyRequest(CDS, CDS_RESOURCE, VERSION_1, "0000", NODE); verify(cdsResourceWatcher).onChanged(cdsUpdateCaptor.capture()); CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.LOGICAL_DNS); + assertThat(cdsUpdate.dnsHostName()).isEqualTo(dnsHostAddr + ":" + dnsHostPort); assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.ROUND_ROBIN); assertThat(cdsUpdate.lrsServerName()).isNull(); assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); @@ -1383,9 +1395,12 @@ public void multipleCdsWatchers() { verifyResourceMetadataDoesNotExist(CDS, cdsResourceTwo); verifySubscribedResourcesMetadataSizes(0, 2, 0, 0); + String dnsHostAddr = "dns-service-bar.googleapis.com"; + int dnsHostPort = 443; String edsService = "eds-service-bar.googleapis.com"; List clusters = ImmutableList.of( - Any.pack(mf.buildLogicalDnsCluster(CDS_RESOURCE, "round_robin", null, false, null, null)), + Any.pack(mf.buildLogicalDnsCluster(CDS_RESOURCE, dnsHostAddr, dnsHostPort, "round_robin", + null, false, null, null)), Any.pack(mf.buildEdsCluster(cdsResourceTwo, edsService, "round_robin", null, true, null, null))); call.sendResponse(CDS, clusters, VERSION_1, "0000"); @@ -1393,6 +1408,7 @@ public void multipleCdsWatchers() { CdsUpdate cdsUpdate = cdsUpdateCaptor.getValue(); assertThat(cdsUpdate.clusterName()).isEqualTo(CDS_RESOURCE); assertThat(cdsUpdate.clusterType()).isEqualTo(ClusterType.LOGICAL_DNS); + assertThat(cdsUpdate.dnsHostName()).isEqualTo(dnsHostAddr + ":" + dnsHostPort); assertThat(cdsUpdate.lbPolicy()).isEqualTo(LbPolicy.ROUND_ROBIN); assertThat(cdsUpdate.lrsServerName()).isNull(); assertThat(cdsUpdate.maxConcurrentRequests()).isNull(); @@ -1776,6 +1792,26 @@ public void multipleEdsWatchers() { verifySubscribedResourcesMetadataSizes(0, 0, 0, 2); } + @Test + public void useIndependentRpcContext() { + // Simulates making RPCs within the context of an inbound RPC. + CancellableContext cancellableContext = Context.current().withCancellation(); + Context prevContext = cancellableContext.attach(); + try { + DiscoveryRpcCall call = startResourceWatcher(LDS, LDS_RESOURCE, ldsResourceWatcher); + + // The inbound RPC finishes and closes its context. The outbound RPC's control plane RPC + // should not be impacted. + cancellableContext.close(); + verify(ldsResourceWatcher, never()).onError(any(Status.class)); + + call.sendResponse(LDS, testListenerRds, VERSION_1, "0000"); + verify(ldsResourceWatcher).onChanged(any(LdsUpdate.class)); + } finally { + cancellableContext.detach(prevContext); + } + } + @Test public void streamClosedAndRetryWithBackoff() { InOrder inOrder = Mockito.inOrder(backoffPolicyProvider, backoffPolicy1, backoffPolicy2); @@ -1987,7 +2023,7 @@ public void serverSideListenerFound() throws InvalidProtocolBufferException { ClientXdsClientTestBase.DiscoveryRpcCall call = startResourceWatcher(LDS, LISTENER_RESOURCE, ldsResourceWatcher); Message listener = - mf.buildListenerWithFilterChain( + mf.buildListenerWithFilterChain( LISTENER_RESOURCE, 7000, "0.0.0.0", "google-sds-config-default", "ROOTCA"); List listeners = ImmutableList.of(Any.pack(listener)); call.sendResponse(ResourceType.LDS, listeners, "0", "0000"); @@ -1996,10 +2032,11 @@ public void serverSideListenerFound() throws InvalidProtocolBufferException { ResourceType.LDS, Collections.singletonList(LISTENER_RESOURCE), "0", "0000", NODE); verify(ldsResourceWatcher).onChanged(ldsUpdateCaptor.capture()); assertThat(ldsUpdateCaptor.getValue().listener) - .isEqualTo(EnvoyServerProtoData.Listener.fromEnvoyProtoListener((Listener)listener)); + .isEqualTo(EnvoyServerProtoData.Listener + .fromEnvoyProtoListener((Listener) listener, tlsContextManager)); listener = - mf.buildListenerWithFilterChain( + mf.buildListenerWithFilterChain( LISTENER_RESOURCE, 7000, "0.0.0.0", "CERT2", "ROOTCA2"); listeners = ImmutableList.of(Any.pack(listener)); call.sendResponse(ResourceType.LDS, listeners, "1", "0001"); @@ -2009,7 +2046,8 @@ public void serverSideListenerFound() throws InvalidProtocolBufferException { ResourceType.LDS, Collections.singletonList(LISTENER_RESOURCE), "1", "0001", NODE); verify(ldsResourceWatcher, times(2)).onChanged(ldsUpdateCaptor.capture()); assertThat(ldsUpdateCaptor.getValue().listener) - .isEqualTo(EnvoyServerProtoData.Listener.fromEnvoyProtoListener((Listener)listener)); + .isEqualTo(EnvoyServerProtoData.Listener + .fromEnvoyProtoListener((Listener) listener, tlsContextManager)); assertThat(fakeClock.getPendingTasks(LDS_RESOURCE_FETCH_TIMEOUT_TASK_FILTER)).isEmpty(); } @@ -2161,8 +2199,8 @@ protected abstract Message buildEdsCluster(String clusterName, @Nullable String String lbPolicy, @Nullable Message ringHashLbConfig, boolean enableLrs, @Nullable Message upstreamTlsContext, @Nullable Message circuitBreakers); - protected abstract Message buildLogicalDnsCluster(String clusterName, String lbPolicy, - @Nullable Message ringHashLbConfig, boolean enableLrs, + protected abstract Message buildLogicalDnsCluster(String clusterName, String dnsHostAddr, + int dnsHostPort, String lbPolicy, @Nullable Message ringHashLbConfig, boolean enableLrs, @Nullable Message upstreamTlsContext, @Nullable Message circuitBreakers); protected abstract Message buildAggregateCluster(String clusterName, String lbPolicy, diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java index 94e93c4e9b6..27ef2ba9eba 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV2Test.java @@ -408,12 +408,20 @@ protected Message buildEdsCluster(String clusterName, @Nullable String edsServic } @Override - protected Message buildLogicalDnsCluster(String clusterName, String lbPolicy, - @Nullable Message ringHashLbConfig, boolean enableLrs, + protected Message buildLogicalDnsCluster(String clusterName, String dnsHostAddr, + int dnsHostPort, String lbPolicy, @Nullable Message ringHashLbConfig, boolean enableLrs, @Nullable Message upstreamTlsContext, @Nullable Message circuitBreakers) { Cluster.Builder builder = initClusterBuilder(clusterName, lbPolicy, ringHashLbConfig, enableLrs, upstreamTlsContext, circuitBreakers); builder.setType(DiscoveryType.LOGICAL_DNS); + builder.setLoadAssignment( + ClusterLoadAssignment.newBuilder().addEndpoints( + LocalityLbEndpoints.newBuilder().addLbEndpoints( + LbEndpoint.newBuilder().setEndpoint( + Endpoint.newBuilder().setAddress( + Address.newBuilder().setSocketAddress( + SocketAddress.newBuilder() + .setAddress(dnsHostAddr).setPortValue(dnsHostPort)))))).build()); return builder.build(); } diff --git a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java index 9332297ff96..a68decba885 100644 --- a/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java +++ b/xds/src/test/java/io/grpc/xds/ClientXdsClientV3Test.java @@ -453,12 +453,20 @@ protected Message buildEdsCluster(String clusterName, @Nullable String edsServic } @Override - protected Message buildLogicalDnsCluster(String clusterName, String lbPolicy, - @Nullable Message ringHashLbConfig, boolean enableLrs, + protected Message buildLogicalDnsCluster(String clusterName, String dnsHostAddr, + int dnsHostPort, String lbPolicy, @Nullable Message ringHashLbConfig, boolean enableLrs, @Nullable Message upstreamTlsContext, @Nullable Message circuitBreakers) { Cluster.Builder builder = initClusterBuilder(clusterName, lbPolicy, ringHashLbConfig, enableLrs, upstreamTlsContext, circuitBreakers); builder.setType(DiscoveryType.LOGICAL_DNS); + builder.setLoadAssignment( + ClusterLoadAssignment.newBuilder().addEndpoints( + LocalityLbEndpoints.newBuilder().addLbEndpoints( + LbEndpoint.newBuilder().setEndpoint( + Endpoint.newBuilder().setAddress( + Address.newBuilder().setSocketAddress( + SocketAddress.newBuilder() + .setAddress(dnsHostAddr).setPortValue(dnsHostPort)))))).build()); return builder.build(); } diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index b1e5dceefac..47926cae55e 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -58,7 +58,6 @@ import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.SslContextProvider; import io.grpc.xds.internal.sds.SslContextProviderSupplier; -import io.grpc.xds.internal.sds.TlsContextManager; import java.net.SocketAddress; import java.util.ArrayList; import java.util.Arrays; @@ -134,12 +133,14 @@ public AtomicLong getOrCreate(String cluster, @Nullable String edsServiceName) { @Before public void setUp() { MockitoAnnotations.initMocks(this); - loadBalancer = new ClusterImplLoadBalancer(helper, mockRandom, tlsContextManager); + loadBalancer = new ClusterImplLoadBalancer(helper, mockRandom); } @After public void tearDown() { - loadBalancer.shutdown(); + if (loadBalancer != null) { + loadBalancer.shutdown(); + } assertThat(xdsClientRefs).isEqualTo(0); assertThat(downstreamBalancers).isEmpty(); } @@ -554,11 +555,21 @@ private void subtest_endpointAddressesAttachedWithTlsConfig(boolean enableSecuri SslContextProviderSupplier supplier = eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); if (enableSecurity) { + assertThat(supplier.isShutdown()).isFalse(); assertThat(supplier.getTlsContext()).isEqualTo(upstreamTlsContext); } else { assertThat(supplier).isNull(); } } + loadBalancer.shutdown(); + for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { + SslContextProviderSupplier supplier = + eag.getAttributes().get(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); + if (enableSecurity) { + assertThat(supplier.isShutdown()).isTrue(); + } + } + loadBalancer = null; } private void deliverAddressesAndConfig(List addresses, @@ -765,6 +776,11 @@ ClusterLocalityStats addClusterLocalityStats(String clusterName, @Nullable String edsServiceName, Locality locality) { return loadStatsManager.getClusterLocalityStats(clusterName, edsServiceName, locality); } + + @Override + TlsContextManager getTlsContextManager() { + return tlsContextManager; + } } private static final class FakeTlsContextManager implements TlsContextManager { diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index 294a33b0bfc..b987ca95c33 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -102,6 +102,7 @@ public class ClusterResolverLoadBalancerTest { private static final String CLUSTER_DNS = "cluster-dns.googleapis.com"; private static final String EDS_SERVICE_NAME1 = "backend-service-foo.googleapis.com"; private static final String EDS_SERVICE_NAME2 = "backend-service-bar.googleapis.com"; + private static final String DNS_HOST_NAME = "dns-service.googleapis.com"; private static final String LRS_SERVER_NAME = "lrs.googleapis.com"; private final Locality locality1 = Locality.create("test-region-1", "test-zone-1", "test-subzone-1"); @@ -119,7 +120,7 @@ public class ClusterResolverLoadBalancerTest { private final DiscoveryMechanism edsDiscoveryMechanism2 = DiscoveryMechanism.forEds(CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_NAME, 200L, tlsContext); private final DiscoveryMechanism logicalDnsDiscoveryMechanism = - DiscoveryMechanism.forLogicalDns(CLUSTER_DNS, LRS_SERVER_NAME, 300L, null); + DiscoveryMechanism.forLogicalDns(CLUSTER_DNS, DNS_HOST_NAME, LRS_SERVER_NAME, 300L, null); private final SynchronizationContext syncContext = new SynchronizationContext( new Thread.UncaughtExceptionHandler() { @@ -216,29 +217,37 @@ public void edsClustersWithRingHashEndpointLbPolicy() { // One priority with two localities of different weights. EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); + EquivalentAddressGroup endpoint3 = makeAddress("endpoint-addr-3"); LocalityLbEndpoints localityLbEndpoints1 = LocalityLbEndpoints.create( - Collections.singletonList( - LbEndpoint.create(endpoint1, 100 /* loadBalancingWeight */, true)), + Arrays.asList( + LbEndpoint.create(endpoint1, 0 /* loadBalancingWeight */, true), + LbEndpoint.create(endpoint2, 0 /* loadBalancingWeight */, true)), 10 /* localityWeight */, 1 /* priority */); LocalityLbEndpoints localityLbEndpoints2 = LocalityLbEndpoints.create( Collections.singletonList( - LbEndpoint.create(endpoint2, 60 /* loadBalancingWeight */, true)), + LbEndpoint.create(endpoint3, 60 /* loadBalancingWeight */, true)), 50 /* localityWeight */, 1 /* priority */); xdsClient.deliverClusterLoadAssignment( EDS_SERVICE_NAME1, ImmutableMap.of(locality1, localityLbEndpoints1, locality2, localityLbEndpoints2)); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.addresses).hasSize(2); + assertThat(childBalancer.addresses).hasSize(3); EquivalentAddressGroup addr1 = childBalancer.addresses.get(0); EquivalentAddressGroup addr2 = childBalancer.addresses.get(1); + EquivalentAddressGroup addr3 = childBalancer.addresses.get(2); + // Endpoints in locality1 have no endpoint-level weight specified, so all endpoints within + // locality1 are equally weighted. assertThat(addr1.getAddresses()).isEqualTo(endpoint1.getAddresses()); assertThat(addr1.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT)) - .isEqualTo(10 * 100); + .isEqualTo(10); assertThat(addr2.getAddresses()).isEqualTo(endpoint2.getAddresses()); assertThat(addr2.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT)) + .isEqualTo(10); + assertThat(addr3.getAddresses()).isEqualTo(endpoint3.getAddresses()); + assertThat(addr3.getAttributes().get(InternalXdsAttributes.ATTR_SERVER_WEIGHT)) .isEqualTo(50 * 60); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; @@ -290,42 +299,53 @@ public void onlyEdsClusters_receivedEndpoints() { String priority2 = CLUSTER2 + "[priority2]"; String priority3 = CLUSTER1 + "[priority1]"; - // First deliver CLUSTER2's endpoints, two priorities with each has one locality. + // CLUSTER2: locality1 with priority 1 and locality3 with priority 2. xdsClient.deliverClusterLoadAssignment( EDS_SERVICE_NAME2, ImmutableMap.of(locality1, localityLbEndpoints1, locality3, localityLbEndpoints3)); + assertThat(childBalancers).isEmpty(); // not created until all clusters resolved + + // CLUSTER1: locality2 with priority 1. + xdsClient.deliverClusterLoadAssignment( + EDS_SERVICE_NAME1, Collections.singletonMap(locality2, localityLbEndpoints2)); + + // Endpoints of all clusters have been resolved. assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLbConfig priorityLbConfig = (PriorityLbConfig) childBalancer.config; - assertThat(priorityLbConfig.priorities).containsExactly(priority1, priority2).inOrder(); - PriorityChildConfig priorityChildConfig = priorityLbConfig.childConfigs.get(priority1); - assertThat(priorityChildConfig.ignoreReresolution).isTrue(); - assertThat(priorityChildConfig.policySelection.getProvider().getPolicyName()) + assertThat(priorityLbConfig.priorities) + .containsExactly(priority3, priority1, priority2).inOrder(); + + PriorityChildConfig priorityChildConfig1 = priorityLbConfig.childConfigs.get(priority1); + assertThat(priorityChildConfig1.ignoreReresolution).isTrue(); + assertThat(priorityChildConfig1.policySelection.getProvider().getPolicyName()) .isEqualTo(CLUSTER_IMPL_POLICY_NAME); - ClusterImplConfig clusterImplConfig = - (ClusterImplConfig) priorityChildConfig.policySelection.getConfig(); - assertClusterImplConfig(clusterImplConfig, CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_NAME, 200L, + ClusterImplConfig clusterImplConfig1 = + (ClusterImplConfig) priorityChildConfig1.policySelection.getConfig(); + assertClusterImplConfig(clusterImplConfig1, CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_NAME, 200L, tlsContext, Collections.emptyList(), WEIGHTED_TARGET_POLICY_NAME); - WeightedTargetConfig weightedTargetConfig = - (WeightedTargetConfig) clusterImplConfig.childPolicy.getConfig(); - assertThat(weightedTargetConfig.targets.keySet()).containsExactly(locality1.toString()); - WeightedPolicySelection target = weightedTargetConfig.targets.get(locality1.toString()); - assertThat(target.weight).isEqualTo(70); - assertThat(target.policySelection.getProvider().getPolicyName()).isEqualTo("round_robin"); - - priorityChildConfig = priorityLbConfig.childConfigs.get(priority2); - assertThat(priorityChildConfig.ignoreReresolution).isTrue(); - assertThat(priorityChildConfig.policySelection.getProvider().getPolicyName()) + WeightedTargetConfig weightedTargetConfig1 = + (WeightedTargetConfig) clusterImplConfig1.childPolicy.getConfig(); + assertThat(weightedTargetConfig1.targets.keySet()).containsExactly(locality1.toString()); + WeightedPolicySelection target1 = weightedTargetConfig1.targets.get(locality1.toString()); + assertThat(target1.weight).isEqualTo(70); + assertThat(target1.policySelection.getProvider().getPolicyName()).isEqualTo("round_robin"); + + PriorityChildConfig priorityChildConfig2 = priorityLbConfig.childConfigs.get(priority2); + assertThat(priorityChildConfig2.ignoreReresolution).isTrue(); + assertThat(priorityChildConfig2.policySelection.getProvider().getPolicyName()) .isEqualTo(CLUSTER_IMPL_POLICY_NAME); - clusterImplConfig = (ClusterImplConfig) priorityChildConfig.policySelection.getConfig(); - assertClusterImplConfig(clusterImplConfig, CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_NAME, 200L, + ClusterImplConfig clusterImplConfig2 = + (ClusterImplConfig) priorityChildConfig2.policySelection.getConfig(); + assertClusterImplConfig(clusterImplConfig2, CLUSTER2, EDS_SERVICE_NAME2, LRS_SERVER_NAME, 200L, tlsContext, Collections.emptyList(), WEIGHTED_TARGET_POLICY_NAME); - weightedTargetConfig = (WeightedTargetConfig) clusterImplConfig.childPolicy.getConfig(); - assertThat(weightedTargetConfig.targets.keySet()).containsExactly(locality3.toString()); - target = weightedTargetConfig.targets.get(locality3.toString()); - assertThat(target.weight).isEqualTo(20); - assertThat(target.policySelection.getProvider().getPolicyName()).isEqualTo("round_robin"); + WeightedTargetConfig weightedTargetConfig2 = + (WeightedTargetConfig) clusterImplConfig2.childPolicy.getConfig(); + assertThat(weightedTargetConfig2.targets.keySet()).containsExactly(locality3.toString()); + WeightedPolicySelection target2 = weightedTargetConfig2.targets.get(locality3.toString()); + assertThat(target2.weight).isEqualTo(20); + assertThat(target2.policySelection.getProvider().getPolicyName()).isEqualTo("round_robin"); List priorityAddrs1 = AddressFilter.filter(childBalancer.addresses, priority1); assertThat(priorityAddrs1).hasSize(2); @@ -335,26 +355,20 @@ public void onlyEdsClusters_receivedEndpoints() { assertThat(priorityAddrs2).hasSize(1); assertAddressesEqual(Collections.singletonList(endpoint4), priorityAddrs2); - // Then deliver CLUSTER1's endpoints, one priority with one locality. - xdsClient.deliverClusterLoadAssignment( - EDS_SERVICE_NAME1, Collections.singletonMap(locality2, localityLbEndpoints2)); - - priorityLbConfig = (PriorityLbConfig) childBalancer.config; - assertThat(priorityLbConfig.priorities) - .containsExactly(priority3, priority1, priority2).inOrder(); - - priorityChildConfig = priorityLbConfig.childConfigs.get(priority3); - assertThat(priorityChildConfig.ignoreReresolution).isTrue(); - assertThat(priorityChildConfig.policySelection.getProvider().getPolicyName()) + PriorityChildConfig priorityChildConfig3 = priorityLbConfig.childConfigs.get(priority3); + assertThat(priorityChildConfig3.ignoreReresolution).isTrue(); + assertThat(priorityChildConfig3.policySelection.getProvider().getPolicyName()) .isEqualTo(CLUSTER_IMPL_POLICY_NAME); - clusterImplConfig = (ClusterImplConfig) priorityChildConfig.policySelection.getConfig(); - assertClusterImplConfig(clusterImplConfig, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_NAME, 100L, + ClusterImplConfig clusterImplConfig3 = + (ClusterImplConfig) priorityChildConfig3.policySelection.getConfig(); + assertClusterImplConfig(clusterImplConfig3, CLUSTER1, EDS_SERVICE_NAME1, LRS_SERVER_NAME, 100L, tlsContext, Collections.emptyList(), WEIGHTED_TARGET_POLICY_NAME); - weightedTargetConfig = (WeightedTargetConfig) clusterImplConfig.childPolicy.getConfig(); - assertThat(weightedTargetConfig.targets.keySet()).containsExactly(locality2.toString()); - target = weightedTargetConfig.targets.get(locality2.toString()); - assertThat(target.weight).isEqualTo(10); - assertThat(target.policySelection.getProvider().getPolicyName()).isEqualTo("round_robin"); + WeightedTargetConfig weightedTargetConfig3 = + (WeightedTargetConfig) clusterImplConfig3.childPolicy.getConfig(); + assertThat(weightedTargetConfig3.targets.keySet()).containsExactly(locality2.toString()); + WeightedPolicySelection target3 = weightedTargetConfig3.targets.get(locality2.toString()); + assertThat(target3.weight).isEqualTo(10); + assertThat(target3.policySelection.getProvider().getPolicyName()).isEqualTo("round_robin"); List priorityAddrs3 = AddressFilter.filter(childBalancer.addresses, priority3); assertThat(priorityAddrs3).hasSize(1); @@ -370,16 +384,17 @@ public void onlyEdsClusters_resourceNeverExist_returnErrorPicker() { assertThat(childBalancers).isEmpty(); reset(helper); xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME1); - verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); - PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); - assertThat(result.getStatus().isOk()).isTrue(); - assertThat(result.getSubchannel()).isNull(); // buffer picker expected + verify(helper, never()).updateBalancingState( + any(ConnectivityState.class), any(SubchannelPicker.class)); // wait for CLUSTER2's results xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME2); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status expectedError = Status.UNAVAILABLE.withDescription("No usable endpoint"); - assertPicker(pickerCaptor.getValue(), expectedError, null); + assertPicker( + pickerCaptor.getValue(), + Status.UNAVAILABLE.withDescription( + "No usable endpoint from cluster(s): " + Arrays.asList(CLUSTER1, CLUSTER2)), + null); } @Test @@ -413,7 +428,8 @@ public void onlyEdsClusters_allResourcesRevoked_shutDownChildLbPolicy() { xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME1); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status expectedError = Status.UNAVAILABLE.withDescription("No usable endpoint"); + Status expectedError = Status.UNAVAILABLE.withDescription( + "No usable endpoint from cluster(s): " + Arrays.asList(CLUSTER1, CLUSTER2)); assertPicker(pickerCaptor.getValue(), expectedError, null); } @@ -507,8 +523,11 @@ public void handleEdsResource_noHealthyEndpoint() { assertThat(childBalancers).isEmpty(); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), - Status.UNAVAILABLE.withDescription("No usable endpoint"), null); + assertPicker( + pickerCaptor.getValue(), + Status.UNAVAILABLE.withDescription( + "No usable endpoint from cluster(s): " + Collections.singleton(CLUSTER1)), + null); } @Test @@ -516,11 +535,10 @@ public void onlyLogicalDnsCluster_endpointsResolved() { ClusterResolverConfig config = new ClusterResolverConfig( Collections.singletonList(logicalDnsDiscoveryMechanism), roundRobin); deliverLbConfig(config); - assertThat(resolvers).hasSize(1); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); assertThat(childBalancers).isEmpty(); EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - FakeNameResolver resolver = Iterables.getOnlyElement(resolvers); resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); assertThat(childBalancers).hasSize(1); @@ -544,11 +562,10 @@ public void onlyLogicalDnsCluster_handleRefreshNameResolution() { ClusterResolverConfig config = new ClusterResolverConfig( Collections.singletonList(logicalDnsDiscoveryMechanism), roundRobin); deliverLbConfig(config); - assertThat(resolvers).hasSize(1); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); assertThat(childBalancers).isEmpty(); EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); - FakeNameResolver resolver = Iterables.getOnlyElement(resolvers); resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); assertThat(resolver.refreshCount).isEqualTo(0); verify(helper).ignoreRefreshNameResolutionCheck(); @@ -564,9 +581,8 @@ public void onlyLogicalDnsCluster_resolutionError_backoffAndRefresh() { ClusterResolverConfig config = new ClusterResolverConfig( Collections.singletonList(logicalDnsDiscoveryMechanism), roundRobin); deliverLbConfig(config); - assertThat(resolvers).hasSize(1); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); assertThat(childBalancers).isEmpty(); - FakeNameResolver resolver = Iterables.getOnlyElement(resolvers); Status error = Status.UNAVAILABLE.withDescription("cannot reach DNS server"); resolver.deliverError(error); inOrder.verify(helper).updateBalancingState( @@ -611,10 +627,9 @@ public void onlyLogicalDnsCluster_refreshNameResolutionRaceWithResolutionError() ClusterResolverConfig config = new ClusterResolverConfig( Collections.singletonList(logicalDnsDiscoveryMechanism), roundRobin); deliverLbConfig(config); - assertThat(resolvers).hasSize(1); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); assertThat(childBalancers).isEmpty(); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr"); - FakeNameResolver resolver = Iterables.getOnlyElement(resolvers); resolver.deliverEndpointAddresses(Collections.singletonList(endpoint)); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertAddressesEqual(Collections.singletonList(endpoint), childBalancer.addresses); @@ -652,12 +667,11 @@ public void edsClustersAndLogicalDnsCluster_receivedEndpoints() { Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); deliverLbConfig(config); assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(resolvers).hasSize(1); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); assertThat(childBalancers).isEmpty(); EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); // DNS endpoint EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); // DNS endpoint EquivalentAddressGroup endpoint3 = makeAddress("endpoint-addr-3"); // EDS endpoint - FakeNameResolver resolver = Iterables.getOnlyElement(resolvers); resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); LocalityLbEndpoints localityLbEndpoints = LocalityLbEndpoints.create( @@ -687,16 +701,13 @@ public void noEdsResourceExists_useDnsResolutionResults() { Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); deliverLbConfig(config); assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(resolvers).hasSize(1); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); assertThat(childBalancers).isEmpty(); reset(helper); xdsClient.deliverResourceNotFound(EDS_SERVICE_NAME1); - verify(helper).updateBalancingState(eq(ConnectivityState.CONNECTING), pickerCaptor.capture()); - PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); - assertThat(result.getStatus().isOk()).isTrue(); - assertThat(result.getSubchannel()).isNull(); // buffer picker expected, waiting for DNS + verify(helper, never()).updateBalancingState( + any(ConnectivityState.class), any(SubchannelPicker.class)); // wait for DNS results - FakeNameResolver resolver = Iterables.getOnlyElement(resolvers); EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); EquivalentAddressGroup endpoint2 = makeAddress("endpoint-addr-2"); resolver.deliverEndpointAddresses(Arrays.asList(endpoint1, endpoint2)); @@ -714,7 +725,7 @@ public void edsResourceRevoked_dnsResolutionError_shutDownChildLbPolicyAndReturn Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); deliverLbConfig(config); assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(resolvers).hasSize(1); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); assertThat(childBalancers).isEmpty(); reset(helper); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); @@ -724,7 +735,6 @@ public void edsResourceRevoked_dnsResolutionError_shutDownChildLbPolicyAndReturn 10 /* localityWeight */, 1 /* priority */); xdsClient.deliverClusterLoadAssignment( EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); - FakeNameResolver resolver = Iterables.getOnlyElement(resolvers); resolver.deliverError(Status.UNKNOWN.withDescription("I am lost")); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); @@ -737,7 +747,7 @@ public void edsResourceRevoked_dnsResolutionError_shutDownChildLbPolicyAndReturn verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); assertPicker(pickerCaptor.getValue(), - Status.UNAVAILABLE.withDescription("No usable endpoint"), null); + Status.UNAVAILABLE.withDescription("I am lost"), null); } @Test @@ -746,7 +756,7 @@ public void resolutionErrorAfterChildLbCreated_propagateErrorIfAllClustersEncoun Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); deliverLbConfig(config); assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(resolvers).hasSize(1); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); assertThat(childBalancers).isEmpty(); reset(helper); EquivalentAddressGroup endpoint = makeAddress("endpoint-addr-1"); @@ -756,10 +766,16 @@ public void resolutionErrorAfterChildLbCreated_propagateErrorIfAllClustersEncoun 10 /* localityWeight */, 1 /* priority */); xdsClient.deliverClusterLoadAssignment( EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); // child LB created - FakeNameResolver resolver = Iterables.getOnlyElement(resolvers); + assertThat(childBalancers).isEmpty(); // not created until all clusters resolved. + resolver.deliverError(Status.UNKNOWN.withDescription("I am lost")); + + // DNS resolution failed, but there are EDS endpoints can be used. + assertThat(childBalancers).hasSize(1); + FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); // child LB created assertThat(childBalancer.upstreamError).isNull(); // should not propagate error to child LB + assertAddressesEqual(Collections.singletonList(endpoint), childBalancer.addresses); + xdsClient.deliverError(Status.RESOURCE_EXHAUSTED.withDescription("out of memory")); assertThat(childBalancer.upstreamError).isNotNull(); // last cluster's (DNS) error propagated assertThat(childBalancer.upstreamError.getCode()).isEqualTo(Code.UNKNOWN); @@ -775,19 +791,21 @@ public void resolutionErrorBeforeChildLbCreated_returnErrorPickerIfAllClustersEn Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); deliverLbConfig(config); assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(resolvers).hasSize(1); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); assertThat(childBalancers).isEmpty(); reset(helper); xdsClient.deliverError(Status.UNIMPLEMENTED.withDescription("not found")); assertThat(childBalancers).isEmpty(); verify(helper, never()).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), any(SubchannelPicker.class)); // wait for DNS - FakeNameResolver resolver = Iterables.getOnlyElement(resolvers); Status dnsError = Status.UNKNOWN.withDescription("I am lost"); resolver.deliverError(dnsError); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), dnsError, null); + assertPicker( + pickerCaptor.getValue(), + Status.UNAVAILABLE.withDescription(dnsError.getDescription()), + null); } @Test @@ -796,7 +814,7 @@ public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_returnErr Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); deliverLbConfig(config); assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(resolvers).hasSize(1); + assertResolverCreated("/" + DNS_HOST_NAME); assertThat(childBalancers).isEmpty(); reset(helper); Status upstreamError = Status.UNAVAILABLE.withDescription("unreachable"); @@ -812,7 +830,7 @@ public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThroug Arrays.asList(edsDiscoveryMechanism1, logicalDnsDiscoveryMechanism), roundRobin); deliverLbConfig(config); assertThat(xdsClient.watchers.keySet()).containsExactly(EDS_SERVICE_NAME1); - assertThat(resolvers).hasSize(1); + FakeNameResolver resolver = assertResolverCreated("/" + DNS_HOST_NAME); assertThat(childBalancers).isEmpty(); reset(helper); EquivalentAddressGroup endpoint1 = makeAddress("endpoint-addr-1"); @@ -823,7 +841,6 @@ public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThroug 10 /* localityWeight */, 1 /* priority */); xdsClient.deliverClusterLoadAssignment( EDS_SERVICE_NAME1, Collections.singletonMap(locality1, localityLbEndpoints)); - FakeNameResolver resolver = Iterables.getOnlyElement(resolvers); resolver.deliverEndpointAddresses(Collections.singletonList(endpoint2)); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); @@ -851,6 +868,13 @@ private void deliverLbConfig(ClusterResolverConfig config) { .build()); } + private FakeNameResolver assertResolverCreated(String uriPath) { + assertThat(resolvers).hasSize(1); + FakeNameResolver resolver = Iterables.getOnlyElement(resolvers); + assertThat(resolver.targetUri.getPath()).isEqualTo(uriPath); + return resolver; + } + private static void assertPicker(SubchannelPicker picker, Status expectedStatus, @Nullable Subchannel expectedSubchannel) { PickResult result = picker.pickSubchannel(mock(PickSubchannelArgs.class)); @@ -964,8 +988,7 @@ private class FakeNameResolverProvider extends NameResolverProvider { @Override public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) { assertThat(targetUri.getScheme()).isEqualTo("dns"); - assertThat(targetUri.getPath()).isEqualTo("/" + AUTHORITY); - FakeNameResolver resolver = new FakeNameResolver(); + FakeNameResolver resolver = new FakeNameResolver(targetUri); resolvers.add(resolver); return resolver; } @@ -987,9 +1010,14 @@ protected int priority() { } private class FakeNameResolver extends NameResolver { + private final URI targetUri; private Listener2 listener; private int refreshCount; + private FakeNameResolver(URI targetUri) { + this.targetUri = targetUri; + } + @Override public String getServiceAuthority() { throw new UnsupportedOperationException("should not be called"); diff --git a/xds/src/test/java/io/grpc/xds/EnvoyServerProtoDataTest.java b/xds/src/test/java/io/grpc/xds/EnvoyServerProtoDataTest.java index d2c0ca39aaa..5e641e8a65a 100644 --- a/xds/src/test/java/io/grpc/xds/EnvoyServerProtoDataTest.java +++ b/xds/src/test/java/io/grpc/xds/EnvoyServerProtoDataTest.java @@ -17,6 +17,7 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; @@ -34,6 +35,8 @@ import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.Listener; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; +import java.util.Arrays; import java.util.List; import org.junit.Test; import org.junit.runner.RunWith; @@ -61,7 +64,7 @@ public void listener_convertFromListenerProto() throws InvalidProtocolBufferExce .setTrafficDirection(TrafficDirection.INBOUND) .build(); - Listener xdsListener = Listener.fromEnvoyProtoListener(listener); + Listener xdsListener = Listener.fromEnvoyProtoListener(listener, mock(TlsContextManager.class)); assertThat(xdsListener.getName()).isEqualTo("8000"); assertThat(xdsListener.getAddress()).isEqualTo("10.2.1.34:8000"); List filterChains = xdsListener.getFilterChains(); @@ -73,7 +76,11 @@ public void listener_convertFromListenerProto() throws InvalidProtocolBufferExce EnvoyServerProtoData.FilterChainMatch inFilterChainMatch = inFilter.getFilterChainMatch(); assertThat(inFilterChainMatch).isNotNull(); assertThat(inFilterChainMatch.getDestinationPort()).isEqualTo(8000); - assertThat(inFilterChainMatch.getApplicationProtocols()).isEmpty(); + assertThat(inFilterChainMatch.getApplicationProtocols()) + .containsExactlyElementsIn(Arrays.asList("managed-mtls", "h2")); + assertThat(inFilterChainMatch.getServerNames()) + .containsExactlyElementsIn(Arrays.asList("server1", "server2")); + assertThat(inFilterChainMatch.getTransportProtocol()).isEqualTo("tls"); assertThat(inFilterChainMatch.getPrefixRanges()) .containsExactly(new EnvoyServerProtoData.CidrRange("10.20.0.15", 32)); assertThat(inFilterChainMatch.getSourcePrefixRanges()) @@ -81,7 +88,11 @@ public void listener_convertFromListenerProto() throws InvalidProtocolBufferExce assertThat(inFilterChainMatch.getConnectionSourceType()) .isEqualTo(EnvoyServerProtoData.ConnectionSourceType.EXTERNAL); assertThat(inFilterChainMatch.getSourcePorts()).containsExactly(200, 300); - DownstreamTlsContext inFilterTlsContext = inFilter.getDownstreamTlsContext(); + SslContextProviderSupplier sslContextProviderSupplier = inFilter + .getSslContextProviderSupplier(); + assertThat(sslContextProviderSupplier.getTlsContext()).isInstanceOf(DownstreamTlsContext.class); + DownstreamTlsContext inFilterTlsContext = (DownstreamTlsContext) sslContextProviderSupplier + .getTlsContext(); assertThat(inFilterTlsContext.getCommonTlsContext()).isNotNull(); CommonTlsContext commonTlsContext = inFilterTlsContext.getCommonTlsContext(); List tlsCertSdsConfigs = commonTlsContext @@ -105,6 +116,9 @@ private static FilterChain createInFilter() { .setFilterChainMatch( FilterChainMatch.newBuilder() .setDestinationPort(UInt32Value.of(8000)) + .addAllServerNames(Arrays.asList("server1", "server2")) + .setTransportProtocol("tls") + .addAllApplicationProtocols(Arrays.asList("managed-mtls", "h2")) .addPrefixRanges(CidrRange.newBuilder() .setAddressPrefix("10.20.0.15") .setPrefixLen(UInt32Value.of(32)) diff --git a/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java b/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java index 086cba7e4b3..dfe80cd3393 100644 --- a/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java +++ b/xds/src/test/java/io/grpc/xds/FilterChainMatchTest.java @@ -18,11 +18,23 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.UInt32Value; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.CidrRange; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.core.v3.TrafficDirection; +import io.envoyproxy.envoy.config.core.v3.TransportSocket; +import io.envoyproxy.envoy.config.listener.v3.Filter; +import io.envoyproxy.envoy.config.listener.v3.FilterChain; +import io.envoyproxy.envoy.config.listener.v3.FilterChainMatch; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; import io.netty.channel.Channel; import java.io.IOException; import java.net.InetAddress; @@ -46,6 +58,7 @@ public class FilterChainMatchTest { private static final String REMOTE_IP = "10.4.2.3"; // source @Mock private Channel channel; + @Mock private TlsContextManager tlsContextManager; private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; private XdsClient.LdsResourceWatcher registeredWatcher; @@ -53,7 +66,8 @@ public class FilterChainMatchTest { @Before public void setUp() throws IOException { MockitoAnnotations.initMocks(this); - xdsClientWrapperForServerSds = XdsServerTestHelper.createXdsClientWrapperForServerSds(PORT); + xdsClientWrapperForServerSds = XdsServerTestHelper + .createXdsClientWrapperForServerSds(PORT, tlsContextManager); registeredWatcher = XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); } @@ -63,6 +77,17 @@ public void tearDown() { xdsClientWrapperForServerSds.shutdown(); } + private DownstreamTlsContext getDownstreamTlsContext() { + SslContextProviderSupplier sslContextProviderSupplier = + xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); + if (sslContextProviderSupplier != null) { + EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); + assertThat(tlsContext).isInstanceOf(DownstreamTlsContext.class); + return (DownstreamTlsContext) tlsContext; + } + return null; + } + @Test public void singleFilterChainWithoutAlpn() throws UnknownHostException { setupChannel(LOCAL_IP, REMOTE_IP, 15000); @@ -73,17 +98,18 @@ public void singleFilterChainWithoutAlpn() throws UnknownHostException { Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); DownstreamTlsContext tlsContext = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); EnvoyServerProtoData.FilterChain filterChain = - new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext); + new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener("listener1", LOCAL_IP, Arrays.asList(filterChain), null); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContext); } @@ -97,18 +123,24 @@ public void singleFilterChainWithAlpn() throws UnknownHostException { Arrays.asList("managed-mtls"), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); DownstreamTlsContext tlsContext = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); EnvoyServerProtoData.FilterChain filterChain = - new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext); + new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext, tlsContextManager); + DownstreamTlsContext defaultTlsContext = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + EnvoyServerProtoData.FilterChain defaultFilterChain = + new EnvoyServerProtoData.FilterChain(null, defaultTlsContext, tlsContextManager); EnvoyServerProtoData.Listener listener = - new EnvoyServerProtoData.Listener("listener1", LOCAL_IP, Arrays.asList(filterChain), null); + new EnvoyServerProtoData.Listener("listener1", LOCAL_IP, Arrays.asList(filterChain), + defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); - assertThat(tlsContext1).isSameInstanceAs(tlsContext); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); + assertThat(tlsContext1).isSameInstanceAs(defaultTlsContext); } @Test @@ -117,14 +149,13 @@ public void defaultFilterChain() throws UnknownHostException { DownstreamTlsContext tlsContext = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); EnvoyServerProtoData.FilterChain filterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContext); + new EnvoyServerProtoData.FilterChain(null, tlsContext, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(), filterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContext); } @@ -140,20 +171,23 @@ public void destPortFails_returnDefaultFilterChain() throws UnknownHostException Arrays.asList("managed-mtls"), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainWithDestPort = - new EnvoyServerProtoData.FilterChain(filterChainMatchWithDestPort, tlsContextWithDestPort); + new EnvoyServerProtoData.FilterChain(filterChainMatchWithDestPort, tlsContextWithDestPort, + tlsContextManager); DownstreamTlsContext tlsContextForDefaultFilterChain = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChainWithDestPort), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); } @@ -169,20 +203,23 @@ public void destPrefixRangeMatch() throws UnknownHostException, InvalidProtocolB Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainWithMatch = - new EnvoyServerProtoData.FilterChain(filterChainMatchWithMatch, tlsContextMatch); + new EnvoyServerProtoData.FilterChain(filterChainMatchWithMatch, tlsContextMatch, + tlsContextManager); DownstreamTlsContext tlsContextForDefaultFilterChain = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChainWithMatch), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMatch); } @@ -200,26 +237,29 @@ public void destPrefixRangeMismatch_returnDefaultFilterChain() Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainWithMismatch = - new EnvoyServerProtoData.FilterChain(filterChainMatchWithMismatch, tlsContextMismatch); + new EnvoyServerProtoData.FilterChain(filterChainMatchWithMismatch, tlsContextMismatch, + tlsContextManager); DownstreamTlsContext tlsContextForDefaultFilterChain = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChainWithMismatch), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); } @Test public void dest0LengthPrefixRange() - throws UnknownHostException, InvalidProtocolBufferException { + throws UnknownHostException, InvalidProtocolBufferException { setupChannel(LOCAL_IP, REMOTE_IP, 15000); DownstreamTlsContext tlsContext0Length = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); @@ -231,20 +271,23 @@ public void dest0LengthPrefixRange() Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChain0Length = - new EnvoyServerProtoData.FilterChain(filterChainMatch0Length, tlsContext0Length); + new EnvoyServerProtoData.FilterChain(filterChainMatch0Length, tlsContext0Length, + tlsContextManager); DownstreamTlsContext tlsContextForDefaultFilterChain = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChain0Length), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContext0Length); } @@ -261,9 +304,12 @@ public void destPrefixRange_moreSpecificWins() Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific, + tlsContextManager); DownstreamTlsContext tlsContextMoreSpecific = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -274,11 +320,14 @@ public void destPrefixRange_moreSpecificWins() Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainMoreSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific, + tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -287,14 +336,13 @@ public void destPrefixRange_moreSpecificWins() defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); } @Test public void destPrefixRange_emptyListLessSpecific() - throws UnknownHostException, InvalidProtocolBufferException { + throws UnknownHostException, InvalidProtocolBufferException { setupChannel(LOCAL_IP, REMOTE_IP, 15000); DownstreamTlsContext tlsContextLessSpecific = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); @@ -305,9 +353,12 @@ public void destPrefixRange_emptyListLessSpecific() Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific, + tlsContextManager); DownstreamTlsContext tlsContextMoreSpecific = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -318,11 +369,14 @@ public void destPrefixRange_emptyListLessSpecific() Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainMoreSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific, + tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -331,8 +385,7 @@ public void destPrefixRange_emptyListLessSpecific() defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); } @@ -349,9 +402,12 @@ public void destPrefixRangeIpv6_moreSpecificWins() Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific, + tlsContextManager); DownstreamTlsContext tlsContextMoreSpecific = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -362,11 +418,14 @@ public void destPrefixRangeIpv6_moreSpecificWins() Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainMoreSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchMoreSpecific, tlsContextMoreSpecific, + tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -375,8 +434,7 @@ public void destPrefixRangeIpv6_moreSpecificWins() defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecific); } @@ -395,10 +453,12 @@ public void destPrefixRange_moreSpecificWith2Wins() Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = new EnvoyServerProtoData.FilterChain( - filterChainMatchMoreSpecificWith2, tlsContextMoreSpecificWith2); + filterChainMatchMoreSpecificWith2, tlsContextMoreSpecificWith2, tlsContextManager); DownstreamTlsContext tlsContextLessSpecific = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -409,11 +469,14 @@ public void destPrefixRange_moreSpecificWith2Wins() Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific, + tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -422,8 +485,7 @@ public void destPrefixRange_moreSpecificWith2Wins() defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecificWith2); } @@ -439,20 +501,23 @@ public void sourceTypeMismatch_returnDefaultFilterChain() throws UnknownHostExce Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainWithMismatch = - new EnvoyServerProtoData.FilterChain(filterChainMatchWithMismatch, tlsContextMismatch); + new EnvoyServerProtoData.FilterChain(filterChainMatchWithMismatch, tlsContextMismatch, + tlsContextManager); DownstreamTlsContext tlsContextForDefaultFilterChain = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChainWithMismatch), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextForDefaultFilterChain); } @@ -468,20 +533,23 @@ public void sourceTypeLocal() throws UnknownHostException { Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainWithMatch = - new EnvoyServerProtoData.FilterChain(filterChainMatchWithMatch, tlsContextMatch); + new EnvoyServerProtoData.FilterChain(filterChainMatchWithMatch, tlsContextMatch, + tlsContextManager); DownstreamTlsContext tlsContextForDefaultFilterChain = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChainWithMatch), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMatch); } @@ -500,10 +568,12 @@ public void sourcePrefixRange_moreSpecificWith2Wins() new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), new EnvoyServerProtoData.CidrRange(REMOTE_IP, 32)), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainMoreSpecificWith2 = new EnvoyServerProtoData.FilterChain( - filterChainMatchMoreSpecificWith2, tlsContextMoreSpecificWith2); + filterChainMatchMoreSpecificWith2, tlsContextMoreSpecificWith2, tlsContextManager); DownstreamTlsContext tlsContextLessSpecific = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -514,11 +584,14 @@ public void sourcePrefixRange_moreSpecificWith2Wins() Arrays.asList(), Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainLessSpecific = - new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific); + new EnvoyServerProtoData.FilterChain(filterChainMatchLessSpecific, tlsContextLessSpecific, + tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -527,14 +600,13 @@ public void sourcePrefixRange_moreSpecificWith2Wins() defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextMoreSpecificWith2); } @Test public void sourcePrefixRange_2Matchers_expectException() - throws UnknownHostException, InvalidProtocolBufferException { + throws UnknownHostException, InvalidProtocolBufferException { setupChannel(LOCAL_IP, REMOTE_IP, 15000); DownstreamTlsContext tlsContext1 = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); @@ -547,9 +619,11 @@ public void sourcePrefixRange_2Matchers_expectException() new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), new EnvoyServerProtoData.CidrRange("192.168.10.2", 32)), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChain1 = - new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext1); + new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext1, tlsContextManager); DownstreamTlsContext tlsContext2 = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -560,18 +634,20 @@ public void sourcePrefixRange_2Matchers_expectException() Arrays.asList(), Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.0", 24)), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChain2 = - new EnvoyServerProtoData.FilterChain(filterChainMatch2, tlsContext2); + new EnvoyServerProtoData.FilterChain(filterChainMatch2, tlsContext2, tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, null); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", LOCAL_IP, Arrays.asList(filterChain1, filterChain2), defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); try { - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); fail("expect exception!"); } catch (IllegalStateException ise) { assertThat(ise).hasMessageThat().isEqualTo("Found 2 matching filter-chains"); @@ -593,10 +669,12 @@ public void sourcePortMatch_exactMatchWinsOverEmptyList() new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainEmptySourcePorts = new EnvoyServerProtoData.FilterChain( - filterChainMatchEmptySourcePorts, tlsContextEmptySourcePorts); + filterChainMatchEmptySourcePorts, tlsContextEmptySourcePorts, tlsContextManager); DownstreamTlsContext tlsContextSourcePortMatch = CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); @@ -607,12 +685,14 @@ public void sourcePortMatch_exactMatchWinsOverEmptyList() Arrays.asList(), Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.2", 31)), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(7000, 15000)); + Arrays.asList(7000, 15000), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChainSourcePortMatch = new EnvoyServerProtoData.FilterChain( - filterChainMatchSourcePortMatch, tlsContextSourcePortMatch); + filterChainMatchSourcePortMatch, tlsContextSourcePortMatch, tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -621,8 +701,7 @@ public void sourcePortMatch_exactMatchWinsOverEmptyList() defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext1 = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext1 = getDownstreamTlsContext(); assertThat(tlsContext1).isSameInstanceAs(tlsContextSourcePortMatch); } @@ -657,9 +736,11 @@ public void filterChain_5stepMatch() throws UnknownHostException, InvalidProtoco Arrays.asList(), Arrays.asList(new EnvoyServerProtoData.CidrRange(REMOTE_IP, 32)), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChain1 = - new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext1); + new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext1, tlsContextManager); // next 5 use prefix range: 4 with prefixLen of 30 and last one with 29 @@ -671,9 +752,11 @@ public void filterChain_5stepMatch() throws UnknownHostException, InvalidProtoco Arrays.asList(), Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.0.0", 16)), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChain2 = - new EnvoyServerProtoData.FilterChain(filterChainMatch2, tlsContext2); + new EnvoyServerProtoData.FilterChain(filterChainMatch2, tlsContext2, tlsContextManager); // has prefix ranges with one not matching and source type local: gets eliminated in step 3 EnvoyServerProtoData.FilterChainMatch filterChainMatch3 = @@ -685,9 +768,11 @@ public void filterChain_5stepMatch() throws UnknownHostException, InvalidProtoco Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.SAME_IP_OR_LOOPBACK, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChain3 = - new EnvoyServerProtoData.FilterChain(filterChainMatch3, tlsContext3); + new EnvoyServerProtoData.FilterChain(filterChainMatch3, tlsContext3, tlsContextManager); // has prefix ranges with both matching and source type external but non matching source port: // gets eliminated in step 5 @@ -700,9 +785,11 @@ public void filterChain_5stepMatch() throws UnknownHostException, InvalidProtoco Arrays.asList(), Arrays.asList(new EnvoyServerProtoData.CidrRange("10.4.2.0", 24)), EnvoyServerProtoData.ConnectionSourceType.EXTERNAL, - Arrays.asList(16000, 9000)); + Arrays.asList(16000, 9000), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChain4 = - new EnvoyServerProtoData.FilterChain(filterChainMatch4, tlsContext4); + new EnvoyServerProtoData.FilterChain(filterChainMatch4, tlsContext4, tlsContextManager); // has prefix ranges with both matching and source type external and matching source port: this // gets selected @@ -717,9 +804,11 @@ public void filterChain_5stepMatch() throws UnknownHostException, InvalidProtoco new EnvoyServerProtoData.CidrRange("10.4.2.0", 24), new EnvoyServerProtoData.CidrRange("192.168.2.0", 24)), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList(15000, 8000)); + Arrays.asList(15000, 8000), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChain5 = - new EnvoyServerProtoData.FilterChain(filterChainMatch5, tlsContext5); + new EnvoyServerProtoData.FilterChain(filterChainMatch5, tlsContext5, tlsContextManager); // has prefix range with prefixLen of 29: gets eliminated in step 2 EnvoyServerProtoData.FilterChainMatch filterChainMatch6 = @@ -729,12 +818,14 @@ public void filterChain_5stepMatch() throws UnknownHostException, InvalidProtoco Arrays.asList(), Arrays.asList(), EnvoyServerProtoData.ConnectionSourceType.ANY, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChain6 = - new EnvoyServerProtoData.FilterChain(filterChainMatch6, tlsContext6); + new EnvoyServerProtoData.FilterChain(filterChainMatch6, tlsContext6, tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, null, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( "listener1", @@ -744,11 +835,112 @@ public void filterChain_5stepMatch() throws UnknownHostException, InvalidProtoco defaultFilterChain); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContextPicked = - xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContextPicked = getDownstreamTlsContext(); assertThat(tlsContextPicked).isSameInstanceAs(tlsContext5); } + @Test + public void filterChainMatch_unsupportedMatchers() + throws InvalidProtocolBufferException, UnknownHostException { + setupChannel(LOCAL_IP, REMOTE_IP, 15000); + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext tlsContext1 = + CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( + "CERT1", "ROOTCA"); + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext tlsContext2 = + CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( + "CERT2", "ROOTCA"); + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext tlsContext3 = + CommonTlsContextTestsUtil.buildTestDownstreamTlsContext( + "CERT3", "ROOTCA"); + + FilterChainMatch filterChainMatch1 = + FilterChainMatch.newBuilder() + .addAllServerNames(Arrays.asList("server1", "server2")) + .setTransportProtocol("tls") + .addAllApplicationProtocols(Arrays.asList("managed-mtls", "h2")) + .addPrefixRanges(CidrRange.newBuilder() + .setAddressPrefix("10.1.0.0") + .setPrefixLen(UInt32Value.of(16)) + .build()) + .build(); + + FilterChainMatch filterChainMatch2 = + FilterChainMatch.newBuilder() + .addPrefixRanges(CidrRange.newBuilder() + .setAddressPrefix("10.0.0.0") + .setPrefixLen(UInt32Value.of(8)) + .build()) + .build(); + + FilterChain filterChain1 = + FilterChain.newBuilder() + .setFilterChainMatch(filterChainMatch1) + .setTransportSocket(TransportSocket.newBuilder().setName("envoy.transport_sockets.tls") + .setTypedConfig(Any.pack(tlsContext1)) + .build()) + .addFilters(Filter.newBuilder() + .setName("envoy.http_connection_manager") + .setTypedConfig(Any.newBuilder() + .setTypeUrl( + "type.googleapis.com/" + + "envoy.extensions.filters.network.http_connection_manager" + + ".v3.HttpConnectionManager")) + .build()) + .build(); + FilterChain filterChain2 = + FilterChain.newBuilder() + .setFilterChainMatch(filterChainMatch2) + .setTransportSocket(TransportSocket.newBuilder().setName("envoy.transport_sockets.tls") + .setTypedConfig(Any.pack(tlsContext2)) + .build()) + .addFilters(Filter.newBuilder() + .setName("envoy.http_connection_manager") + .setTypedConfig(Any.newBuilder() + .setTypeUrl( + "type.googleapis.com/" + + "envoy.extensions.filters.network.http_connection_manager" + + ".v3.HttpConnectionManager")) + .build()) + .build(); + FilterChain defaultFilterChain = + FilterChain.newBuilder() + .setTransportSocket(TransportSocket.newBuilder().setName("envoy.transport_sockets.tls") + .setTypedConfig(Any.pack(tlsContext3)) + .build()) + .addFilters(Filter.newBuilder() + .setName("envoy.http_connection_manager") + .setTypedConfig(Any.newBuilder() + .setTypeUrl( + "type.googleapis.com/" + + "envoy.extensions.filters.network.http_connection_manager" + + ".v3.HttpConnectionManager")) + .build()) + .build(); + Address address = + Address.newBuilder() + .setSocketAddress( + SocketAddress.newBuilder().setPortValue(8000).setAddress("10.2.1.34").build()) + .build(); + io.envoyproxy.envoy.config.listener.v3.Listener listener = + io.envoyproxy.envoy.config.listener.v3.Listener.newBuilder() + .setName("8000") + .setAddress(address) + .addFilterChains(filterChain1) + .addFilterChains(filterChain2) + .setDefaultFilterChain(defaultFilterChain) + .setTrafficDirection(TrafficDirection.INBOUND) + .build(); + + EnvoyServerProtoData.Listener xdsListener = EnvoyServerProtoData.Listener + .fromEnvoyProtoListener(listener, mock(TlsContextManager.class)); + XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(xdsListener); + registeredWatcher.onChanged(listenerUpdate); + DownstreamTlsContext tlsContextPicked = getDownstreamTlsContext(); + // assert defaultFilterChain match + assertThat(tlsContextPicked.getCommonTlsContext().getTlsCertificateSdsSecretConfigsList().get(0) + .getName()).isEqualTo("CERT3"); + } + private void setupChannel(String localIp, String remoteIp, int remotePort) throws UnknownHostException { when(channel.localAddress()) diff --git a/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java b/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java index 9df8953485a..07e957b24c4 100644 --- a/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/GoogleCloudToProdNameResolverTest.java @@ -47,6 +47,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; @@ -103,6 +104,8 @@ public void close(Executor instance) {} @Mock private NameResolver.Listener2 mockListener; + @Mock + private Random mockRandom; @Captor private ArgumentCaptor errorCaptor; private boolean originalIsOnGcp; @@ -111,6 +114,7 @@ public void close(Executor instance) {} @Before public void setUp() { + when(mockRandom.nextInt()).thenReturn(123456789); nsRegistry.register(new FakeNsProvider("dns")); nsRegistry.register(new FakeNsProvider("xds")); originalIsOnGcp = GoogleCloudToProdNameResolver.isOnGcp; @@ -142,7 +146,8 @@ public HttpURLConnection createConnection(String url) throws IOException { } }; resolver = new GoogleCloudToProdNameResolver( - TARGET_URI, args, fakeExecutorResource, fakeXdsClientPoolFactory, nsRegistry.asFactory()); + TARGET_URI, args, fakeExecutorResource, mockRandom, fakeXdsClientPoolFactory, + nsRegistry.asFactory()); resolver.setHttpConnectionProvider(httpConnections); } @@ -178,7 +183,8 @@ public void onGcpAndNoProvidedBootstrapDelegateToXds() { Map bootstrap = fakeXdsClientPoolFactory.bootstrapRef.get(); Map node = (Map) bootstrap.get("node"); assertThat(node).containsExactly( - "id", "C2P", "locality", ImmutableMap.of("zone", ZONE), + "id", "C2P-123456789", + "locality", ImmutableMap.of("zone", ZONE), "metadata", ImmutableMap.of("TRAFFICDIRECTOR_DIRECTPATH_C2P_IPV6_CAPABLE", true)); Map server = Iterables.getOnlyElement( (List>) bootstrap.get("xds_servers")); diff --git a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java index cc46d572020..53952f89478 100644 --- a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java @@ -172,8 +172,8 @@ public void cancelled(Context context) { when(backoffPolicy2.nextBackoffNanos()) .thenReturn(TimeUnit.SECONDS.toNanos(2L), TimeUnit.SECONDS.toNanos(20L)); addFakeStatsData(); - lrsClient = new LoadReportClient(loadStatsManager, channel, false, NODE, syncContext, - fakeClock.getScheduledExecutorService(), backoffPolicyProvider, + lrsClient = new LoadReportClient(loadStatsManager, channel, Context.ROOT, false, NODE, + syncContext, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, fakeClock.getStopwatchSupplier()); syncContext.execute(new Runnable() { @Override diff --git a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java new file mode 100644 index 00000000000..c5fe7b3d1bd --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java @@ -0,0 +1,343 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import com.google.api.expr.v1alpha1.Expr; +import com.google.protobuf.Any; +import com.google.protobuf.Message; +import com.google.protobuf.UInt32Value; +import io.envoyproxy.envoy.config.core.v3.CidrRange; +import io.envoyproxy.envoy.config.rbac.v3.Permission; +import io.envoyproxy.envoy.config.rbac.v3.Policy; +import io.envoyproxy.envoy.config.rbac.v3.Principal; +import io.envoyproxy.envoy.config.rbac.v3.Principal.Authenticated; +import io.envoyproxy.envoy.config.rbac.v3.RBAC; +import io.envoyproxy.envoy.config.rbac.v3.RBAC.Action; +import io.envoyproxy.envoy.config.route.v3.HeaderMatcher; +import io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBACPerRoute; +import io.envoyproxy.envoy.type.matcher.v3.MetadataMatcher; +import io.envoyproxy.envoy.type.matcher.v3.PathMatcher; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.Status; +import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.Filter.ConfigOrError; +import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AlwaysTrueMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthConfig; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthDecision; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.DestinationPortMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.OrMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.PolicyMatcher; +import java.net.InetSocketAddress; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import javax.net.ssl.SSLSession; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; + +/** Tests for {@link RbacFilter}. */ +@RunWith(JUnit4.class) +public class RbacFilterTest { + private static final String PATH = "auth"; + private static final StringMatcher STRING_MATCHER = + StringMatcher.newBuilder().setExact("/" + PATH).setIgnoreCase(true).build(); + + @Test + @SuppressWarnings({"unchecked", "deprecation"}) + public void ipPortParser() { + CidrRange cidrRange = CidrRange.newBuilder().setAddressPrefix("10.10.10.0") + .setPrefixLen(UInt32Value.of(24)).build(); + List permissionList = Arrays.asList( + Permission.newBuilder().setAndRules(Permission.Set.newBuilder() + .addRules(Permission.newBuilder().setDestinationIp(cidrRange).build()) + .addRules(Permission.newBuilder().setDestinationPort(9090).build()).build() + ).build()); + List principalList = Arrays.asList( + Principal.newBuilder().setAndIds(Principal.Set.newBuilder() + .addIds(Principal.newBuilder().setDirectRemoteIp(cidrRange).build()) + .addIds(Principal.newBuilder().setRemoteIp(cidrRange).build()) + .addIds(Principal.newBuilder().setSourceIp(cidrRange).build()) + .build()).build()); + ConfigOrError result = parseRaw(permissionList, principalList); + assertThat(result.errorDetail).isNull(); + ServerCall serverCall = mock(ServerCall.class); + Attributes attributes = Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, new InetSocketAddress("10.10.10.0", 1)) + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, new InetSocketAddress("10.10.10.0",9090)) + .build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(method().build()); + GrpcAuthorizationEngine engine = + new GrpcAuthorizationEngine(((RbacConfig)result.config).authConfig()); + AuthDecision decision = engine.evaluate(new Metadata(), serverCall); + assertThat(decision.decision()).isEqualTo(GrpcAuthorizationEngine.Action.DENY); + } + + @Test + @SuppressWarnings("unchecked") + public void pathParser() { + PathMatcher pathMatcher = PathMatcher.newBuilder().setPath(STRING_MATCHER).build(); + List permissionList = Arrays.asList( + Permission.newBuilder().setUrlPath(pathMatcher).build()); + List principalList = Arrays.asList( + Principal.newBuilder().setUrlPath(pathMatcher).build()); + ConfigOrError result = parse(permissionList, principalList); + assertThat(result.errorDetail).isNull(); + ServerCall serverCall = mock(ServerCall.class); + when(serverCall.getMethodDescriptor()).thenReturn(method().build()); + GrpcAuthorizationEngine engine = + new GrpcAuthorizationEngine(result.config.authConfig()); + AuthDecision decision = engine.evaluate(new Metadata(), serverCall); + assertThat(decision.decision()).isEqualTo(GrpcAuthorizationEngine.Action.DENY); + } + + @Test + @SuppressWarnings("unchecked") + public void authenticatedParser() throws Exception { + List permissionList = Arrays.asList( + Permission.newBuilder().setNotRule( + Permission.newBuilder().setRequestedServerName(STRING_MATCHER).build()).build()); + List principalList = Arrays.asList( + Principal.newBuilder().setAuthenticated(Authenticated.newBuilder() + .setPrincipalName(STRING_MATCHER).build()).build()); + ConfigOrError result = parse(permissionList, principalList); + assertThat(result.errorDetail).isNull(); + SSLSession sslSession = mock(SSLSession.class); + X509Certificate mockCert = mock(X509Certificate.class); + when(sslSession.getPeerCertificates()).thenReturn(new X509Certificate[]{mockCert}); + when(mockCert.getSubjectAlternativeNames()).thenReturn( + Arrays.>asList(Arrays.asList(2, "/" + PATH))); + Attributes attributes = Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession) + .build(); + ServerCall serverCall = mock(ServerCall.class); + when(serverCall.getAttributes()).thenReturn(attributes); + GrpcAuthorizationEngine engine = + new GrpcAuthorizationEngine(((RbacConfig)result.config).authConfig()); + AuthDecision decision = engine.evaluate(new Metadata(), serverCall); + assertThat(decision.decision()).isEqualTo(GrpcAuthorizationEngine.Action.DENY); + } + + @Test + @SuppressWarnings("unchecked") + public void headerParser() { + HeaderMatcher headerMatcher = HeaderMatcher.newBuilder() + .setName("party").setExactMatch("win").build(); + List permissionList = Arrays.asList( + Permission.newBuilder().setHeader(headerMatcher).build()); + List principalList = Arrays.asList( + Principal.newBuilder().setHeader(headerMatcher).build()); + ConfigOrError result = parseOverride(permissionList, principalList); + assertThat(result.errorDetail).isNull(); + ServerCall serverCall = mock(ServerCall.class); + GrpcAuthorizationEngine engine = + new GrpcAuthorizationEngine(result.config.authConfig()); + AuthDecision decision = engine.evaluate(metadata("party", "win"), serverCall); + assertThat(decision.decision()).isEqualTo(GrpcAuthorizationEngine.Action.DENY); + } + + @Test + @SuppressWarnings("unchecked") + public void compositeRules() { + MetadataMatcher metadataMatcher = MetadataMatcher.newBuilder().build(); + List permissionList = Arrays.asList( + Permission.newBuilder().setOrRules(Permission.Set.newBuilder().addRules( + Permission.newBuilder().setMetadata(metadataMatcher).build() + ).build()).build()); + List principalList = Arrays.asList( + Principal.newBuilder().setNotId( + Principal.newBuilder().setMetadata(metadataMatcher).build() + ).build()); + ConfigOrError result = parse(permissionList, principalList); + assertThat(result.errorDetail).isNull(); + assertThat(result.config).isInstanceOf(RbacConfig.class); + ServerCall serverCall = mock(ServerCall.class); + GrpcAuthorizationEngine engine = + new GrpcAuthorizationEngine(((RbacConfig)result.config).authConfig()); + AuthDecision decision = engine.evaluate(new Metadata(), serverCall); + assertThat(decision.decision()).isEqualTo(GrpcAuthorizationEngine.Action.ALLOW); + } + + @SuppressWarnings("unchecked") + @Test + public void testAuthorizationInterceptor() { + ServerCallHandler mockHandler = mock(ServerCallHandler.class); + ServerCall mockServerCall = mock(ServerCall.class); + Attributes attr = Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, new InetSocketAddress("1::", 20)) + .build(); + when(mockServerCall.getAttributes()).thenReturn(attr); + PolicyMatcher policyMatcher = new PolicyMatcher("policy-matcher", + OrMatcher.create(new DestinationPortMatcher(99999)), + OrMatcher.create(AlwaysTrueMatcher.INSTANCE)); + AuthConfig authconfig = new AuthConfig(Collections.singletonList(policyMatcher), + GrpcAuthorizationEngine.Action.ALLOW); + new RbacFilter().buildServerInterceptor(RbacConfig.create(authconfig), null) + .interceptCall(mockServerCall, new Metadata(), mockHandler); + verify(mockHandler, never()).startCall(eq(mockServerCall), any(Metadata.class)); + ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); + verify(mockServerCall).close(captor.capture(), any(Metadata.class)); + assertThat(captor.getValue().getCode()).isEqualTo(Status.UNAUTHENTICATED.getCode()); + verify(mockServerCall).getAttributes(); + verifyNoMoreInteractions(mockServerCall); + + authconfig = new AuthConfig(Collections.singletonList(policyMatcher), + GrpcAuthorizationEngine.Action.DENY); + new RbacFilter().buildServerInterceptor(RbacConfig.create(authconfig), null) + .interceptCall(mockServerCall, new Metadata(), mockHandler); + verify(mockHandler).startCall(eq(mockServerCall), any(Metadata.class)); + } + + @Test + public void handleException() { + PathMatcher pathMatcher = PathMatcher.newBuilder() + .setPath(StringMatcher.newBuilder().build()).build(); + List permissionList = Arrays.asList( + Permission.newBuilder().setUrlPath(pathMatcher).build()); + List principalList = Arrays.asList( + Principal.newBuilder().setUrlPath(pathMatcher).build()); + ConfigOrError result = parse(permissionList, principalList); + assertThat(result.errorDetail).isNotNull(); + + permissionList = Arrays.asList(Permission.newBuilder().build()); + principalList = Arrays.asList(Principal.newBuilder().build()); + result = parse(permissionList, principalList); + assertThat(result.errorDetail).isNotNull(); + + Message rawProto = io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC.newBuilder() + .setRules(RBAC.newBuilder().setAction(Action.DENY) + .putPolicies("policy-name", + Policy.newBuilder().setCondition(Expr.newBuilder().build()).build()) + .build()).build(); + result = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + assertThat(result.errorDetail).isNotNull(); + } + + @Test + @SuppressWarnings("unchecked") + public void overrideConfig() { + ServerCallHandler mockHandler = mock(ServerCallHandler.class); + ServerCall mockServerCall = mock(ServerCall.class); + Attributes attr = Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, new InetSocketAddress("1::", 20)) + .build(); + when(mockServerCall.getAttributes()).thenReturn(attr); + + PolicyMatcher policyMatcher = new PolicyMatcher("policy-matcher", + OrMatcher.create(new DestinationPortMatcher(99999)), + OrMatcher.create(AlwaysTrueMatcher.INSTANCE)); + AuthConfig authconfig = new AuthConfig(Collections.singletonList(policyMatcher), + GrpcAuthorizationEngine.Action.ALLOW); + RbacConfig original = RbacConfig.create(authconfig); + + RBACPerRoute rbacPerRoute = RBACPerRoute.newBuilder().build(); + RbacConfig override = + new RbacFilter().parseFilterConfigOverride(Any.pack(rbacPerRoute)).config; + assertThat(override).isEqualTo(RbacConfig.create(null)); + ServerInterceptor interceptor = new RbacFilter().buildServerInterceptor(original, override); + assertThat(interceptor).isNull(); + + policyMatcher = new PolicyMatcher("policy-matcher-override", + OrMatcher.create(new DestinationPortMatcher(20)), + OrMatcher.create(AlwaysTrueMatcher.INSTANCE)); + authconfig = new AuthConfig(Collections.singletonList(policyMatcher), + GrpcAuthorizationEngine.Action.ALLOW); + override = RbacConfig.create(authconfig); + + new RbacFilter().buildServerInterceptor(original, override) + .interceptCall(mockServerCall, new Metadata(), mockHandler); + verify(mockHandler).startCall(eq(mockServerCall), any(Metadata.class)); + verify(mockServerCall).getAttributes(); + verifyNoMoreInteractions(mockServerCall); + } + + @Test + public void ignoredConfig() { + Message rawProto = io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC.newBuilder() + .setRules(RBAC.newBuilder().setAction(Action.LOG) + .putPolicies("policy-name", Policy.newBuilder().build()).build()).build(); + ConfigOrError result = new RbacFilter().parseFilterConfig(Any.pack(rawProto)); + assertThat(result.config).isEqualTo(RbacConfig.create(null)); + } + + private static Metadata metadata(String key, String value) { + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value); + return metadata; + } + + private MethodDescriptor.Builder method() { + return MethodDescriptor.newBuilder() + .setType(MethodType.BIDI_STREAMING) + .setFullMethodName(PATH) + .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) + .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()); + } + + private ConfigOrError parse(List permissionList, + List principalList) { + + return RbacFilter.parseRbacConfig(buildRbac(permissionList, principalList)); + } + + private ConfigOrError parseRaw(List permissionList, + List principalList) { + Message rawProto = buildRbac(permissionList, principalList); + Any proto = Any.pack(rawProto); + return new RbacFilter().parseFilterConfig(proto); + } + + private io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC buildRbac( + List permissionList, List principalList) { + return io.envoyproxy.envoy.extensions.filters.http.rbac.v3.RBAC.newBuilder() + .setRules(RBAC.newBuilder().setAction(Action.DENY) + .putPolicies("policy-name", Policy.newBuilder() + .addAllPermissions(permissionList) + .addAllPrincipals(principalList).build()).build()).build(); + + } + + private ConfigOrError parseOverride(List permissionList, + List principalList) { + RBACPerRoute rbacPerRoute = RBACPerRoute.newBuilder().setRbac( + buildRbac(permissionList, principalList)).build(); + Any proto = Any.pack(rbacPerRoute); + return new RbacFilter().parseFilterConfigOverride(proto); + } +} diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java index 2d7eb4fd59f..052868a2fb1 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerProviderTest.java @@ -82,14 +82,48 @@ public void parseLoadBalancingConfig_valid() throws IOException { } @Test - public void parseLoadBalancingConfig_missingRingSize() throws IOException { - String lbConfig = "{\"minRingSize\" : 10}"; + public void parseLoadBalancingConfig_missingRingSize_useDefaults() throws IOException { + String lbConfig = "{}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getConfig()).isNotNull(); + RingHashConfig config = (RingHashConfig) configOrError.getConfig(); + assertThat(config.minRingSize).isEqualTo(RingHashLoadBalancerProvider.DEFAULT_MIN_RING_SIZE); + assertThat(config.maxRingSize).isEqualTo(RingHashLoadBalancerProvider.DEFAULT_MAX_RING_SIZE); + } + + @Test + public void parseLoadBalancingConfig_invalid_negativeSize() throws IOException { + String lbConfig = "{\"minRingSize\" : -10}"; + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getError()).isNotNull(); + assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(configOrError.getError().getDescription()) + .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); + } + + @Test + public void parseLoadBalancingConfig_invalid_minGreaterThanMax() throws IOException { + String lbConfig = "{\"minRingSize\" : 1000, \"maxRingSize\" : 100}"; ConfigOrError configOrError = provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); assertThat(configOrError.getError()).isNotNull(); assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT); assertThat(configOrError.getError().getDescription()) - .isEqualTo("Missing 'mingRingSize'/'maxRingSize'"); + .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); + } + + @Test + public void parseLoadBalancingConfig_invalid_ringTooLarge() throws IOException { + long ringSize = RingHashLoadBalancerProvider.MAX_RING_SIZE + 1; + String lbConfig = String.format("{\"minRingSize\" : 10, \"maxRingSize\" : %d}", ringSize); + ConfigOrError configOrError = + provider.parseLoadBalancingPolicyConfig(parseJsonObject(lbConfig)); + assertThat(configOrError.getError()).isNotNull(); + assertThat(configOrError.getError().getCode()).isEqualTo(Code.INVALID_ARGUMENT); + assertThat(configOrError.getError().getDescription()) + .isEqualTo("Invalid 'mingRingSize'/'maxRingSize'"); } @Test diff --git a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java index eaa2160c546..5a9bb7ff4a8 100644 --- a/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/RingHashLoadBalancerTest.java @@ -473,6 +473,8 @@ public void skipFailingHosts_pickNextNonFailingHostInFirstTwoHosts() { assertThat(result.getSubchannel()).isNull(); // buffer request verify(subchannels.get(Collections.singletonList(servers.get(2)))) .requestConnection(); // kick off connection to server2 + verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) + .requestConnection(); // no excessive connection deliverSubchannelState( subchannels.get(Collections.singletonList(servers.get(2))), @@ -496,16 +498,15 @@ public void skipFailingHosts_pickNextNonFailingHostInFirstTwoHosts() { @Test public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { // Map each server address to exactly one ring entry. - RingHashConfig config = new RingHashConfig(4, 4); - List servers = createWeightedServerAddrs(1, 1, 1, 1); + RingHashConfig config = new RingHashConfig(3, 3); + List servers = createWeightedServerAddrs(1, 1, 1); loadBalancer.handleResolvedAddresses( ResolvedAddresses.newBuilder() .setAddresses(servers).setLoadBalancingPolicyConfig(config).build()); - verify(helper, times(4)).createSubchannel(any(CreateSubchannelArgs.class)); + verify(helper, times(3)).createSubchannel(any(CreateSubchannelArgs.class)); verify(helper).updateBalancingState(eq(IDLE), any(SubchannelPicker.class)); // initial IDLE reset(helper); // ring: - // "[FakeSocketAddress-server3]_0" // "[FakeSocketAddress-server1]_0" // "[FakeSocketAddress-server0]_0" // "[FakeSocketAddress-server2]_0" @@ -515,7 +516,7 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { TestMethodDescriptors.voidMethod(), new Metadata(), CallOptions.DEFAULT.withOption(XdsNameResolver.RPC_HASH_KEY, rpcHash)); - // Bring down server0 and server2 to force trying other servers. + // Bring down server0 and server2 to force trying server1. deliverSubchannelState( subchannels.get(Collections.singletonList(servers.get(0))), ConnectivityStateInfo.forTransientFailure( @@ -525,20 +526,20 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { ConnectivityStateInfo.forTransientFailure( Status.PERMISSION_DENIED.withDescription("permission denied"))); verify(helper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); + verify(subchannels.get(Collections.singletonList(servers.get(1)))) + .requestConnection(); // LB attempts to recover by itself PickResult result = pickerCaptor.getValue().pickSubchannel(args); assertThat(result.getStatus().isOk()).isFalse(); // fail the RPC assertThat(result.getStatus().getCode()) .isEqualTo(Code.UNAVAILABLE); // with error status for the original server hit by hash assertThat(result.getStatus().getDescription()).isEqualTo("unreachable"); - verify(subchannels.get(Collections.singletonList(servers.get(3)))) + verify(subchannels.get(Collections.singletonList(servers.get(1))), times(2)) .requestConnection(); // kickoff connection to server3 (next first non-failing) - verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) - .requestConnection(); // no excessive connection - // Now connecting to server3. + // Now connecting to server1. deliverSubchannelState( - subchannels.get(Collections.singletonList(servers.get(3))), + subchannels.get(Collections.singletonList(servers.get(1))), ConnectivityStateInfo.forNonError(CONNECTING)); verify(helper, times(2)).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture()); @@ -547,8 +548,6 @@ public void skipFailingHosts_firstTwoHostsFailed_pickNextFirstReady() { assertThat(result.getStatus().getCode()) .isEqualTo(Code.UNAVAILABLE); // with error status for the original server hit by hash assertThat(result.getStatus().getDescription()).isEqualTo("unreachable"); - verify(subchannels.get(Collections.singletonList(servers.get(1))), never()) - .requestConnection(); // no excessive connection (server3 connection already in progress) // Simulate server1 becomes READY. deliverSubchannelState( diff --git a/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java b/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java index bae48dd3a67..c4e888f5439 100644 --- a/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java +++ b/xds/src/test/java/io/grpc/xds/ServerWrapperForXdsTest.java @@ -66,12 +66,15 @@ public class ServerWrapperForXdsTest { private XdsServerBuilder.XdsServingStatusListener mockXdsServingStatusListener; private XdsClient.LdsResourceWatcher listenerWatcher; private Server mockServer; + private TlsContextManager tlsContextManager; @Before public void setUp() throws IOException { port = XdsServerTestHelper.findFreePort(); mockDelegateBuilder = mock(ServerBuilder.class); - xdsClientWrapperForServerSds = XdsServerTestHelper.createXdsClientWrapperForServerSds(port); + tlsContextManager = mock(TlsContextManager.class); + xdsClientWrapperForServerSds = XdsServerTestHelper + .createXdsClientWrapperForServerSds(port, tlsContextManager); mockXdsServingStatusListener = mock(XdsServerBuilder.XdsServingStatusListener.class); listenerWatcher = XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); @@ -116,8 +119,8 @@ public void start() verifyCapturedCodeAndNotServing(Status.Code.ABORTED, ServerWrapperForXds.ServingState.STARTING); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); Throwable exception = future.get(2, TimeUnit.SECONDS); assertThat(exception).isNull(); assertThat(serverWrapperForXds.getCurrentServingState()) @@ -162,8 +165,8 @@ public Server answer(InvocationOnMock invocation) throws Throwable { public void run() { XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), + tlsContextManager); } }).start(); assertThat(settableFutureToSignalStart.get()).isNull(); @@ -196,9 +199,9 @@ public void delegateInitialStartError() @Override public void run() { XdsServerTestHelper.generateListenerUpdate( - listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") - ); + listenerWatcher, + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), + tlsContextManager); } }).start(); Throwable exception = future.get(2, TimeUnit.SECONDS); @@ -241,8 +244,8 @@ public void delegateStartError_shutdown() Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), + tlsContextManager); Throwable exception = future.get(2, TimeUnit.SECONDS); assertThat(exception).isNull(); assertThat(serverWrapperForXds.getCurrentServingState()) @@ -255,8 +258,8 @@ public void delegateStartError_shutdown() when(mockDelegateBuilder.build()).thenReturn(mockServer); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"), + tlsContextManager); Thread.sleep(100L); assertThat(serverWrapperForXds.getCurrentServingState()) .isEqualTo(ServerWrapperForXds.ServingState.SHUTDOWN); @@ -268,8 +271,8 @@ public void shutdownDuringRestart() Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), + tlsContextManager); Throwable exception = future.get(2, TimeUnit.SECONDS); assertThat(exception).isNull(); assertThat(serverWrapperForXds.getCurrentServingState()) @@ -301,8 +304,8 @@ public Server answer(InvocationOnMock invocation) public void run() { XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"), + tlsContextManager); } }).start(); assertThat(settableFutureToSignalStart.get()).isNull(); diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java index fc32cb4e3b5..75d7b76dbbd 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java @@ -19,9 +19,11 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -30,12 +32,15 @@ import io.grpc.inprocess.InProcessSocketAddress; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.sds.SslContextProvider; +import io.grpc.xds.internal.sds.SslContextProviderSupplier; import io.netty.channel.Channel; import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.UnknownHostException; +import java.util.Arrays; import java.util.Collections; import org.junit.After; import org.junit.Before; @@ -53,14 +58,27 @@ public class XdsClientWrapperForServerSdsTestMisc { private static final int PORT = 7000; @Mock private Channel channel; + @Mock private TlsContextManager tlsContextManager; + @Mock private XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher; private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; private XdsClient.LdsResourceWatcher registeredWatcher; + private InetSocketAddress localAddress; + private DownstreamTlsContext tlsContext1; + private DownstreamTlsContext tlsContext2; + private DownstreamTlsContext tlsContext3; @Before public void setUp() throws IOException { MockitoAnnotations.initMocks(this); - xdsClientWrapperForServerSds = XdsServerTestHelper.createXdsClientWrapperForServerSds(PORT); + tlsContext1 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); + tlsContext2 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT2", "VA2"); + tlsContext3 = + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT3", "VA3"); + xdsClientWrapperForServerSds = XdsServerTestHelper + .createXdsClientWrapperForServerSds(PORT, tlsContextManager); } @After @@ -72,7 +90,9 @@ public void tearDown() { public void nonInetSocketAddress_expectNull() throws UnknownHostException { registeredWatcher = XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - assertThat(sendListenerUpdate(new InProcessSocketAddress("test1"), null)).isNull(); + assertThat( + sendListenerUpdate(new InProcessSocketAddress("test1"), null, null, tlsContextManager)) + .isNull(); } @Test @@ -82,7 +102,7 @@ public void nonMatchingPort_expectException() throws UnknownHostException { try { InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT + 1); - DownstreamTlsContext unused = sendListenerUpdate(localAddress, null); + sendListenerUpdate(localAddress, null, null, tlsContextManager); fail("exception expected"); } catch (IllegalStateException expected) { assertThat(expected) @@ -113,96 +133,180 @@ public void emptyFilterChain_expectNull() throws UnknownHostException { null); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); - DownstreamTlsContext tlsContext = xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + DownstreamTlsContext tlsContext = getDownstreamTlsContext(); assertThat(tlsContext).isNull(); } @Test - public void registerServerWatcher() throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher = - mock(XdsClientWrapperForServerSds.ServerWatcher.class); - xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher); - InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT); - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - verify(mockServerWatcher, never()) - .onListenerUpdate(); - DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext); - assertThat(returnedTlsContext).isSameInstanceAs(tlsContext); + public void registerServerWatcher_afterListenerUpdate() throws UnknownHostException { + registerWatcherAndCreateListenerUpdate(tlsContext1); verify(mockServerWatcher).onListenerUpdate(); - xdsClientWrapperForServerSds.removeServerWatcher(mockServerWatcher); } @Test - public void registerServerWatcher_afterListenerUpdate() throws UnknownHostException { - registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT); - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext); - assertThat(returnedTlsContext).isSameInstanceAs(tlsContext); - XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher = - mock(XdsClientWrapperForServerSds.ServerWatcher.class); - xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher); - verify(mockServerWatcher).onListenerUpdate(); + public void registerServerWatcher_notifyNotFound() throws UnknownHostException { + commonErrorCheck(true, Status.NOT_FOUND, true); } @Test - public void registerServerWatcher_notifyError() throws UnknownHostException { + public void registerServerWatcher_notifyInternalError() throws UnknownHostException { + commonErrorCheck(false, Status.INTERNAL, false); + } + + @Test + public void registerServerWatcher_notifyPermDeniedError() throws UnknownHostException { + commonErrorCheck(false, Status.PERMISSION_DENIED, true); + } + + @Test + public void releaseOldSupplierOnChanged_noCloseDueToLazyLoading() throws UnknownHostException { + registerWatcherAndCreateListenerUpdate(tlsContext1); + XdsServerTestHelper.generateListenerUpdate(registeredWatcher, tlsContext2, tlsContextManager); + verify(tlsContextManager, never()) + .findOrCreateServerSslContextProvider(any(DownstreamTlsContext.class)); + } + + @Test + public void releaseOldSupplierOnChangedOnShutdown_verifyClose() throws UnknownHostException { + SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); + when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) + .thenReturn(sslContextProvider1); + registerWatcherAndCreateListenerUpdate(tlsContext1); + callUpdateSslContext(channel); + XdsServerTestHelper + .generateListenerUpdate(registeredWatcher, Arrays.asList(1234), tlsContext2, + tlsContext3, tlsContextManager); + verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); + reset(tlsContextManager); + SslContextProvider sslContextProvider2 = mock(SslContextProvider.class); + when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext2))) + .thenReturn(sslContextProvider2); + SslContextProvider sslContextProvider3 = mock(SslContextProvider.class); + when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext3))) + .thenReturn(sslContextProvider3); + callUpdateSslContext(channel); + InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); + InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1111); + when(channel.remoteAddress()).thenReturn(remoteAddress); + callUpdateSslContext(channel); + XdsClient mockXdsClient = xdsClientWrapperForServerSds.getXdsClient(); + xdsClientWrapperForServerSds.shutdown(); + verify(mockXdsClient, times(1)) + .cancelLdsResourceWatch(eq("grpc/server?udpa.resource.listening_address=0.0.0.0:" + PORT), + eq(registeredWatcher)); + verify(tlsContextManager, never()).releaseServerSslContextProvider(eq(sslContextProvider1)); + verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider2)); + verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider3)); + } + + @Test + public void releaseOldSupplierOnNotFound_verifyClose() throws UnknownHostException { + SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); + when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) + .thenReturn(sslContextProvider1); + registerWatcherAndCreateListenerUpdate(tlsContext1); + callUpdateSslContext(channel); + registeredWatcher.onResourceDoesNotExist("not-found Error"); + verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); + } + + @Test + public void releaseOldSupplierOnPermDeniedError_verifyClose() throws UnknownHostException { + SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); + when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) + .thenReturn(sslContextProvider1); + registerWatcherAndCreateListenerUpdate(tlsContext1); + callUpdateSslContext(channel); + registeredWatcher.onError(Status.PERMISSION_DENIED); + verify(tlsContextManager, times(1)).releaseServerSslContextProvider(eq(sslContextProvider1)); + } + + @Test + public void releaseOldSupplierOnInternalError_noClose() throws UnknownHostException { + SslContextProvider sslContextProvider1 = mock(SslContextProvider.class); + when(tlsContextManager.findOrCreateServerSslContextProvider(eq(tlsContext1))) + .thenReturn(sslContextProvider1); + registerWatcherAndCreateListenerUpdate(tlsContext1); + callUpdateSslContext(channel); + registeredWatcher.onError(Status.INTERNAL); + verify(tlsContextManager, never()).releaseServerSslContextProvider(eq(sslContextProvider1)); + } + + private void callUpdateSslContext(Channel channel) { + SslContextProviderSupplier sslContextProviderSupplier = + xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); + assertThat(sslContextProviderSupplier).isNotNull(); + SslContextProvider.Callback callback = mock(SslContextProvider.Callback.class); + sslContextProviderSupplier.updateSslContext(callback); + } + + private void registerWatcherAndCreateListenerUpdate(DownstreamTlsContext tlsContext) + throws UnknownHostException { registeredWatcher = - XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); - XdsClientWrapperForServerSds.ServerWatcher mockServerWatcher = - mock(XdsClientWrapperForServerSds.ServerWatcher.class); + XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); + InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); + localAddress = new InetSocketAddress(ipLocalAddress, PORT); xdsClientWrapperForServerSds.addServerWatcher(mockServerWatcher); - registeredWatcher.onError(Status.INTERNAL); + DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext, null, + tlsContextManager); + assertThat(returnedTlsContext).isSameInstanceAs(tlsContext); + } + + private void commonErrorCheck(boolean generateResourceDoesNotExist, Status status, + boolean isAbsent) throws UnknownHostException { + registerWatcherAndCreateListenerUpdate(tlsContext1); + reset(mockServerWatcher); + if (generateResourceDoesNotExist) { + registeredWatcher.onResourceDoesNotExist("not-found Error"); + } else { + registeredWatcher.onError(status); + } ArgumentCaptor argCaptor = ArgumentCaptor.forClass(null); - verify(mockServerWatcher).onError(argCaptor.capture(), eq(false)); + verify(mockServerWatcher).onError(argCaptor.capture(), eq(isAbsent)); Throwable throwable = argCaptor.getValue(); assertThat(throwable).isInstanceOf(StatusException.class); - Status captured = ((StatusException)throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(Status.Code.INTERNAL); - reset(mockServerWatcher); - registeredWatcher.onResourceDoesNotExist("not-found Error"); - ArgumentCaptor argCaptor1 = ArgumentCaptor.forClass(null); - verify(mockServerWatcher).onError(argCaptor1.capture(), eq(true)); - throwable = argCaptor1.getValue(); - assertThat(throwable).isInstanceOf(StatusException.class); - captured = ((StatusException)throwable).getStatus(); - assertThat(captured.getCode()).isEqualTo(Status.Code.NOT_FOUND); - InetAddress ipLocalAddress = InetAddress.getByName("10.1.2.3"); - InetSocketAddress localAddress = new InetSocketAddress(ipLocalAddress, PORT); - EnvoyServerProtoData.DownstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"); - verify(mockServerWatcher, never()) - .onListenerUpdate(); - DownstreamTlsContext returnedTlsContext = sendListenerUpdate(localAddress, tlsContext); - assertThat(returnedTlsContext).isSameInstanceAs(tlsContext); - verify(mockServerWatcher).onListenerUpdate(); + Status captured = ((StatusException) throwable).getStatus(); + assertThat(captured.getCode()).isEqualTo(status.getCode()); + if (isAbsent) { + assertThat(xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel)).isNull(); + } else { + assertThat(xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel)).isNotNull(); + } } private DownstreamTlsContext sendListenerUpdate( - SocketAddress localAddress, DownstreamTlsContext tlsContext) throws UnknownHostException { + SocketAddress localAddress, DownstreamTlsContext tlsContext, + DownstreamTlsContext tlsContextForDefaultFilterChain, TlsContextManager tlsContextManager) + throws UnknownHostException { when(channel.localAddress()).thenReturn(localAddress); InetAddress ipRemoteAddress = InetAddress.getByName("10.4.5.6"); InetSocketAddress remoteAddress = new InetSocketAddress(ipRemoteAddress, 1234); when(channel.remoteAddress()).thenReturn(remoteAddress); - XdsServerTestHelper.generateListenerUpdate(registeredWatcher, tlsContext); - return xdsClientWrapperForServerSds.getDownstreamTlsContext(channel); + XdsServerTestHelper + .generateListenerUpdate(registeredWatcher, Arrays.asList(), tlsContext, + tlsContextForDefaultFilterChain, tlsContextManager); + return getDownstreamTlsContext(); + } + + private DownstreamTlsContext getDownstreamTlsContext() { + SslContextProviderSupplier sslContextProviderSupplier = + xdsClientWrapperForServerSds.getSslContextProviderSupplier(channel); + if (sslContextProviderSupplier != null) { + EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); + assertThat(tlsContext).isInstanceOf(DownstreamTlsContext.class); + return (DownstreamTlsContext)tlsContext; + } + return null; } /** Creates XdsClientWrapperForServerSds: also used by other classes. */ public static XdsClientWrapperForServerSds createXdsClientWrapperForServerSds( - int port, DownstreamTlsContext downstreamTlsContext) { + int port, DownstreamTlsContext downstreamTlsContext, TlsContextManager tlsContextManager) { XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - XdsServerTestHelper.createXdsClientWrapperForServerSds(port); + XdsServerTestHelper.createXdsClientWrapperForServerSds(port, tlsContextManager); xdsClientWrapperForServerSds.start(); XdsSdsClientServerTest.generateListenerUpdateToWatcher( - downstreamTlsContext, xdsClientWrapperForServerSds.getListenerWatcher()); + downstreamTlsContext, xdsClientWrapperForServerSds.getListenerWatcher(), tlsContextManager); return xdsClientWrapperForServerSds; } } diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index 46ed3dafbc5..75f81da558c 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -42,6 +42,7 @@ import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptors; +import io.grpc.Deadline; import io.grpc.InternalConfigSelector; import io.grpc.InternalConfigSelector.Result; import io.grpc.Metadata; @@ -66,18 +67,17 @@ import io.grpc.xds.FaultConfig.FaultDelay; import io.grpc.xds.Filter.FilterConfig; import io.grpc.xds.Filter.NamedFilterConfig; -import io.grpc.xds.Matchers.HeaderMatcher; -import io.grpc.xds.Matchers.PathMatcher; import io.grpc.xds.VirtualHost.Route; import io.grpc.xds.VirtualHost.Route.RouteAction; import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight; import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy; import io.grpc.xds.VirtualHost.Route.RouteMatch; +import io.grpc.xds.VirtualHost.Route.RouteMatch.PathMatcher; import io.grpc.xds.XdsNameResolverProvider.XdsClientPoolFactory; +import io.grpc.xds.internal.Matchers.HeaderMatcher; import java.io.IOException; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; @@ -449,7 +449,38 @@ public void resolved_simpleCallFailedToRoute() { } @Test - public void resolved_rpcHashingByHeader() { + public void resolved_rpcHashingByHeader_withoutSubstitution() { + resolver.start(mockListener); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + xdsClient.deliverLdsUpdate( + Collections.singletonList( + Route.create( + RouteMatch.withPathExactOnly( + "/" + TestMethodDescriptors.voidMethod().getFullMethodName()), + RouteAction.forCluster(cluster1, Collections.singletonList(HashPolicy.forHeader( + false, "custom-key", null, null)), + null), + ImmutableMap.of()))); + verify(mockListener).onResult(resolutionResultCaptor.capture()); + InternalConfigSelector configSelector = + resolutionResultCaptor.getValue().getAttributes().get(InternalConfigSelector.KEY); + + // First call, with header "custom-key": "custom-value". + startNewCall(TestMethodDescriptors.voidMethod(), configSelector, + ImmutableMap.of("custom-key", "custom-value"), CallOptions.DEFAULT); + long hash1 = testCall.callOptions.getOption(XdsNameResolver.RPC_HASH_KEY); + + // Second call, with header "custom-key": "custom-val". + startNewCall(TestMethodDescriptors.voidMethod(), configSelector, + ImmutableMap.of("custom-key", "custom-val"), + CallOptions.DEFAULT); + long hash2 = testCall.callOptions.getOption(XdsNameResolver.RPC_HASH_KEY); + + assertThat(hash2).isNotEqualTo(hash1); + } + + @Test + public void resolved_rpcHashingByHeader_withSubstitution() { resolver.start(mockListener); FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); xdsClient.deliverLdsUpdate( @@ -1150,6 +1181,36 @@ public void resolved_faultDelayWithMaxActiveStreamsInLdsUpdate() { verifyRpcDelayed(observer3, 5000L); } + @Test + public void resolved_faultDelayInLdsUpdate_callWithEarlyDeadline() { + resolver.start(mockListener); + FakeXdsClient xdsClient = (FakeXdsClient) resolver.getXdsClient(); + when(mockRandom.nextInt(1000_000)).thenReturn(500_000); // 50% + + FaultConfig httpFilterFaultConfig = FaultConfig.create( + FaultDelay.forFixedDelay(5000L, FaultConfig.FractionalPercent.perMillion(1000_000)), + null, + null); + xdsClient.deliverLdsUpdateWithFaultInjection(cluster1, httpFilterFaultConfig, null, null, null); + verify(mockListener).onResult(resolutionResultCaptor.capture()); + ResolutionResult result = resolutionResultCaptor.getValue(); + InternalConfigSelector configSelector = result.getAttributes().get(InternalConfigSelector.KEY); + + Deadline.Ticker fakeTicker = new Deadline.Ticker() { + @Override + public long nanoTime() { + return fakeClock.getTicker().read(); + } + }; + ClientCall.Listener observer = startNewCall(TestMethodDescriptors.voidMethod(), + configSelector, Collections.emptyMap(), CallOptions.DEFAULT.withDeadline( + Deadline.after(4000, TimeUnit.NANOSECONDS, fakeTicker))); + assertThat(testCall).isNull(); + verifyRpcDelayedThenAborted(observer, 4000L, Status.DEADLINE_EXCEEDED.withDescription( + "Deadline exceeded after up to 5000 ns of fault-injected delay:" + + " Deadline exceeded after 0.000004000s. ")); + } + @Test public void resolved_withNoRouterFilter() { resolver.start(mockListener); @@ -1320,7 +1381,7 @@ private void verifyRpcDelayedThenAborted( @Test public void routeMatching_pathOnly() { - Map headers = Collections.emptyMap(); + Metadata headers = new Metadata(); ThreadSafeRandom random = mock(ThreadSafeRandom.class); RouteMatch routeMatch1 = @@ -1351,14 +1412,36 @@ public void routeMatching_pathOnly() { .isTrue(); } + @Test + public void routeMatching_pathOnly_caseInsensitive() { + Metadata headers = new Metadata(); + ThreadSafeRandom random = mock(ThreadSafeRandom.class); + + RouteMatch routeMatch1 = + RouteMatch.create( + PathMatcher.fromPath("/FooService/barMethod", false), + Collections.emptyList(), null); + assertThat(XdsNameResolver.matchRoute(routeMatch1, "/fooservice/barmethod", headers, random)) + .isTrue(); + + RouteMatch routeMatch2 = + RouteMatch.create( + PathMatcher.fromPrefix("/FooService", false), + Collections.emptyList(), null); + assertThat(XdsNameResolver.matchRoute(routeMatch2, "/fooservice/barmethod", headers, random)) + .isTrue(); + } + @Test public void routeMatching_withHeaders() { - Map headers = new HashMap<>(); - headers.put("authority", "foo.googleapis.com"); - headers.put("grpc-encoding", "gzip"); - headers.put("user-agent", "gRPC-Java"); - headers.put("content-length", "1000"); - headers.put("custom-key", "custom-value1,custom-value2"); + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("authority", Metadata.ASCII_STRING_MARSHALLER), + "foo.googleapis.com"); + headers.put(Metadata.Key.of("grpc-encoding", Metadata.ASCII_STRING_MARSHALLER), "gzip"); + headers.put(Metadata.Key.of("user-agent", Metadata.ASCII_STRING_MARSHALLER), "gRPC-Java"); + headers.put(Metadata.Key.of("content-length", Metadata.ASCII_STRING_MARSHALLER), "1000"); + headers.put(Metadata.Key.of("custom-key", Metadata.ASCII_STRING_MARSHALLER), "custom-value1"); + headers.put(Metadata.Key.of("custom-key", Metadata.ASCII_STRING_MARSHALLER), "custom-value2"); ThreadSafeRandom random = mock(ThreadSafeRandom.class); PathMatcher pathMatcher = PathMatcher.fromPath("/FooService/barMethod", true); @@ -1419,15 +1502,22 @@ public void routeMatching_withHeaders() { null); assertThat(XdsNameResolver.matchRoute(routeMatch7, "/FooService/barMethod", headers, random)) .isTrue(); - } - @Test - public void pathMatching_caseInsensitive() { - PathMatcher pathMatcher1 = PathMatcher.fromPath("/FooService/barMethod", false); - assertThat(XdsNameResolver.matchPath(pathMatcher1, "/fooservice/barmethod")).isTrue(); + RouteMatch routeMatch8 = RouteMatch.create( + pathMatcher, + Collections.singletonList( + HeaderMatcher.forExactValue("content-type", "application/grpc", false)), + null); + assertThat(XdsNameResolver.matchRoute( + routeMatch8, "/FooService/barMethod", new Metadata(), random)).isTrue(); - PathMatcher pathMatcher2 = PathMatcher.fromPrefix("/FooService", false); - assertThat(XdsNameResolver.matchPath(pathMatcher2, "/fooservice/barmethod")).isTrue(); + RouteMatch routeMatch9 = RouteMatch.create( + pathMatcher, + Collections.singletonList( + HeaderMatcher.forExactValue("custom-key!", "custom-value1,custom-value2", false)), + null); + assertThat(XdsNameResolver.matchRoute(routeMatch9, "/FooService/barMethod", headers, random)) + .isFalse(); } private final class FakeXdsClientPoolFactory implements XdsClientPoolFactory { diff --git a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java index a8f326c621c..2d7a2fbce14 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSdsClientServerTest.java @@ -27,7 +27,6 @@ import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; import com.google.common.collect.ImmutableList; import io.grpc.Attributes; @@ -42,7 +41,6 @@ import io.grpc.ServerCredentials; import io.grpc.Status; import io.grpc.StatusRuntimeException; -import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.protobuf.SimpleRequest; @@ -53,7 +51,6 @@ import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; import io.grpc.xds.internal.sds.SslContextProviderSupplier; import io.grpc.xds.internal.sds.TlsContextManagerImpl; -import io.grpc.xds.internal.sds.XdsChannelBuilder; import io.netty.handler.ssl.NotSslRecordException; import java.io.IOException; import java.net.Inet4Address; @@ -73,7 +70,7 @@ import org.junit.runners.JUnit4; /** - * Unit tests for {@link XdsChannelBuilder} and {@link XdsServerBuilder} for plaintext/TLS/mTLS + * Unit tests for {@link XdsChannelCredentials} and {@link XdsServerBuilder} for plaintext/TLS/mTLS * modes. */ @RunWith(JUnit4.class) @@ -82,12 +79,11 @@ public class XdsSdsClientServerTest { @Rule public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); private int port; private FakeNameResolverFactory fakeNameResolverFactory; - private Bootstrapper mockBootstrapper; + private final TlsContextManagerImpl tlsContextManager = new TlsContextManagerImpl(null); @Before public void setUp() throws IOException { port = XdsServerTestHelper.findFreePort(); - mockBootstrapper = mock(Bootstrapper.class); } @After @@ -106,28 +102,6 @@ public void plaintextClientServer() throws IOException, URISyntaxException { assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); } - @Test - public void plaintextClientServer_withXdsChannelCreds() throws IOException, URISyntaxException { - buildServerWithTlsContext(/* downstreamTlsContext= */ null); - - SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStubNewApi(/* upstreamTlsContext= */ null, /* overrideAuthority= */ null); - assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); - } - - @Test - public void plaintextClientServer_withDefaultTlsContext() throws IOException, URISyntaxException { - DownstreamTlsContext defaultTlsContext = - EnvoyServerProtoData.DownstreamTlsContext.fromEnvoyProtoDownstreamTlsContext( - io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext - .getDefaultInstance()); - buildServerWithTlsContext(/* downstreamTlsContext= */ defaultTlsContext); - - SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = - getBlockingStub(/* upstreamTlsContext= */ null, /* overrideAuthority= */ null); - assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); - } - @Test public void nullFallbackCredentials_expectException() throws IOException, URISyntaxException { try { @@ -291,7 +265,7 @@ public void mtlsClientServer_changeServerContext_expectException() DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( BAD_SERVER_KEY_FILE, BAD_SERVER_PEM_FILE, CA_PEM_FILE); - generateListenerUpdateToWatcher(downstreamTlsContext, listenerWatcher); + generateListenerUpdateToWatcher(downstreamTlsContext, listenerWatcher, tlsContextManager); try { SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); @@ -319,7 +293,7 @@ private XdsClient.LdsResourceWatcher performMtlsTestAndGetListenerWatcher( .getListenerWatcher(); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = newApi - ? getBlockingStubNewApi(upstreamTlsContext, "foo.test.google.fr") : + ? getBlockingStub(upstreamTlsContext, "foo.test.google.fr") : getBlockingStub(upstreamTlsContext, "foo.test.google.fr"); assertThat(unaryRpc("buddy", blockingStub)).isEqualTo("Hello buddy"); return listenerWatcher; @@ -350,16 +324,18 @@ private void buildServerWithFallbackServerCredentials( } /** Creates XdsClientWrapperForServerSds. */ - private static XdsClientWrapperForServerSds createXdsClientWrapperForServerSds(int port) { + private XdsClientWrapperForServerSds createXdsClientWrapperForServerSds(int port) { XdsClientWrapperForServerSds xdsClientWrapperForServerSds = - XdsServerTestHelper.createXdsClientWrapperForServerSds(port); + XdsServerTestHelper.createXdsClientWrapperForServerSds(port, tlsContextManager); xdsClientWrapperForServerSds.start(); return xdsClientWrapperForServerSds; } static void generateListenerUpdateToWatcher( - DownstreamTlsContext tlsContext, XdsClient.LdsResourceWatcher registeredWatcher) { - EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0", tlsContext); + DownstreamTlsContext tlsContext, XdsClient.LdsResourceWatcher registeredWatcher, + TlsContextManager tlsContextManager) { + EnvoyServerProtoData.Listener listener = buildListener("listener1", "0.0.0.0", tlsContext, + tlsContextManager); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); } @@ -373,12 +349,13 @@ private void buildServer( XdsServerBuilder builder = XdsServerBuilder.forPort(port, serverCredentials) .addService(new SimpleServiceImpl()); XdsServerTestHelper.generateListenerUpdate( - xdsClientWrapperForServerSds.getListenerWatcher(), downstreamTlsContext); + xdsClientWrapperForServerSds.getListenerWatcher(), downstreamTlsContext, tlsContextManager); cleanupRule.register(builder.buildServer(xdsClientWrapperForServerSds)).start(); } static EnvoyServerProtoData.Listener buildListener( - String name, String address, DownstreamTlsContext tlsContext) { + String name, String address, DownstreamTlsContext tlsContext, + TlsContextManager tlsContextManager) { EnvoyServerProtoData.FilterChainMatch filterChainMatch = new EnvoyServerProtoData.FilterChainMatch( 0, @@ -386,9 +363,11 @@ static EnvoyServerProtoData.Listener buildListener( Arrays.asList(), Arrays.asList(), null, - Arrays.asList()); + Arrays.asList(), + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext); + new EnvoyServerProtoData.FilterChain(filterChainMatch, tlsContext, tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener(name, address, Arrays.asList(defaultFilterChain), null); return listener; @@ -400,33 +379,6 @@ private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( URI expectedUri = new URI("sdstest://localhost:" + port); fakeNameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri).build(); NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverFactory); - XdsChannelBuilder channelBuilder = - XdsChannelBuilder.forTarget("sdstest://localhost:" + port) - .fallbackProtocolNegotiator(InternalProtocolNegotiators.plaintext()); - if (overrideAuthority != null) { - channelBuilder = channelBuilder.overrideAuthority(overrideAuthority); - } - InetSocketAddress socketAddress = - new InetSocketAddress(Inet4Address.getLoopbackAddress(), port); - Attributes attrs = - (upstreamTlsContext != null) - ? Attributes.newBuilder() - .set(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, - new SslContextProviderSupplier( - upstreamTlsContext, new TlsContextManagerImpl(mockBootstrapper))) - .build() - : Attributes.EMPTY; - fakeNameResolverFactory.setServers( - ImmutableList.of(new EquivalentAddressGroup(socketAddress, attrs))); - return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build())); - } - - private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStubNewApi( - final UpstreamTlsContext upstreamTlsContext, String overrideAuthority) - throws URISyntaxException { - URI expectedUri = new URI("sdstest://localhost:" + port); - fakeNameResolverFactory = new FakeNameResolverFactory.Builder(expectedUri).build(); - NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverFactory); ManagedChannelBuilder channelBuilder = Grpc.newChannelBuilder( "sdstest://localhost:" + port, @@ -442,7 +394,7 @@ private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStubNewApi( ? Attributes.newBuilder() .set(InternalXdsAttributes.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, new SslContextProviderSupplier( - upstreamTlsContext, new TlsContextManagerImpl(mockBootstrapper))) + upstreamTlsContext, tlsContextManager)) .build() : Attributes.EMPTY; fakeNameResolverFactory.setServers( diff --git a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java index 6c09338d213..0b174a4a313 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerBuilderTest.java @@ -63,6 +63,7 @@ public class XdsServerBuilderTest { private XdsClient.LdsResourceWatcher listenerWatcher; private int port; private XdsClientWrapperForServerSds xdsClientWrapperForServerSds; + private TlsContextManager tlsContextManager; private void buildServer(XdsServerBuilder.XdsServingStatusListener xdsServingStatusListener) throws IOException { @@ -79,7 +80,9 @@ private void buildBuilder(XdsServerBuilder.XdsServingStatusListener xdsServingSt if (xdsServingStatusListener != null) { builder = builder.xdsServingStatusListener(xdsServingStatusListener); } - xdsClientWrapperForServerSds = XdsServerTestHelper.createXdsClientWrapperForServerSds(port); + tlsContextManager = mock(TlsContextManager.class); + xdsClientWrapperForServerSds = XdsServerTestHelper + .createXdsClientWrapperForServerSds(port, tlsContextManager); listenerWatcher = XdsServerTestHelper.startAndGetWatcher(xdsClientWrapperForServerSds); } @@ -131,7 +134,12 @@ public void run() { } }); // wait until xdsClientWrapperForServerSds.serverWatchers populated - for (int i = 0; i < 10 && xdsClientWrapperForServerSds.serverWatchers.isEmpty(); i++) { + for (int i = 0; i < 10; i++) { + synchronized (xdsClientWrapperForServerSds.serverWatchers) { + if (!xdsClientWrapperForServerSds.serverWatchers.isEmpty()) { + break; + } + } Thread.sleep(100L); } return settableFuture; @@ -144,8 +152,8 @@ public void xdsServerStartAndShutdown() Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); verifyServer(future, null, null); verifyShutdown(); } @@ -156,8 +164,8 @@ public void xdsServerStartAfterListenerUpdate() buildServer(null); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); xdsServer.start(); try { xdsServer.start(); @@ -177,8 +185,8 @@ public void xdsServerStartAndShutdownWithXdsServingStatusListener() Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); verifyServer(future, mockXdsServingStatusListener, null); } @@ -218,8 +226,8 @@ public void xdsServer_serverWatcher() reset(mockXdsServingStatusListener); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); verifyServer(future, mockXdsServingStatusListener, null); } @@ -234,8 +242,8 @@ public void xdsServer_startError() ServerSocket serverSocket = new ServerSocket(port); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); Throwable exception = future.get(5, TimeUnit.SECONDS); assertThat(exception).isInstanceOf(IOException.class); assertThat(exception).hasMessageThat().contains("Failed to bind"); @@ -252,12 +260,12 @@ public void xdsServerStartSecondUpdateAndError() Future future = startServerAsync(); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); XdsServerTestHelper.generateListenerUpdate( listenerWatcher, - CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1") - ); + CommonTlsContextTestsUtil.buildTestInternalDownstreamTlsContext("CERT1", "VA1"), + tlsContextManager); verify(mockXdsServingStatusListener, never()).onNotServing(any(Throwable.class)); verifyServer(future, mockXdsServingStatusListener, null); listenerWatcher.onError(Status.ABORTED); diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index cf5c110e16e..a960818059f 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -26,6 +26,7 @@ import java.io.IOException; import java.net.ServerSocket; import java.util.Arrays; +import java.util.List; import java.util.Map; import javax.annotation.Nullable; import org.mockito.ArgumentCaptor; @@ -84,15 +85,17 @@ public XdsClient returnObject(Object object) { } } - static XdsClientWrapperForServerSds createXdsClientWrapperForServerSds(int port) { + static XdsClientWrapperForServerSds createXdsClientWrapperForServerSds(int port, + TlsContextManager tlsContextManager) { FakeXdsClientPoolFactory fakeXdsClientPoolFactory = new FakeXdsClientPoolFactory( - buildMockXdsClient()); + buildMockXdsClient(tlsContextManager)); return new XdsClientWrapperForServerSds(port, fakeXdsClientPoolFactory); } - private static XdsClient buildMockXdsClient() { + private static XdsClient buildMockXdsClient(TlsContextManager tlsContextManager) { XdsClient xdsClient = mock(XdsClient.class); when(xdsClient.getBootstrapInfo()).thenReturn(BOOTSTRAP_INFO); + when(xdsClient.getTlsContextManager()).thenReturn(tlsContextManager); return xdsClient; } @@ -110,14 +113,25 @@ static XdsClient.LdsResourceWatcher startAndGetWatcher( * Creates a {@link XdsClient.LdsUpdate} with {@link * io.grpc.xds.EnvoyServerProtoData.FilterChain} with a destination port and an optional {@link * EnvoyServerProtoData.DownstreamTlsContext}. - * * @param registeredWatcher the watcher on which to generate the update * @param tlsContext if non-null, used to populate filterChain */ static void generateListenerUpdate( XdsClient.LdsResourceWatcher registeredWatcher, - EnvoyServerProtoData.DownstreamTlsContext tlsContext) { - EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", tlsContext); + EnvoyServerProtoData.DownstreamTlsContext tlsContext, TlsContextManager tlsContextManager) { + EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", + Arrays.asList(), tlsContext, null, tlsContextManager); + XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); + registeredWatcher.onChanged(listenerUpdate); + } + + static void generateListenerUpdate( + XdsClient.LdsResourceWatcher registeredWatcher, List sourcePorts, + EnvoyServerProtoData.DownstreamTlsContext tlsContext, + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, + TlsContextManager tlsContextManager) { + EnvoyServerProtoData.Listener listener = buildTestListener("listener1", "10.1.2.3", sourcePorts, + tlsContext, tlsContextForDefaultFilterChain, tlsContextManager); XdsClient.LdsUpdate listenerUpdate = new XdsClient.LdsUpdate(listener); registeredWatcher.onChanged(listenerUpdate); } @@ -130,7 +144,10 @@ static int findFreePort() throws IOException { } static EnvoyServerProtoData.Listener buildTestListener( - String name, String address, EnvoyServerProtoData.DownstreamTlsContext tlsContext) { + String name, String address, List sourcePorts, + EnvoyServerProtoData.DownstreamTlsContext tlsContext, + EnvoyServerProtoData.DownstreamTlsContext tlsContextForDefaultFilterChain, + TlsContextManager tlsContextManager) { EnvoyServerProtoData.FilterChainMatch filterChainMatch1 = new EnvoyServerProtoData.FilterChainMatch( 0, @@ -138,11 +155,14 @@ static EnvoyServerProtoData.Listener buildTestListener( Arrays.asList(), Arrays.asList(), null, - Arrays.asList()); + sourcePorts, + Arrays.asList(), + null); EnvoyServerProtoData.FilterChain filterChain1 = - new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext); + new EnvoyServerProtoData.FilterChain(filterChainMatch1, tlsContext, tlsContextManager); EnvoyServerProtoData.FilterChain defaultFilterChain = - new EnvoyServerProtoData.FilterChain(null, null); + new EnvoyServerProtoData.FilterChain(null, tlsContextForDefaultFilterChain, + tlsContextManager); EnvoyServerProtoData.Listener listener = new EnvoyServerProtoData.Listener( name, address, Arrays.asList(filterChain1), defaultFilterChain); diff --git a/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java b/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java new file mode 100644 index 00000000000..4fb4acc41f6 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/MatcherTest.java @@ -0,0 +1,171 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.re2j.Pattern; +import io.grpc.xds.internal.Matchers.CidrMatcher; +import io.grpc.xds.internal.Matchers.HeaderMatcher; +import io.grpc.xds.internal.Matchers.HeaderMatcher.Range; +import io.grpc.xds.internal.Matchers.StringMatcher; +import java.net.InetAddress; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class MatcherTest { + + @Test + public void ipMatcher_ipv4() throws Exception { + CidrMatcher matcher = CidrMatcher.create(InetAddress.getByName("10.10.24.10"), 20); + assertThat(matcher.matches(InetAddress.getByName("::0"))).isFalse(); + assertThat(matcher.matches(InetAddress.getByName("10.10.20.0"))).isTrue(); + assertThat(matcher.matches(InetAddress.getByName("10.10.16.0"))).isTrue(); + assertThat(matcher.matches(InetAddress.getByName("10.10.24.10"))).isTrue(); + assertThat(matcher.matches(InetAddress.getByName("10.10.31.0"))).isTrue(); + assertThat(matcher.matches(InetAddress.getByName("10.10.17.0"))).isTrue(); + assertThat(matcher.matches(InetAddress.getByName("10.32.20.0"))).isFalse(); + assertThat(matcher.matches(InetAddress.getByName("10.10.40.0"))).isFalse(); + matcher = CidrMatcher.create(InetAddress.getByName("0.0.0.0"), 20); + assertThat(matcher.matches(InetAddress.getByName("10.32.20.0"))).isFalse(); + assertThat(matcher.matches(InetAddress.getByName("0.0.31.0"))).isFalse(); + assertThat(matcher.matches(InetAddress.getByName("0.0.15.0"))).isTrue(); + assertThat(matcher.matches(null)).isFalse(); + } + + @Test + public void ipMatcher_ipv6() throws Exception { + CidrMatcher matcher = CidrMatcher.create(InetAddress.getByName("2012:00fe:d808::"), 36); + assertThat(matcher.matches(InetAddress.getByName("0.0.0.0"))).isFalse(); + assertThat(matcher.matches(InetAddress.getByName("2012:00fe:d000::0"))).isTrue(); + assertThat(matcher.matches(InetAddress.getByName("2012:00fe:d808::"))).isTrue(); + assertThat(matcher.matches(InetAddress.getByName("2012:00fe:da81:0909:0008:4018:e930:b019"))) + .isTrue(); + assertThat(matcher.matches(InetAddress.getByName("2013:00fe:d000::0"))).isFalse(); + } + + @Test + public void stringMatcher() { + StringMatcher matcher = StringMatcher.forExact("essence", false); + assertThat(matcher.matches("elite")).isFalse(); + assertThat(matcher.matches("ess")).isFalse(); + assertThat(matcher.matches("")).isFalse(); + assertThat(matcher.matches("essential")).isFalse(); + assertThat(matcher.matches("Essence")).isFalse(); + assertThat(matcher.matches("essence")).isTrue(); + assertThat(matcher.matches((String)null)).isFalse(); + matcher = StringMatcher.forExact("essence", true); + assertThat(matcher.matches("Essence")).isTrue(); + assertThat(matcher.matches("essence")).isTrue(); + matcher = StringMatcher.forExact("", true); + assertThat(matcher.matches("essence")).isFalse(); + assertThat(matcher.matches("")).isTrue(); + + matcher = StringMatcher.forPrefix("Ess", false); + assertThat(matcher.matches("elite")).isFalse(); + assertThat(matcher.matches("ess")).isFalse(); + assertThat(matcher.matches("")).isFalse(); + assertThat(matcher.matches("e")).isFalse(); + assertThat(matcher.matches("essential")).isFalse(); + assertThat(matcher.matches("Essence")).isTrue(); + assertThat(matcher.matches("essence")).isFalse(); + assertThat(matcher.matches((String)null)).isFalse(); + matcher = StringMatcher.forPrefix("Ess", true); + assertThat(matcher.matches("esSEncE")).isTrue(); + assertThat(matcher.matches("ess")).isTrue(); + assertThat(matcher.matches("ES")).isFalse(); + matcher = StringMatcher.forPrefix("", false); + assertThat(matcher.matches("elite")).isTrue(); + + matcher = StringMatcher.forSuffix("ess", false); + assertThat(matcher.matches("elite")).isFalse(); + assertThat(matcher.matches("es")).isFalse(); + assertThat(matcher.matches("")).isFalse(); + assertThat(matcher.matches("ess")).isTrue(); + assertThat(matcher.matches("Excess")).isTrue(); + assertThat(matcher.matches("ExcesS")).isFalse(); + assertThat(matcher.matches((String)null)).isFalse(); + matcher = StringMatcher.forSuffix("ess", true); + assertThat(matcher.matches("esSEncESs")).isTrue(); + assertThat(matcher.matches("ess")).isTrue(); + matcher = StringMatcher.forSuffix("", true); + assertThat(matcher.matches("")).isTrue(); + assertThat(matcher.matches("any")).isTrue(); + + matcher = StringMatcher.forContains("ess"); + assertThat(matcher.matches("elite")).isFalse(); + assertThat(matcher.matches("es")).isFalse(); + assertThat(matcher.matches("")).isFalse(); + assertThat(matcher.matches("essence")).isTrue(); + assertThat(matcher.matches("eSs")).isFalse(); + assertThat(matcher.matches("ExcesS")).isFalse(); + assertThat(matcher.matches((String)null)).isFalse(); + + matcher = StringMatcher.forSafeRegEx(Pattern.compile("^es*.*")); + assertThat(matcher.matches("essence")).isTrue(); + assertThat(matcher.matches("")).isFalse(); + } + + @Test + public void headerMatcher() { + HeaderMatcher matcher = HeaderMatcher.forExactValue("version", "v1", false); + assertThat(matcher.matches("v1")).isTrue(); + assertThat(matcher.matches("v2")).isFalse(); + + matcher = HeaderMatcher.forExactValue("version", "v1", true); + assertThat(matcher.matches("v1")).isFalse(); + assertThat(matcher.matches( "v2")).isTrue(); + + matcher = HeaderMatcher.forPresent("version", true, false); + assertThat(matcher.matches("any")).isTrue(); + assertThat(matcher.matches(null)).isFalse(); + matcher = HeaderMatcher.forPresent("version", true, true); + assertThat(matcher.matches("version")).isFalse(); + matcher = HeaderMatcher.forPresent("version", false, true); + assertThat(matcher.matches("tag")).isTrue(); + matcher = HeaderMatcher.forPresent("version", false, false); + assertThat(matcher.matches("tag")).isFalse(); + + matcher = HeaderMatcher.forPrefix("version", "v2", false); + assertThat(matcher.matches("v22")).isTrue(); + matcher = HeaderMatcher.forPrefix("version", "v2", true); + assertThat(matcher.matches("v22")).isFalse(); + + matcher = HeaderMatcher.forSuffix("version", "v1", false); + assertThat(matcher.matches("xv1")).isTrue(); + assertThat(matcher.matches("v1x")).isFalse(); + matcher = HeaderMatcher.forSuffix("version", "v2", true); + assertThat(matcher.matches("xv1")).isTrue(); + assertThat(matcher.matches("1v2")).isFalse(); + + matcher = HeaderMatcher.forSafeRegEx("version", Pattern.compile("v2.*"), false); + assertThat(matcher.matches("v2..")).isTrue(); + assertThat(matcher.matches("v1")).isFalse(); + matcher = HeaderMatcher.forSafeRegEx("version", Pattern.compile("v1\\..*"), true); + assertThat(matcher.matches("v1.43")).isFalse(); + assertThat(matcher.matches("v2")).isTrue(); + + matcher = HeaderMatcher.forRange("version", Range.create(8080L, 8090L), false); + assertThat(matcher.matches("8080")).isTrue(); + assertThat(matcher.matches("1")).isFalse(); + matcher = HeaderMatcher.forRange("version", Range.create(8080L, 8090L), true); + assertThat(matcher.matches("1")).isTrue(); + assertThat(matcher.matches("8080")).isFalse(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderProviderTest.java deleted file mode 100644 index 2ee5fb39640..00000000000 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderProviderTest.java +++ /dev/null @@ -1,219 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.truth.Truth.assertThat; -import static org.junit.Assert.fail; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.grpc.internal.JsonParser; -import io.grpc.internal.TimeProvider; -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -/** Unit tests for {@link DynamicReloadingCertificateProviderProvider}. */ -@RunWith(JUnit4.class) -public class DynamicReloadingCertificateProviderProviderTest { - - @Mock DynamicReloadingCertificateProvider.Factory dynamicReloadingCertificateProviderFactory; - @Mock private DynamicReloadingCertificateProviderProvider.ScheduledExecutorServiceFactory - scheduledExecutorServiceFactory; - @Mock private TimeProvider timeProvider; - - private DynamicReloadingCertificateProviderProvider provider; - - @Before - public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - provider = - new DynamicReloadingCertificateProviderProvider( - dynamicReloadingCertificateProviderFactory, - scheduledExecutorServiceFactory, - timeProvider); - } - - @Test - public void providerRegisteredName() { - CertificateProviderProvider certProviderProvider = - CertificateProviderRegistry.getInstance() - .getProvider( - DynamicReloadingCertificateProviderProvider.DYNAMIC_RELOADING_PROVIDER_NAME); - assertThat(certProviderProvider) - .isInstanceOf(DynamicReloadingCertificateProviderProvider.class); - DynamicReloadingCertificateProviderProvider dynamicReloadingCertificateProviderProvider = - (DynamicReloadingCertificateProviderProvider) certProviderProvider; - assertThat( - dynamicReloadingCertificateProviderProvider.dynamicReloadingCertificateProviderFactory) - .isSameInstanceAs(DynamicReloadingCertificateProvider.Factory.getInstance()); - } - - @Test - public void createProvider_minimalConfig() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MINIMAL_DYNAMIC_RELOADING_CONFIG); - ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); - when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); - provider.createCertificateProvider(map, distWatcher, true); - verify(dynamicReloadingCertificateProviderFactory, times(1)) - .create( - eq(distWatcher), - eq(true), - eq("/var/run/gke-spiffe/certs/..data"), - eq("certificates.pem"), - eq("private_key.pem"), - eq("ca_certificates.pem"), - eq(600L), - eq(mockService), - eq(timeProvider)); - } - - @Test - public void createProvider_fullConfig() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(FULL_DYNAMIC_RELOADING_CONFIG); - ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); - when(scheduledExecutorServiceFactory.create()).thenReturn(mockService); - provider.createCertificateProvider(map, distWatcher, true); - verify(dynamicReloadingCertificateProviderFactory, times(1)) - .create( - eq(distWatcher), - eq(true), - eq("/var/run/gke-spiffe/certs/..data1"), - eq("certificates2.pem"), - eq("private_key3.pem"), - eq("ca_certificates4.pem"), - eq(7890L), - eq(mockService), - eq(timeProvider)); - } - - @Test - public void createProvider_missingDir_expectException() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_DIR_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'directory' is required in the config"); - } - } - - @Test - public void createProvider_missingCert_expectException() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_CERT_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'certificate-file' is required in the config"); - } - } - - @Test - public void createProvider_missingKey_expectException() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_KEY_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'private-key-file' is required in the config"); - } - } - - @Test - public void createProvider_missingRoot_expectException() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_ROOT_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'ca-certificate-file' is required in the config"); - } - } - - private static final String MINIMAL_DYNAMIC_RELOADING_CONFIG = - "{\n" - + " \"directory\": \"/var/run/gke-spiffe/certs/..data\"," - + " \"certificate-file\": \"certificates.pem\"," - + " \"private-key-file\": \"private_key.pem\"," - + " \"ca-certificate-file\": \"ca_certificates.pem\"" - + " }"; - - private static final String FULL_DYNAMIC_RELOADING_CONFIG = - "{\n" - + " \"directory\": \"/var/run/gke-spiffe/certs/..data1\"," - + " \"certificate-file\": \"certificates2.pem\"," - + " \"private-key-file\": \"private_key3.pem\"," - + " \"ca-certificate-file\": \"ca_certificates4.pem\"," - + " \"refresh-interval\": 7890" - + " }"; - - private static final String MISSING_DIR_CONFIG = - "{\n" - + " \"certificate-file\": \"certificates.pem\"," - + " \"private-key-file\": \"private_key.pem\"," - + " \"ca-certificate-file\": \"ca_certificates.pem\"" - + " }"; - - private static final String MISSING_CERT_CONFIG = - "{\n" - + " \"directory\": \"/var/run/gke-spiffe/certs/..data\"," - + " \"private-key-file\": \"private_key.pem\"," - + " \"ca-certificate-file\": \"ca_certificates.pem\"" - + " }"; - - private static final String MISSING_KEY_CONFIG = - "{\n" - + " \"directory\": \"/var/run/gke-spiffe/certs/..data\"," - + " \"certificate-file\": \"certificates.pem\"," - + " \"ca-certificate-file\": \"ca_certificates.pem\"" - + " }"; - - private static final String MISSING_ROOT_CONFIG = - "{\n" - + " \"directory\": \"/var/run/gke-spiffe/certs/..data\"," - + " \"certificate-file\": \"certificates.pem\"," - + " \"private-key-file\": \"private_key.pem\"" - + " }"; -} diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderTest.java deleted file mode 100644 index a2f7cb7e32f..00000000000 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/DynamicReloadingCertificateProviderTest.java +++ /dev/null @@ -1,303 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CLIENT_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_KEY_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import io.grpc.Status; -import io.grpc.internal.TimeProvider; -import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import java.io.File; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.NoSuchFileException; -import java.nio.file.Paths; -import java.security.PrivateKey; -import java.security.cert.CertificateException; -import java.security.cert.X509Certificate; -import java.util.List; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatchers; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -/** Unit tests for {@link DynamicReloadingCertificateProvider}. */ -@RunWith(JUnit4.class) -public class DynamicReloadingCertificateProviderTest { - private static final String CERT_FILE = "cert.pem"; - private static final String KEY_FILE = "key.pem"; - private static final String ROOT_FILE = "root.pem"; - - @Mock private CertificateProvider.Watcher mockWatcher; - @Mock private ScheduledExecutorService timeService; - @Mock private TimeProvider timeProvider; - - @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); - private String symlink; - - private DynamicReloadingCertificateProvider provider; - - @Before - public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - - DistributorWatcher watcher = new DistributorWatcher(); - watcher.addWatcher(mockWatcher); - - symlink = new File(tempFolder.getRoot(), "..data").getAbsolutePath(); - provider = - new DynamicReloadingCertificateProvider( - watcher, - true, - symlink, - CERT_FILE, - KEY_FILE, - ROOT_FILE, - 600L, - timeService, - timeProvider); - } - - private void populateTarget( - String certFile, String keyFile, String rootFile, boolean deleteExisting, boolean createNew) - throws IOException { - String target = tempFolder.newFolder().getAbsolutePath(); - if (certFile != null) { - certFile = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(certFile); - Files.copy(Paths.get(certFile), Paths.get(target, CERT_FILE)); - } - if (keyFile != null) { - keyFile = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(keyFile); - Files.copy(Paths.get(keyFile), Paths.get(target, KEY_FILE)); - } - if (rootFile != null) { - rootFile = CommonTlsContextTestsUtil.getTempFileNameForResourcesFile(rootFile); - Files.copy(Paths.get(rootFile), Paths.get(target, ROOT_FILE)); - } - if (deleteExisting) { - Files.delete(Paths.get(symlink)); - } - if (createNew) { - Files.createSymbolicLink(Paths.get(symlink), Paths.get(target)); - } - } - - @Test - public void getCertificateAndCheckUpdates() throws IOException, CertificateException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(CLIENT_PEM_FILE, CLIENT_KEY_FILE, CA_PEM_FILE, false, true); - provider.checkAndReloadCertificates(); - verifyWatcherUpdates(CLIENT_PEM_FILE, CA_PEM_FILE); - verifyTimeServiceAndScheduledHandle(); - - reset(mockWatcher, timeService); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.scheduledHandle.cancel(); - provider.checkAndReloadCertificates(); - verifyWatcherErrorUpdates(null, null, (String[]) null); - verifyTimeServiceAndScheduledHandle(); - - reset(mockWatcher, timeService); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.scheduledHandle.cancel(); - populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, true, true); - provider.checkAndReloadCertificates(); - verifyWatcherUpdates(SERVER_0_PEM_FILE, SERVER_1_PEM_FILE); - verifyTimeServiceAndScheduledHandle(); - } - - @Test - public void getCertificate_initialMissingCertFile() throws IOException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(null, CLIENT_KEY_FILE, CA_PEM_FILE, false, true); - when(timeProvider.currentTimeNanos()) - .thenReturn(TimeProvider.SYSTEM_TIME_PROVIDER.currentTimeNanos()); - provider.checkAndReloadCertificates(); - verifyWatcherErrorUpdates(Status.Code.UNKNOWN, java.io.FileNotFoundException.class, "cert.pem"); - } - - @Test - public void getCertificate_missingSymlink() throws IOException { - commonErrorTest(null, null, null, true, false, NoSuchFileException.class, "..data"); - } - - @Test - public void getCertificate_missingCertFile() throws IOException { - commonErrorTest( - null, - CLIENT_KEY_FILE, - CA_PEM_FILE, - true, - true, - java.io.FileNotFoundException.class, - "cert.pem"); - } - - @Test - public void getCertificate_missingKeyFile() throws IOException { - commonErrorTest( - CLIENT_PEM_FILE, - null, - CA_PEM_FILE, - true, - true, - java.io.FileNotFoundException.class, - "key.pem"); - } - - @Test - public void getCertificate_badKeyFile() throws IOException { - commonErrorTest( - CLIENT_PEM_FILE, - SERVER_0_PEM_FILE, - CA_PEM_FILE, - true, - true, - java.security.KeyException.class, - "could not find a PKCS #8 private key in input stream"); - } - - @Test - public void getCertificate_missingRootFile() throws IOException { - commonErrorTest( - CLIENT_PEM_FILE, - CLIENT_KEY_FILE, - null, - true, - true, - java.io.FileNotFoundException.class, - "root.pem"); - } - - private void commonErrorTest( - String certFile, - String keyFile, - String rootFile, - boolean deleteExisting, - boolean createNew, - Class throwableType, - String... causeMessages) - throws IOException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - populateTarget(SERVER_0_PEM_FILE, SERVER_0_KEY_FILE, SERVER_1_PEM_FILE, false, true); - provider.checkAndReloadCertificates(); - - reset(mockWatcher); - populateTarget(certFile, keyFile, rootFile, deleteExisting, createNew); - when(timeProvider.currentTimeNanos()) - .thenReturn( - TimeUnit.MILLISECONDS.toNanos( - MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L)); - provider.scheduledHandle.cancel(); - provider.checkAndReloadCertificates(); - verifyWatcherErrorUpdates(null, null, (String[]) null); - - reset(mockWatcher, timeProvider); - when(timeProvider.currentTimeNanos()) - .thenReturn( - TimeUnit.MILLISECONDS.toNanos( - MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 590_000L)); - provider.scheduledHandle.cancel(); - provider.checkAndReloadCertificates(); - verifyWatcherErrorUpdates(Status.Code.UNKNOWN, throwableType, causeMessages); - } - - private void verifyWatcherErrorUpdates( - Status.Code code, Class throwableType, String... causeMessages) { - verify(mockWatcher, never()) - .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); - verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); - if (code == null && throwableType == null && causeMessages == null) { - verify(mockWatcher, never()).onError(any(Status.class)); - } else { - ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(null); - verify(mockWatcher, times(1)).onError(statusCaptor.capture()); - Status status = statusCaptor.getValue(); - assertThat(status.getCode()).isEqualTo(code); - Throwable cause = status.getCause(); - assertThat(cause).isInstanceOf(throwableType); - for (String causeMessage : causeMessages) { - assertThat(cause).hasMessageThat().contains(causeMessage); - cause = cause.getCause(); - } - } - } - - private void verifyTimeServiceAndScheduledHandle() { - verify(timeService, times(1)).schedule(any(Runnable.class), eq(600L), eq(TimeUnit.SECONDS)); - assertThat(provider.scheduledHandle).isNotNull(); - assertThat(provider.scheduledHandle.isPending()).isTrue(); - } - - private void verifyWatcherUpdates(String certPemFile, String rootPemFile) - throws IOException, CertificateException { - ArgumentCaptor> certChainCaptor = ArgumentCaptor.forClass(null); - verify(mockWatcher, times(1)) - .updateCertificate(any(PrivateKey.class), certChainCaptor.capture()); - List certChain = certChainCaptor.getValue(); - assertThat(certChain).hasSize(1); - assertThat(certChain.get(0)) - .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(certPemFile)); - - ArgumentCaptor> rootsCaptor = ArgumentCaptor.forClass(null); - verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture()); - List roots = rootsCaptor.getValue(); - assertThat(roots).hasSize(1); - assertThat(roots.get(0)) - .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(rootPemFile)); - verify(mockWatcher, never()).onError(any(Status.class)); - } -} diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java index 474c05d0489..4b22cfb4e34 100644 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/certprovider/FileWatcherCertificateProviderTest.java @@ -45,8 +45,11 @@ import java.security.PrivateKey; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.ArrayList; import java.util.List; +import java.util.concurrent.Delayed; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; import org.junit.Before; import org.junit.Rule; @@ -62,6 +65,10 @@ /** Unit tests for {@link FileWatcherCertificateProvider}. */ @RunWith(JUnit4.class) public class FileWatcherCertificateProviderTest { + /** + * Expire time of cert SERVER_0_PEM_FILE. + */ + static final long CERT0_EXPIRY_TIME_MILLIS = 1899853658000L; private static final String CERT_FILE = "cert.pem"; private static final String KEY_FILE = "key.pem"; private static final String ROOT_FILE = "root.pem"; @@ -126,8 +133,8 @@ private void populateTarget( @Test public void getCertificateAndCheckUpdates() throws IOException, CertificateException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -147,8 +154,8 @@ public void getCertificateAndCheckUpdates() throws IOException, CertificateExcep @Test public void allUpdateSecondTime() throws IOException, CertificateException, InterruptedException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -168,8 +175,8 @@ public void allUpdateSecondTime() throws IOException, CertificateException, Inte @Test public void closeDoesNotScheduleNext() throws IOException, CertificateException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -180,13 +187,14 @@ public void closeDoesNotScheduleNext() throws IOException, CertificateException .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); verify(timeService, never()).schedule(any(Runnable.class), any(Long.TYPE), any(TimeUnit.class)); + verify(timeService, times(1)).shutdownNow(); } @Test public void rootFileUpdateOnly() throws IOException, CertificateException, InterruptedException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -207,8 +215,8 @@ public void rootFileUpdateOnly() throws IOException, CertificateException, Inter @Test public void certAndKeyFileUpdateOnly() throws IOException, CertificateException, InterruptedException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -228,8 +236,8 @@ public void certAndKeyFileUpdateOnly() @Test public void getCertificate_initialMissingCertFile() throws IOException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -268,8 +276,8 @@ public void getCertificate_badKeyFile() throws IOException, InterruptedException @Test public void getCertificate_missingRootFile() throws IOException, InterruptedException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -282,7 +290,7 @@ public void getCertificate_missingRootFile() throws IOException, InterruptedExce when(timeProvider.currentTimeNanos()) .thenReturn( TimeUnit.MILLISECONDS.toNanos( - MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L)); + CERT0_EXPIRY_TIME_MILLIS - 610_000L)); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates(Status.Code.UNKNOWN, NoSuchFileException.class, 1, 0, "root.pem"); } @@ -298,8 +306,8 @@ private void commonErrorTest( int secondUpdateRootCount, String... causeMessages) throws IOException, InterruptedException { - MeshCaCertificateProviderTest.TestScheduledFuture scheduledFuture = - new MeshCaCertificateProviderTest.TestScheduledFuture<>(); + TestScheduledFuture scheduledFuture = + new TestScheduledFuture<>(); doReturn(scheduledFuture) .when(timeService) .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); @@ -313,7 +321,7 @@ private void commonErrorTest( when(timeProvider.currentTimeNanos()) .thenReturn( TimeUnit.MILLISECONDS.toNanos( - MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 610_000L)); + CERT0_EXPIRY_TIME_MILLIS - 610_000L)); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates( null, null, firstUpdateCertCount, firstUpdateRootCount, (String[]) null); @@ -322,7 +330,7 @@ private void commonErrorTest( when(timeProvider.currentTimeNanos()) .thenReturn( TimeUnit.MILLISECONDS.toNanos( - MeshCaCertificateProviderTest.CERT0_EXPIRY_TIME_MILLIS - 590_000L)); + CERT0_EXPIRY_TIME_MILLIS - 590_000L)); provider.checkAndReloadCertificates(); verifyWatcherErrorUpdates( Status.Code.UNKNOWN, @@ -391,4 +399,55 @@ private void verifyWatcherUpdates(String certPemFile, String rootPemFile) verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); } } + + static class TestScheduledFuture implements ScheduledFuture { + + static class Record { + long timeout; + TimeUnit unit; + + Record(long timeout, TimeUnit unit) { + this.timeout = timeout; + this.unit = unit; + } + } + + ArrayList calls = new ArrayList<>(); + + @Override + public long getDelay(TimeUnit unit) { + return 0; + } + + @Override + public int compareTo(Delayed o) { + return 0; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public V get() { + return null; + } + + @Override + public V get(long timeout, TimeUnit unit) { + calls.add(new Record(timeout, unit)); + return null; + } + } } diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java deleted file mode 100644 index 791d5a395c5..00000000000 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderProviderTest.java +++ /dev/null @@ -1,409 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.certprovider.MeshCaCertificateProviderProvider.RPC_TIMEOUT_SECONDS; -import static org.junit.Assert.fail; -import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.ArgumentMatchers.isNull; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.google.auth.oauth2.GoogleCredentials; -import io.grpc.internal.BackoffPolicy; -import io.grpc.internal.ExponentialBackoffPolicy; -import io.grpc.internal.JsonParser; -import io.grpc.internal.TimeProvider; -import io.grpc.xds.internal.sts.StsCredentials; -import java.io.IOException; -import java.util.Map; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; - -/** Unit tests for {@link MeshCaCertificateProviderProvider}. */ -@RunWith(JUnit4.class) -public class MeshCaCertificateProviderProviderTest { - - public static final String EXPECTED_AUDIENCE = - "identitynamespace:test-project1.svc.id.goog:https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3"; - public static final String EXPECTED_AUDIENCE_V1BETA1_ZONE = - "identitynamespace:test-project1.svc.id.goog:https://container.googleapis.com/v1beta1/projects/test-project1/zones/test-zone2/clusters/test-cluster3"; - public static final String TMP_PATH_4 = "/tmp/path4"; - public static final String NON_DEFAULT_MESH_CA_URL = "nonDefaultMeshCaUrl"; - - @Mock - StsCredentials.Factory stsCredentialsFactory; - - @Mock - MeshCaCertificateProvider.MeshCaChannelFactory meshCaChannelFactory; - - @Mock - BackoffPolicy.Provider backoffPolicyProvider; - - @Mock - MeshCaCertificateProvider.Factory meshCaCertificateProviderFactory; - - @Mock - private MeshCaCertificateProviderProvider.ScheduledExecutorServiceFactory - scheduledExecutorServiceFactory; - - @Mock - private TimeProvider timeProvider; - - private MeshCaCertificateProviderProvider provider; - - @Before - public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - provider = - new MeshCaCertificateProviderProvider( - stsCredentialsFactory, - meshCaChannelFactory, - backoffPolicyProvider, - meshCaCertificateProviderFactory, - scheduledExecutorServiceFactory, - timeProvider); - } - - @Test - public void providerRegisteredName() { - CertificateProviderProvider certProviderProvider = CertificateProviderRegistry.getInstance() - .getProvider(MeshCaCertificateProviderProvider.MESH_CA_NAME); - assertThat(certProviderProvider).isInstanceOf(MeshCaCertificateProviderProvider.class); - MeshCaCertificateProviderProvider meshCaCertificateProviderProvider = - (MeshCaCertificateProviderProvider) certProviderProvider; - assertThat(meshCaCertificateProviderProvider.stsCredentialsFactory) - .isSameInstanceAs(StsCredentials.Factory.getInstance()); - assertThat(meshCaCertificateProviderProvider.meshCaChannelFactory) - .isSameInstanceAs(MeshCaCertificateProvider.MeshCaChannelFactory.getInstance()); - assertThat(meshCaCertificateProviderProvider.backoffPolicyProvider) - .isInstanceOf(ExponentialBackoffPolicy.Provider.class); - assertThat(meshCaCertificateProviderProvider.meshCaCertificateProviderFactory) - .isSameInstanceAs(MeshCaCertificateProvider.Factory.getInstance()); - } - - @Test - public void createProvider_minimalConfig() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MINIMAL_MESHCA_CONFIG); - ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); - when(scheduledExecutorServiceFactory.create( - eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT))) - .thenReturn(mockService); - provider.createCertificateProvider(map, distWatcher, true); - verify(stsCredentialsFactory, times(1)) - .create( - eq(MeshCaCertificateProviderProvider.STS_URL_DEFAULT), - eq(EXPECTED_AUDIENCE), - eq("/tmp/path5")); - verify(meshCaCertificateProviderFactory, times(1)) - .create( - eq(distWatcher), - eq(true), - eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT), - eq("test-zone2"), - eq(MeshCaCertificateProviderProvider.CERT_VALIDITY_SECONDS_DEFAULT), - eq(MeshCaCertificateProviderProvider.KEY_SIZE_DEFAULT), - eq(MeshCaCertificateProviderProvider.KEY_ALGO_DEFAULT), - eq(MeshCaCertificateProviderProvider.SIGNATURE_ALGO_DEFAULT), - eq(meshCaChannelFactory), - eq(backoffPolicyProvider), - eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT), - eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT), - (GoogleCredentials) isNull(), - eq(mockService), - eq(timeProvider), - eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS))); - } - - @Test - public void createProvider_minimalConfig_v1beta1AndZone() - throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(V1BETA1_ZONE_MESHCA_CONFIG); - ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); - when(scheduledExecutorServiceFactory.create( - eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT))) - .thenReturn(mockService); - provider.createCertificateProvider(map, distWatcher, true); - verify(stsCredentialsFactory, times(1)) - .create( - eq(MeshCaCertificateProviderProvider.STS_URL_DEFAULT), - eq(EXPECTED_AUDIENCE_V1BETA1_ZONE), - eq("/tmp/path5")); - verify(meshCaCertificateProviderFactory, times(1)) - .create( - eq(distWatcher), - eq(true), - eq(MeshCaCertificateProviderProvider.MESHCA_URL_DEFAULT), - eq("test-zone2"), - eq(MeshCaCertificateProviderProvider.CERT_VALIDITY_SECONDS_DEFAULT), - eq(MeshCaCertificateProviderProvider.KEY_SIZE_DEFAULT), - eq(MeshCaCertificateProviderProvider.KEY_ALGO_DEFAULT), - eq(MeshCaCertificateProviderProvider.SIGNATURE_ALGO_DEFAULT), - eq(meshCaChannelFactory), - eq(backoffPolicyProvider), - eq(MeshCaCertificateProviderProvider.RENEWAL_GRACE_PERIOD_SECONDS_DEFAULT), - eq(MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT), - (GoogleCredentials) isNull(), - eq(mockService), - eq(timeProvider), - eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS))); - } - - @Test - public void createProvider_missingGkeUrl_expectException() - throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_GKE_CLUSTER_URL_MESHCA_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'location' is required in the config"); - } - } - - @Test - public void createProvider_missingSaJwtLocation_expectException() - throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MISSING_SAJWT_MESHCA_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException npe) { - assertThat(npe).hasMessageThat().isEqualTo("'subject_token_path' is required in the config"); - } - } - - @Test - public void createProvider_missingProject_expectException() - throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(MINIMAL_BAD_CLUSTER_URL_MESHCA_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (IllegalStateException ex) { - assertThat(ex).hasMessageThat().isEqualTo("gkeClusterUrl does not have correct format"); - } - } - - @Test - public void createProvider_badChannelCreds_expectException() - throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(BAD_CHANNEL_CREDS_MESHCA_CONFIG); - try { - provider.createCertificateProvider(map, distWatcher, true); - fail("exception expected"); - } catch (NullPointerException ex) { - assertThat(ex).hasMessageThat().isEqualTo("channel_credentials need to be google_default!"); - } - } - - @Test - public void createProvider_nonDefaultFullConfig() throws IOException { - CertificateProvider.DistributorWatcher distWatcher = - new CertificateProvider.DistributorWatcher(); - @SuppressWarnings("unchecked") - Map map = (Map) JsonParser.parse(NONDEFAULT_MESHCA_CONFIG); - ScheduledExecutorService mockService = mock(ScheduledExecutorService.class); - when(scheduledExecutorServiceFactory.create(eq(NON_DEFAULT_MESH_CA_URL))) - .thenReturn(mockService); - provider.createCertificateProvider(map, distWatcher, true); - verify(stsCredentialsFactory, times(1)) - .create( - eq("test.sts.com"), - eq(EXPECTED_AUDIENCE), - eq(TMP_PATH_4)); - verify(meshCaCertificateProviderFactory, times(1)) - .create( - eq(distWatcher), - eq(true), - eq(NON_DEFAULT_MESH_CA_URL), - eq("test-zone2"), - eq(234567L), - eq(512), - eq("RSA"), - eq("SHA256withRSA"), - eq(meshCaChannelFactory), - eq(backoffPolicyProvider), - eq(4321L), - eq(3), - (GoogleCredentials) isNull(), - eq(mockService), - eq(timeProvider), - eq(TimeUnit.SECONDS.toMillis(RPC_TIMEOUT_SECONDS))); - } - - private static final String NONDEFAULT_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"target_uri\": \"nonDefaultMeshCaUrl\",\n" - + " \"channel_credentials\": {\"google_default\": {}},\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " \"token_exchange_service\": \"test.sts.com\",\n" - + " \"subject_token_path\": \"/tmp/path4\"\n" - + " }\n" - + " }]\n" // end call_credentials - + " },\n" // end google_grpc - + " \"time_out\": {\"seconds\": 12}\n" - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"certificate_lifetime\": {\"seconds\": 234567},\n" - + " \"renewal_grace_period\": {\"seconds\": 4321},\n" - + " \"key_type\": \"RSA\",\n" - + " \"key_size\": 512,\n" - + " \"location\": \"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3\"\n" - + " }"; - - private static final String MINIMAL_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " \"subject_token_path\": \"/tmp/path5\"\n" - + " }\n" - + " }]\n" // end call_credentials - + " }\n" // end google_grpc - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"location\": \"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3\"\n" - + " }"; - - private static final String V1BETA1_ZONE_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " \"subject_token_path\": \"/tmp/path5\"\n" - + " }\n" - + " }]\n" // end call_credentials - + " }\n" // end google_grpc - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"location\": \"https://container.googleapis.com/v1beta1/projects/test-project1/zones/test-zone2/clusters/test-cluster3\"\n" - + " }"; - - private static final String MINIMAL_BAD_CLUSTER_URL_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " \"subject_token_path\": \"/tmp/path5\"\n" - + " }\n" - + " }]\n" // end call_credentials - + " }\n" // end google_grpc - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"location\": \"https://container.googleapis.com/v1/project/test-project1/locations/test-zone2/clusters/test-cluster3\"\n" - + " }"; - - private static final String MISSING_SAJWT_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " }\n" - + " }]\n" // end call_credentials - + " }\n" // end google_grpc - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"location\": \"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3\"\n" - + " }"; - - private static final String MISSING_GKE_CLUSTER_URL_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"target_uri\": \"meshca.com\",\n" - + " \"channel_credentials\": {\"google_default\": {}},\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " \"token_exchange_service\": \"securetoken.googleapis.com\",\n" - + " \"subject_token_path\": \"/etc/secret/sajwt.token\"\n" - + " }\n" - + " }]\n" // end call_credentials - + " },\n" // end google_grpc - + " \"time_out\": {\"seconds\": 10}\n" - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"certificate_lifetime\": {\"seconds\": 86400},\n" - + " \"renewal_grace_period\": {\"seconds\": 3600},\n" - + " \"key_type\": \"RSA\",\n" - + " \"key_size\": 2048\n" - + " }"; - - private static final String BAD_CHANNEL_CREDS_MESHCA_CONFIG = - "{\n" - + " \"server\": {\n" - + " \"api_type\": \"GRPC\",\n" - + " \"grpc_services\": [{\n" - + " \"google_grpc\": {\n" - + " \"channel_credentials\": {\"mtls\": \"true\"},\n" - + " \"call_credentials\": [{\n" - + " \"sts_service\": {\n" - + " \"subject_token_path\": \"/tmp/path5\"\n" - + " }\n" - + " }]\n" // end call_credentials - + " }\n" // end google_grpc - + " }]\n" // end grpc_services - + " },\n" // end server - + " \"location\": \"https://container.googleapis.com/v1/projects/test-project1/locations/test-zone2/clusters/test-cluster3\"\n" - + " }"; -} diff --git a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderTest.java deleted file mode 100644 index 4b4791dd0c6..00000000000 --- a/xds/src/test/java/io/grpc/xds/internal/certprovider/MeshCaCertificateProviderTest.java +++ /dev/null @@ -1,591 +0,0 @@ -/* - * Copyright 2020 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.certprovider; - -import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_0_PEM_FILE; -import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; -import static org.mockito.AdditionalAnswers.delegatesTo; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.doReturn; -import static org.mockito.Mockito.eq; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import com.google.auth.http.AuthHttpConstants; -import com.google.auth.oauth2.AccessToken; -import com.google.auth.oauth2.GoogleCredentials; -import com.google.common.collect.ImmutableList; -import com.google.common.util.concurrent.MoreExecutors; -import com.google.security.meshca.v1.MeshCertificateRequest; -import com.google.security.meshca.v1.MeshCertificateResponse; -import com.google.security.meshca.v1.MeshCertificateServiceGrpc; -import io.grpc.Context; -import io.grpc.ManagedChannel; -import io.grpc.Metadata; -import io.grpc.ServerCall; -import io.grpc.ServerCallHandler; -import io.grpc.ServerInterceptor; -import io.grpc.Status; -import io.grpc.StatusRuntimeException; -import io.grpc.SynchronizationContext; -import io.grpc.inprocess.InProcessChannelBuilder; -import io.grpc.inprocess.InProcessServerBuilder; -import io.grpc.internal.BackoffPolicy; -import io.grpc.internal.TimeProvider; -import io.grpc.stub.StreamObserver; -import io.grpc.testing.GrpcCleanupRule; -import io.grpc.xds.internal.certprovider.CertificateProvider.DistributorWatcher; -import io.grpc.xds.internal.sds.CommonTlsContextTestsUtil; -import java.io.IOException; -import java.security.NoSuchAlgorithmException; -import java.security.PrivateKey; -import java.security.cert.CertificateException; -import java.security.cert.X509Certificate; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Date; -import java.util.List; -import java.util.Queue; -import java.util.concurrent.Delayed; -import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicBoolean; -import org.bouncycastle.operator.OperatorCreationException; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; -import org.mockito.ArgumentCaptor; -import org.mockito.ArgumentMatchers; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.mockito.Spy; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; - -/** Unit tests for {@link MeshCaCertificateProvider}. */ -@RunWith(JUnit4.class) -public class MeshCaCertificateProviderTest { - - private static final String TEST_STS_TOKEN = "test-stsToken"; - private static final long RENEWAL_GRACE_PERIOD_SECONDS = TimeUnit.HOURS.toSeconds(1L); - private static final Metadata.Key KEY_FOR_AUTHORIZATION = - Metadata.Key.of(AuthHttpConstants.AUTHORIZATION, Metadata.ASCII_STRING_MARSHALLER); - private static final String ZONE = "us-west2-a"; - private static final long START_DELAY = 200_000_000L; // 0.2 seconds - private static final long[] DELAY_VALUES = {START_DELAY, START_DELAY * 2, START_DELAY * 4}; - private static final long RPC_TIMEOUT_MILLIS = 1000L; - /** - * Expire time of cert SERVER_0_PEM_FILE. - */ - static final long CERT0_EXPIRY_TIME_MILLIS = 1899853658000L; - /** - * Cert validity of 12 hours for the above cert. - */ - private static final long CERT0_VALIDITY_MILLIS = TimeUnit.MILLISECONDS - .convert(12, TimeUnit.HOURS); - /** - * Compute current time based on cert expiry and cert validity. - */ - private static final long CURRENT_TIME_NANOS = - TimeUnit.MILLISECONDS.toNanos(CERT0_EXPIRY_TIME_MILLIS - CERT0_VALIDITY_MILLIS); - @Rule - public final GrpcCleanupRule cleanupRule = new GrpcCleanupRule(); - - private static class ResponseToSend { - Throwable getThrowable() { - throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName()); - } - - List getList() { - throw new UnsupportedOperationException("Called on " + getClass().getCanonicalName()); - } - } - - private static class ResponseThrowable extends ResponseToSend { - final Throwable throwableToSend; - - ResponseThrowable(Throwable throwable) { - throwableToSend = throwable; - } - - @Override - Throwable getThrowable() { - return throwableToSend; - } - } - - private static class ResponseList extends ResponseToSend { - final List listToSend; - - ResponseList(List list) { - listToSend = list; - } - - @Override - List getList() { - return listToSend; - } - } - - private final Queue receivedRequests = new ArrayDeque<>(); - private final Queue receivedStsCreds = new ArrayDeque<>(); - private final Queue receivedZoneValues = new ArrayDeque<>(); - private final Queue responsesToSend = new ArrayDeque<>(); - private final Queue oauth2Tokens = new ArrayDeque<>(); - private final AtomicBoolean callEnded = new AtomicBoolean(true); - - @Mock private MeshCertificateServiceGrpc.MeshCertificateServiceImplBase mockedMeshCaService; - @Mock private CertificateProvider.Watcher mockWatcher; - @Mock private BackoffPolicy.Provider backoffPolicyProvider; - @Mock private BackoffPolicy backoffPolicy; - @Spy private GoogleCredentials oauth2Creds; - @Mock private ScheduledExecutorService timeService; - @Mock private TimeProvider timeProvider; - - private ManagedChannel channel; - private MeshCaCertificateProvider provider; - - @Before - public void setUp() throws IOException { - MockitoAnnotations.initMocks(this); - when(backoffPolicyProvider.get()).thenReturn(backoffPolicy); - when(backoffPolicy.nextBackoffNanos()) - .thenReturn(DELAY_VALUES[0], DELAY_VALUES[1], DELAY_VALUES[2]); - doAnswer( - new Answer() { - @Override - public AccessToken answer(InvocationOnMock invocation) throws Throwable { - return new AccessToken( - oauth2Tokens.poll(), new Date(System.currentTimeMillis() + 1000L)); - } - }) - .when(oauth2Creds) - .refreshAccessToken(); - final String meshCaUri = InProcessServerBuilder.generateName(); - MeshCertificateServiceGrpc.MeshCertificateServiceImplBase meshCaServiceImpl = - new MeshCertificateServiceGrpc.MeshCertificateServiceImplBase() { - - @Override - public void createCertificate( - MeshCertificateRequest request, - StreamObserver responseObserver) { - assertThat(callEnded.get()).isTrue(); // ensure previous call was ended - callEnded.set(false); - Context.current() - .addListener( - new Context.CancellationListener() { - @Override - public void cancelled(Context context) { - callEnded.set(true); - } - }, - MoreExecutors.directExecutor()); - receivedRequests.offer(request); - ResponseToSend response = responsesToSend.poll(); - if (response instanceof ResponseThrowable) { - responseObserver.onError(response.getThrowable()); - } else if (response instanceof ResponseList) { - List certChainInResponse = response.getList(); - MeshCertificateResponse responseToSend = - MeshCertificateResponse.newBuilder() - .addAllCertChain(certChainInResponse) - .build(); - responseObserver.onNext(responseToSend); - responseObserver.onCompleted(); - } else { - callEnded.set(true); - } - } - }; - mockedMeshCaService = - mock( - MeshCertificateServiceGrpc.MeshCertificateServiceImplBase.class, - delegatesTo(meshCaServiceImpl)); - ServerInterceptor interceptor = - new ServerInterceptor() { - @Override - public ServerCall.Listener interceptCall( - ServerCall call, Metadata headers, ServerCallHandler next) { - receivedStsCreds.offer(headers.get(KEY_FOR_AUTHORIZATION)); - receivedZoneValues.offer(headers.get(MeshCaCertificateProvider.KEY_FOR_ZONE_INFO)); - return next.startCall(call, headers); - } - }; - cleanupRule.register( - InProcessServerBuilder.forName(meshCaUri) - .addService(mockedMeshCaService) - .intercept(interceptor) - .directExecutor() - .build() - .start()); - channel = - cleanupRule.register(InProcessChannelBuilder.forName(meshCaUri).directExecutor().build()); - MeshCaCertificateProvider.MeshCaChannelFactory channelFactory = - new MeshCaCertificateProvider.MeshCaChannelFactory() { - @Override - ManagedChannel createChannel(String serverUri) { - assertThat(serverUri).isEqualTo(meshCaUri); - return channel; - } - }; - CertificateProvider.DistributorWatcher watcher = new CertificateProvider.DistributorWatcher(); - watcher.addWatcher(mockWatcher); // - provider = - new MeshCaCertificateProvider( - watcher, - true, - meshCaUri, - ZONE, - TimeUnit.HOURS.toSeconds(9L), - 2048, - "RSA", - "SHA256withRSA", - channelFactory, - backoffPolicyProvider, - RENEWAL_GRACE_PERIOD_SECONDS, - MeshCaCertificateProviderProvider.MAX_RETRY_ATTEMPTS_DEFAULT, - oauth2Creds, - timeService, - timeProvider, - RPC_TIMEOUT_MILLIS); - } - - @Test - public void startAndClose() { - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.start(); - SynchronizationContext.ScheduledHandle savedScheduledHandle = provider.scheduledHandle; - assertThat(savedScheduledHandle).isNotNull(); - assertThat(savedScheduledHandle.isPending()).isTrue(); - verify(timeService, times(1)) - .schedule( - any(Runnable.class), - eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS), - eq(TimeUnit.SECONDS)); - DistributorWatcher distWatcher = provider.getWatcher(); - assertThat(distWatcher.downstreamWatchers).hasSize(1); - PrivateKey mockKey = mock(PrivateKey.class); - X509Certificate mockCert = mock(X509Certificate.class); - distWatcher.updateCertificate(mockKey, ImmutableList.of(mockCert)); - distWatcher.updateTrustedRoots(ImmutableList.of(mockCert)); - provider.close(); - assertThat(provider.scheduledHandle).isNull(); - assertThat(savedScheduledHandle.isPending()).isFalse(); - assertThat(distWatcher.downstreamWatchers).isEmpty(); - assertThat(distWatcher.getLastIdentityCert()).isNull(); - } - - @Test - public void startTwice_noException() { - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.start(); - SynchronizationContext.ScheduledHandle savedScheduledHandle1 = provider.scheduledHandle; - provider.start(); - SynchronizationContext.ScheduledHandle savedScheduledHandle2 = provider.scheduledHandle; - assertThat(savedScheduledHandle2).isNotSameInstanceAs(savedScheduledHandle1); - assertThat(savedScheduledHandle2.isPending()).isTrue(); - } - - @Test - public void getCertificate() - throws IOException, CertificateException, OperatorCreationException, - NoSuchAlgorithmException { - oauth2Tokens.offer(TEST_STS_TOKEN + "0"); - responsesToSend.offer( - new ResponseList(ImmutableList.of( - CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE), - CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE), - CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE)))); - when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS); - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture) - .when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.refreshCertificate(); - MeshCertificateRequest receivedReq = receivedRequests.poll(); - assertThat(receivedReq.getValidity().getSeconds()).isEqualTo(TimeUnit.HOURS.toSeconds(9L)); - // cannot decode CSR: just check the PEM format delimiters - String csr = receivedReq.getCsr(); - assertThat(csr).startsWith("-----BEGIN NEW CERTIFICATE REQUEST-----"); - verifyReceivedMetadataValues(1); - verify(timeService, times(1)) - .schedule( - any(Runnable.class), - eq( - TimeUnit.MILLISECONDS.toSeconds( - CERT0_VALIDITY_MILLIS - - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))), - eq(TimeUnit.SECONDS)); - verifyMockWatcher(); - } - - @Test - public void getCertificate_withError() - throws IOException, OperatorCreationException, NoSuchAlgorithmException { - oauth2Tokens.offer(TEST_STS_TOKEN + "0"); - responsesToSend - .offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION))); - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.refreshCertificate(); - verify(mockWatcher, never()) - .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); - verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); - verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION); - verify(timeService, times(1)).schedule(any(Runnable.class), - eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS), - eq(TimeUnit.SECONDS)); - verifyReceivedMetadataValues(1); - } - - @Test - public void getCertificate_withError_withExistingCert() - throws IOException, OperatorCreationException, NoSuchAlgorithmException { - PrivateKey mockKey = mock(PrivateKey.class); - X509Certificate mockCert = mock(X509Certificate.class); - // have current cert expire in 3 hours from current time - long threeHoursFromNowMillis = TimeUnit.NANOSECONDS - .toMillis(CURRENT_TIME_NANOS + TimeUnit.HOURS.toNanos(3)); - when(mockCert.getNotAfter()).thenReturn(new Date(threeHoursFromNowMillis)); - provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert)); - reset(mockWatcher); - oauth2Tokens.offer(TEST_STS_TOKEN + "0"); - responsesToSend - .offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION))); - when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS); - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.refreshCertificate(); - verify(mockWatcher, never()) - .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); - verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); - verify(mockWatcher, never()).onError(any(Status.class)); - verify(timeService, times(1)).schedule(any(Runnable.class), - eq(5400L), - eq(TimeUnit.SECONDS)); - assertThat(provider.getWatcher().getLastIdentityCert()).isNotNull(); - verifyReceivedMetadataValues(1); - } - - @Test - public void getCertificate_withError_withExistingExpiredCert() - throws IOException, OperatorCreationException, NoSuchAlgorithmException { - PrivateKey mockKey = mock(PrivateKey.class); - X509Certificate mockCert = mock(X509Certificate.class); - // have current cert expire in 3 seconds from current time - long threeSecondsFromNowMillis = TimeUnit.NANOSECONDS - .toMillis(CURRENT_TIME_NANOS + TimeUnit.SECONDS.toNanos(3)); - when(mockCert.getNotAfter()).thenReturn(new Date(threeSecondsFromNowMillis)); - provider.getWatcher().updateCertificate(mockKey, ImmutableList.of(mockCert)); - reset(mockWatcher); - oauth2Tokens.offer(TEST_STS_TOKEN + "0"); - responsesToSend - .offer(new ResponseThrowable(new StatusRuntimeException(Status.FAILED_PRECONDITION))); - when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS); - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - provider.refreshCertificate(); - verify(mockWatcher, never()) - .updateCertificate(any(PrivateKey.class), ArgumentMatchers.anyList()); - verify(mockWatcher, never()).updateTrustedRoots(ArgumentMatchers.anyList()); - verify(mockWatcher, times(1)).onError(Status.FAILED_PRECONDITION); - verify(timeService, times(1)).schedule(any(Runnable.class), - eq(MeshCaCertificateProvider.INITIAL_DELAY_SECONDS), - eq(TimeUnit.SECONDS)); - assertThat(provider.getWatcher().getLastIdentityCert()).isNull(); - verifyReceivedMetadataValues(1); - } - - @Test - public void getCertificate_retriesWithErrors() - throws IOException, CertificateException, OperatorCreationException, - NoSuchAlgorithmException { - oauth2Tokens.offer(TEST_STS_TOKEN + "0"); - oauth2Tokens.offer(TEST_STS_TOKEN + "1"); - oauth2Tokens.offer(TEST_STS_TOKEN + "2"); - responsesToSend.offer(new ResponseThrowable(new StatusRuntimeException(Status.UNKNOWN))); - responsesToSend.offer( - new ResponseThrowable( - new Exception(new StatusRuntimeException(Status.RESOURCE_EXHAUSTED)))); - responsesToSend.offer(new ResponseList(ImmutableList.of( - CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE), - CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE), - CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE)))); - when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS); - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - TestScheduledFuture scheduledFutureSleep = new TestScheduledFuture<>(); - doReturn(scheduledFutureSleep).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS)); - provider.refreshCertificate(); - assertThat(receivedRequests.size()).isEqualTo(3); - verify(timeService, times(1)).schedule(any(Runnable.class), - eq(TimeUnit.MILLISECONDS.toSeconds( - CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))), - eq(TimeUnit.SECONDS)); - verifyRetriesWithBackoff(scheduledFutureSleep, 2); - verifyMockWatcher(); - verifyReceivedMetadataValues(3); - } - - @Test - public void getCertificate_retriesWithTimeouts() - throws IOException, CertificateException, OperatorCreationException, - NoSuchAlgorithmException { - oauth2Tokens.offer(TEST_STS_TOKEN + "0"); - oauth2Tokens.offer(TEST_STS_TOKEN + "1"); - oauth2Tokens.offer(TEST_STS_TOKEN + "2"); - oauth2Tokens.offer(TEST_STS_TOKEN + "3"); - responsesToSend.offer(new ResponseToSend()); - responsesToSend.offer(new ResponseToSend()); - responsesToSend.offer(new ResponseToSend()); - responsesToSend.offer(new ResponseList(ImmutableList.of( - CommonTlsContextTestsUtil.getResourceContents(SERVER_0_PEM_FILE), - CommonTlsContextTestsUtil.getResourceContents(SERVER_1_PEM_FILE), - CommonTlsContextTestsUtil.getResourceContents(CA_PEM_FILE)))); - when(timeProvider.currentTimeNanos()).thenReturn(CURRENT_TIME_NANOS); - TestScheduledFuture scheduledFuture = new TestScheduledFuture<>(); - doReturn(scheduledFuture).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.SECONDS)); - TestScheduledFuture scheduledFutureSleep = new TestScheduledFuture<>(); - doReturn(scheduledFutureSleep).when(timeService) - .schedule(any(Runnable.class), any(Long.TYPE), eq(TimeUnit.NANOSECONDS)); - provider.refreshCertificate(); - assertThat(receivedRequests.size()).isEqualTo(4); - verify(timeService, times(1)).schedule(any(Runnable.class), - eq(TimeUnit.MILLISECONDS.toSeconds( - CERT0_VALIDITY_MILLIS - TimeUnit.SECONDS.toMillis(RENEWAL_GRACE_PERIOD_SECONDS))), - eq(TimeUnit.SECONDS)); - verifyRetriesWithBackoff(scheduledFutureSleep, 3); - verifyMockWatcher(); - verifyReceivedMetadataValues(4); - } - - private void verifyRetriesWithBackoff( - TestScheduledFuture scheduledFutureSleep, int numOfRetries) { - for (int i = 0; i < numOfRetries; i++) { - long delayValue = DELAY_VALUES[i]; - verify(timeService, times(1)).schedule(any(Runnable.class), - eq(delayValue), - eq(TimeUnit.NANOSECONDS)); - assertThat(scheduledFutureSleep.calls.get(i).timeout).isEqualTo(delayValue); - assertThat(scheduledFutureSleep.calls.get(i).unit).isEqualTo(TimeUnit.NANOSECONDS); - } - } - - private void verifyMockWatcher() throws IOException, CertificateException { - ArgumentCaptor> certChainCaptor = ArgumentCaptor.forClass(null); - verify(mockWatcher, times(1)) - .updateCertificate(any(PrivateKey.class), certChainCaptor.capture()); - List certChain = certChainCaptor.getValue(); - assertThat(certChain).hasSize(3); - assertThat(certChain.get(0)) - .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_0_PEM_FILE)); - assertThat(certChain.get(1)) - .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(SERVER_1_PEM_FILE)); - assertThat(certChain.get(2)) - .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE)); - - ArgumentCaptor> rootsCaptor = ArgumentCaptor.forClass(null); - verify(mockWatcher, times(1)).updateTrustedRoots(rootsCaptor.capture()); - List roots = rootsCaptor.getValue(); - assertThat(roots).hasSize(1); - assertThat(roots.get(0)) - .isEqualTo(CommonTlsContextTestsUtil.getCertFromResourceName(CA_PEM_FILE)); - verify(mockWatcher, never()).onError(any(Status.class)); - } - - private void verifyReceivedMetadataValues(int count) { - assertThat(receivedStsCreds).hasSize(count); - assertThat(receivedZoneValues).hasSize(count); - for (int i = 0; i < count; i++) { - assertThat(receivedStsCreds.poll()).isEqualTo("Bearer " + TEST_STS_TOKEN + i); - assertThat(receivedZoneValues.poll()).isEqualTo("location=locations/us-west2-a"); - } - } - - static class TestScheduledFuture implements ScheduledFuture { - - static class Record { - long timeout; - TimeUnit unit; - - Record(long timeout, TimeUnit unit) { - this.timeout = timeout; - this.unit = unit; - } - } - - ArrayList calls = new ArrayList<>(); - - @Override - public long getDelay(TimeUnit unit) { - return 0; - } - - @Override - public int compareTo(Delayed o) { - return 0; - } - - @Override - public boolean cancel(boolean mayInterruptIfRunning) { - return false; - } - - @Override - public boolean isCancelled() { - return false; - } - - @Override - public boolean isDone() { - return false; - } - - @Override - public V get() { - return null; - } - - @Override - public V get(long timeout, TimeUnit unit) { - calls.add(new Record(timeout, unit)); - return null; - } - } -} diff --git a/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java b/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java new file mode 100644 index 00000000000..504c9e8df2a --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/rbac/engine/GrpcAuthorizationEngineTest.java @@ -0,0 +1,304 @@ +/* + * Copyright 2021 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.rbac.engine; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.MethodDescriptor.MethodType; +import io.grpc.ServerCall; +import io.grpc.internal.testing.TestUtils; +import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.internal.Matchers; +import io.grpc.xds.internal.Matchers.CidrMatcher; +import io.grpc.xds.internal.Matchers.StringMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.Action; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AlwaysTrueMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AndMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthConfig; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthDecision; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthHeaderMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.AuthenticatedMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.DestinationIpMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.DestinationPortMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.InvertMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.OrMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.PathMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.PolicyMatcher; +import io.grpc.xds.internal.rbac.engine.GrpcAuthorizationEngine.SourceIpMatcher; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.security.Principal; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class GrpcAuthorizationEngineTest { + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + private static final String POLICY_NAME = "policy-name"; + private static final String HEADER_KEY = "header-key"; + private static final String HEADER_VALUE = "header-val"; + private static final String IP_ADDR1 = "10.10.10.0"; + private static final String IP_ADDR2 = "68.36.0.19"; + private static final int PORT = 100; + private static final String PATH = "/auth/engine"; + private static final StringMatcher STRING_MATCHER = StringMatcher.forExact("/" + PATH, false); + private static final Metadata HEADER = metadata(HEADER_KEY, HEADER_VALUE); + + @Mock + private ServerCall serverCall; + @Mock + private SSLSession sslSession; + + @Before + public void setUp() throws Exception { + X509Certificate[] certs = {TestUtils.loadX509Cert("server1.pem")}; + when(sslSession.getPeerCertificates()).thenReturn(certs); + Attributes attributes = Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, new InetSocketAddress(IP_ADDR2, PORT)) + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, new InetSocketAddress(IP_ADDR1, PORT)) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession) + .build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(method().build()); + } + + @Test + public void ipMatcher() throws Exception { + CidrMatcher ip1 = CidrMatcher.create(InetAddress.getByName(IP_ADDR1), 24); + DestinationIpMatcher destIpMatcher = new DestinationIpMatcher(ip1); + CidrMatcher ip2 = CidrMatcher.create(InetAddress.getByName(IP_ADDR2), 24); + SourceIpMatcher sourceIpMatcher = new SourceIpMatcher(ip2); + DestinationPortMatcher portMatcher = new DestinationPortMatcher(PORT); + OrMatcher permission = OrMatcher.create(AndMatcher.create(portMatcher, destIpMatcher)); + OrMatcher principal = OrMatcher.create(sourceIpMatcher); + PolicyMatcher policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + AuthDecision decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + + Attributes attributes = Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, new InetSocketAddress(IP_ADDR2, PORT)) + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, new InetSocketAddress(IP_ADDR1, 2)) + .build(); + when(serverCall.getAttributes()).thenReturn(attributes); + decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.DENY); + assertThat(decision.matchingPolicyName()).isEqualTo(null); + + attributes = Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, null) + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, new InetSocketAddress("1.1.1.1", PORT)) + .build(); + when(serverCall.getAttributes()).thenReturn(attributes); + decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.DENY); + assertThat(decision.matchingPolicyName()).isEqualTo(null); + + engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.DENY)); + decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + assertThat(decision.matchingPolicyName()).isEqualTo(null); + } + + @Test + public void headerMatcher() { + AuthHeaderMatcher headerMatcher = new AuthHeaderMatcher(Matchers.HeaderMatcher + .forExactValue(HEADER_KEY, HEADER_VALUE, false)); + OrMatcher principal = OrMatcher.create(headerMatcher); + OrMatcher permission = OrMatcher.create( + new InvertMatcher(new DestinationPortMatcher(PORT + 1))); + PolicyMatcher policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + AuthDecision decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + + HEADER.put(Metadata.Key.of(HEADER_KEY, Metadata.ASCII_STRING_MARSHALLER), HEADER_VALUE); + headerMatcher = new AuthHeaderMatcher(Matchers.HeaderMatcher + .forExactValue(HEADER_KEY, HEADER_VALUE + "," + HEADER_VALUE, false)); + principal = OrMatcher.create(headerMatcher); + policyMatcher = new PolicyMatcher(POLICY_NAME, + OrMatcher.create(AlwaysTrueMatcher.INSTANCE), principal); + engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + + headerMatcher = new AuthHeaderMatcher(Matchers.HeaderMatcher + .forExactValue(HEADER_KEY + Metadata.BINARY_HEADER_SUFFIX, HEADER_VALUE, false)); + principal = OrMatcher.create(headerMatcher); + policyMatcher = new PolicyMatcher(POLICY_NAME, + OrMatcher.create(AlwaysTrueMatcher.INSTANCE), principal); + engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.DENY); + } + + @Test + public void pathMatcher() { + PathMatcher pathMatcher = new PathMatcher(STRING_MATCHER); + OrMatcher permission = OrMatcher.create(AlwaysTrueMatcher.INSTANCE); + OrMatcher principal = OrMatcher.create(pathMatcher); + PolicyMatcher policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.DENY)); + AuthDecision decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.DENY); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + } + + @Test + public void authenticatedMatcher() throws Exception { + AuthenticatedMatcher authMatcher = new AuthenticatedMatcher( + StringMatcher.forExact("*.test.google.fr", false)); + PathMatcher pathMatcher = new PathMatcher(STRING_MATCHER); + OrMatcher permission = OrMatcher.create(authMatcher); + OrMatcher principal = OrMatcher.create(pathMatcher); + PolicyMatcher policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(Collections.singletonList(policyMatcher), Action.ALLOW)); + AuthDecision decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.ALLOW); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + + X509Certificate[] certs = {TestUtils.loadX509Cert("badserver.pem")}; + when(sslSession.getPeerCertificates()).thenReturn(certs); + decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.DENY); + assertThat(decision.matchingPolicyName()).isEqualTo(null); + + X509Certificate mockCert = mock(X509Certificate.class); + when(sslSession.getPeerCertificates()).thenReturn(new X509Certificate[]{mockCert}); + assertThat(engine.evaluate(HEADER, serverCall).decision()).isEqualTo(Action.DENY); + when(mockCert.getSubjectDN()).thenReturn(mock(Principal.class)); + assertThat(engine.evaluate(HEADER, serverCall).decision()).isEqualTo(Action.DENY); + when(mockCert.getSubjectAlternativeNames()).thenReturn(Arrays.>asList( + Arrays.asList(2, "*.test.google.fr"))); + assertThat(engine.evaluate(HEADER, serverCall).decision()).isEqualTo(Action.ALLOW); + when(mockCert.getSubjectAlternativeNames()).thenReturn(Arrays.>asList( + Arrays.asList(6, "*.test.google.fr"))); + assertThat(engine.evaluate(HEADER, serverCall).decision()).isEqualTo(Action.ALLOW); + when(mockCert.getSubjectAlternativeNames()).thenReturn(Arrays.>asList( + Arrays.asList(10, "*.test.google.fr"))); + assertThat(engine.evaluate(HEADER, serverCall).decision()).isEqualTo(Action.DENY); + when(mockCert.getSubjectAlternativeNames()).thenReturn(Arrays.>asList( + Arrays.asList(2, "google.com"), Arrays.asList(6, "*.test.google.fr"))); + assertThat(engine.evaluate(HEADER, serverCall).decision()).isEqualTo(Action.ALLOW); + when(mockCert.getSubjectAlternativeNames()).thenReturn(Arrays.>asList( + Arrays.asList(6, "*.test.google.fr"), Arrays.asList(2, "google.com"))); + assertThat(engine.evaluate(HEADER, serverCall).decision()).isEqualTo(Action.ALLOW); + when(mockCert.getSubjectAlternativeNames()).thenReturn(Arrays.>asList( + Arrays.asList(2, "*.test.google.fr"), Arrays.asList(6, "google.com"))); + assertThat(engine.evaluate(HEADER, serverCall).decision()).isEqualTo(Action.DENY); + when(mockCert.getSubjectAlternativeNames()).thenReturn(Arrays.>asList( + Arrays.asList(2, "*.test.google.fr"), Arrays.asList(6, "google.com"), + Arrays.asList(6, "*.test.google.fr"))); + assertThat(engine.evaluate(HEADER, serverCall).decision()).isEqualTo(Action.ALLOW); + + // match any authenticated connection if StringMatcher not set in AuthenticatedMatcher + permission = OrMatcher.create(new AuthenticatedMatcher(null)); + policyMatcher = new PolicyMatcher(POLICY_NAME, permission, principal); + when(mockCert.getSubjectAlternativeNames()).thenReturn( + Arrays.>asList(Arrays.asList(6, "random"))); + engine = new GrpcAuthorizationEngine(new AuthConfig(Collections.singletonList(policyMatcher), + Action.ALLOW)); + assertThat(engine.evaluate(HEADER, serverCall).decision()).isEqualTo(Action.ALLOW); + + // not match any unauthenticated connection + Attributes attributes = Attributes.newBuilder() + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, new InetSocketAddress(IP_ADDR2, PORT)) + .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, new InetSocketAddress(IP_ADDR1, PORT)) + .build(); + when(serverCall.getAttributes()).thenReturn(attributes); + assertThat(engine.evaluate(HEADER, serverCall).decision()).isEqualTo(Action.DENY); + + doThrow(new SSLPeerUnverifiedException("bad")).when(sslSession).getPeerCertificates(); + decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.DENY); + assertThat(decision.matchingPolicyName()).isEqualTo(null); + } + + @Test + public void multiplePolicies() throws Exception { + AuthenticatedMatcher authMatcher = new AuthenticatedMatcher( + StringMatcher.forSuffix("TEST.google.fr", true)); + PathMatcher pathMatcher = new PathMatcher(STRING_MATCHER); + OrMatcher principal = OrMatcher.create(AndMatcher.create(authMatcher, pathMatcher)); + OrMatcher permission = OrMatcher.create(AndMatcher.create(pathMatcher, + new InvertMatcher(new DestinationPortMatcher(PORT + 1)))); + PolicyMatcher policyMatcher1 = new PolicyMatcher(POLICY_NAME, permission, principal); + + AuthHeaderMatcher headerMatcher = new AuthHeaderMatcher(Matchers.HeaderMatcher + .forExactValue(HEADER_KEY, HEADER_VALUE + 1, false)); + authMatcher = new AuthenticatedMatcher( + StringMatcher.forContains("TEST.google.fr")); + principal = OrMatcher.create(headerMatcher, authMatcher); + CidrMatcher ip1 = CidrMatcher.create(InetAddress.getByName(IP_ADDR1), 24); + DestinationIpMatcher destIpMatcher = new DestinationIpMatcher(ip1); + permission = OrMatcher.create(destIpMatcher, pathMatcher); + PolicyMatcher policyMatcher2 = new PolicyMatcher(POLICY_NAME + "-2", permission, principal); + + GrpcAuthorizationEngine engine = new GrpcAuthorizationEngine( + new AuthConfig(ImmutableList.of(policyMatcher1, policyMatcher2), Action.DENY)); + AuthDecision decision = engine.evaluate(HEADER, serverCall); + assertThat(decision.decision()).isEqualTo(Action.DENY); + assertThat(decision.matchingPolicyName()).isEqualTo(POLICY_NAME); + } + + private MethodDescriptor.Builder method() { + return MethodDescriptor.newBuilder() + .setType(MethodType.BIDI_STREAMING) + .setFullMethodName(PATH) + .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) + .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()); + } + + private static Metadata metadata(String key, String value) { + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), value); + return metadata; + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java index 2fb30a91be9..e2d0fdab58a 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ClientSslContextProviderFactoryTest.java @@ -23,8 +23,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableSet; @@ -56,7 +54,6 @@ @RunWith(JUnit4.class) public class ClientSslContextProviderFactoryTest { - Bootstrapper bootstrapper; CertificateProviderRegistry certificateProviderRegistry; CertificateProviderStore certificateProviderStore; CertProviderClientSslContextProvider.Factory certProviderClientSslContextProviderFactory; @@ -64,18 +61,17 @@ public class ClientSslContextProviderFactoryTest { @Before public void setUp() { - bootstrapper = mock(Bootstrapper.class); certificateProviderRegistry = new CertificateProviderRegistry(); certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry); certProviderClientSslContextProviderFactory = new CertProviderClientSslContextProvider.Factory(certificateProviderStore); - clientSslContextProviderFactory = - new ClientSslContextProviderFactory( - bootstrapper, certProviderClientSslContextProviderFactory); } @Test public void createSslContextProvider_allFilenames() { + clientSslContextProviderFactory = + new ClientSslContextProviderFactory( + null, certProviderClientSslContextProviderFactory); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); @@ -87,6 +83,9 @@ public void createSslContextProvider_allFilenames() { @Test public void createSslContextProvider_sdsConfigForTlsCert_expectException() { + clientSslContextProviderFactory = + new ClientSslContextProviderFactory( + null, certProviderClientSslContextProviderFactory); CommonTlsContext commonTlsContext = CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForTlsCertificate( /* name= */ "name", /* targetUri= */ "unix:/tmp/sds/path", CA_PEM_FILE); @@ -103,6 +102,9 @@ public void createSslContextProvider_sdsConfigForTlsCert_expectException() { @Test public void createSslContextProvider_sdsConfigForCertValidationContext_expectException() { + clientSslContextProviderFactory = + new ClientSslContextProviderFactory( + null, certProviderClientSslContextProviderFactory); CommonTlsContext commonTlsContext = CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForValidationContext( /* name= */ "name", @@ -136,7 +138,9 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio /* staticCertValidationContext= */ null); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + clientSslContextProviderFactory = + new ClientSslContextProviderFactory( + bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); @@ -145,7 +149,6 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); - verify(bootstrapper, times(1)).bootstrap(); } @Test @@ -168,7 +171,9 @@ public void bothPresent_expectCertProviderClientSslContextProvider() upstreamTlsContext = new UpstreamTlsContext(builder.build()); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + clientSslContextProviderFactory = + new ClientSslContextProviderFactory( + bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); @@ -191,7 +196,9 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() /* staticCertValidationContext= */ null); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + clientSslContextProviderFactory = + new ClientSslContextProviderFactory( + bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); @@ -221,7 +228,9 @@ public void createCertProviderClientSslContextProvider_withStaticContext() staticCertValidationContext); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + clientSslContextProviderFactory = + new ClientSslContextProviderFactory(bootstrapInfo, + certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); @@ -248,7 +257,9 @@ public void createCertProviderClientSslContextProvider_2providers() /* staticCertValidationContext= */ null); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + clientSslContextProviderFactory = + new ClientSslContextProviderFactory( + bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider).isInstanceOf(CertProviderClientSslContextProvider.class); @@ -256,29 +267,11 @@ public void createCertProviderClientSslContextProvider_2providers() verifyWatcher(sslContextProvider, watcherCaptor[1]); } - @Test - public void createCertProviderClientSslContextProvider_exception() - throws XdsInitializationException { - UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( - "gcp_id", - "cert-default", - "gcp_id", - "root-default", - /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); - when(bootstrapper.bootstrap()) - .thenThrow(new XdsInitializationException("test exception")); - try { - clientSslContextProviderFactory.create(upstreamTlsContext); - Assert.fail("no exception thrown"); - } catch (RuntimeException expected) { - assertThat(expected).hasMessageThat().contains("test exception"); - } - } - @Test public void createEmptyCommonTlsContext_exception() throws IOException { + clientSslContextProviderFactory = + new ClientSslContextProviderFactory( + null, certProviderClientSslContextProviderFactory); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames(null, null, null); try { @@ -293,6 +286,9 @@ public void createEmptyCommonTlsContext_exception() throws IOException { @Test public void createNullCommonTlsContext_exception() throws IOException { + clientSslContextProviderFactory = + new ClientSslContextProviderFactory( + null, certProviderClientSslContextProviderFactory); UpstreamTlsContext upstreamTlsContext = new UpstreamTlsContext(null); try { clientSslContextProviderFactory.create(upstreamTlsContext); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java index 87e85e615bc..1f1d32644e6 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SdsProtocolNegotiatorsTest.java @@ -48,6 +48,7 @@ import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.InternalXdsAttributes; +import io.grpc.xds.TlsContextManager; import io.grpc.xds.XdsClientWrapperForServerSds; import io.grpc.xds.XdsClientWrapperForServerSdsTestMisc; import io.grpc.xds.internal.sds.SdsProtocolNegotiators.ClientSdsHandler; @@ -206,7 +207,7 @@ public void clientSdsHandler_addLast() throws IOException { buildUpstreamTlsContextFromFilenames(CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); SslContextProviderSupplier sslContextProviderSupplier = - new SslContextProviderSupplier(upstreamTlsContext, TlsContextManagerImpl.getInstance()); + new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(null)); SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler = new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier); pipeline.addLast(clientSdsHandler); @@ -248,7 +249,7 @@ public SocketAddress remoteAddress() { XdsClientWrapperForServerSds xdsClientWrapperForServerSds = XdsClientWrapperForServerSdsTestMisc.createXdsClientWrapperForServerSds( - 80, downstreamTlsContext); + 80, downstreamTlsContext, new TlsContextManagerImpl(null)); SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = new SdsProtocolNegotiators.HandlerPickerHandler(grpcHandler, xdsClientWrapperForServerSds, InternalProtocolNegotiators.serverPlaintext()); @@ -296,7 +297,7 @@ public SocketAddress localAddress() { XdsClientWrapperForServerSds xdsClientWrapperForServerSds = XdsClientWrapperForServerSdsTestMisc.createXdsClientWrapperForServerSds( - 80, downstreamTlsContext); + 80, downstreamTlsContext, mock(TlsContextManager.class)); SdsProtocolNegotiators.HandlerPickerHandler handlerPickerHandler = new SdsProtocolNegotiators.HandlerPickerHandler( grpcHandler, xdsClientWrapperForServerSds, mockProtocolNegotiator); @@ -369,7 +370,7 @@ public void clientSdsProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() buildUpstreamTlsContextFromFilenames(CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE); SslContextProviderSupplier sslContextProviderSupplier = - new SslContextProviderSupplier(upstreamTlsContext, TlsContextManagerImpl.getInstance()); + new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(null)); SdsProtocolNegotiators.ClientSdsHandler clientSdsHandler = new SdsProtocolNegotiators.ClientSdsHandler(grpcHandler, sslContextProviderSupplier); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java index 75247654ff2..b6a0fdd2e2c 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/ServerSslContextProviderFactoryTest.java @@ -22,10 +22,6 @@ import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.CA_PEM_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_KEY_FILE; import static io.grpc.xds.internal.sds.CommonTlsContextTestsUtil.SERVER_1_PEM_FILE; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableSet; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; @@ -51,7 +47,6 @@ @RunWith(JUnit4.class) public class ServerSslContextProviderFactoryTest { - Bootstrapper bootstrapper; CertificateProviderRegistry certificateProviderRegistry; CertificateProviderStore certificateProviderStore; CertProviderServerSslContextProvider.Factory certProviderServerSslContextProviderFactory; @@ -59,18 +54,17 @@ public class ServerSslContextProviderFactoryTest { @Before public void setUp() { - bootstrapper = mock(Bootstrapper.class); certificateProviderRegistry = new CertificateProviderRegistry(); certificateProviderStore = new CertificateProviderStore(certificateProviderRegistry); certProviderServerSslContextProviderFactory = new CertProviderServerSslContextProvider.Factory(certificateProviderStore); - serverSslContextProviderFactory = - new ServerSslContextProviderFactory( - bootstrapper, certProviderServerSslContextProviderFactory); } @Test public void createSslContextProvider_allFilenames() { + serverSslContextProviderFactory = + new ServerSslContextProviderFactory( + null, certProviderServerSslContextProviderFactory); DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, CA_PEM_FILE); @@ -82,6 +76,9 @@ public void createSslContextProvider_allFilenames() { @Test public void createSslContextProvider_sdsConfigForTlsCert_expectException() { + serverSslContextProviderFactory = + new ServerSslContextProviderFactory( + null, certProviderServerSslContextProviderFactory); CommonTlsContext commonTlsContext = CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForTlsCertificate( "name", "unix:/tmp/sds/path", CA_PEM_FILE); @@ -100,6 +97,9 @@ public void createSslContextProvider_sdsConfigForTlsCert_expectException() { @Test public void createSslContextProvider_sdsConfigForCertValidationContext_expectException() { + serverSslContextProviderFactory = + new ServerSslContextProviderFactory( + null, certProviderServerSslContextProviderFactory); CommonTlsContext commonTlsContext = CommonTlsContextTestsUtil.buildCommonTlsContextFromSdsConfigForValidationContext( "name", "unix:/tmp/sds/path", SERVER_1_KEY_FILE, SERVER_1_PEM_FILE); @@ -132,7 +132,9 @@ public void createCertProviderServerSslContextProvider() throws XdsInitializatio /* requireClientCert= */ true); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + serverSslContextProviderFactory = + new ServerSslContextProviderFactory( + bootstrapInfo, certProviderServerSslContextProviderFactory); SslContextProvider sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); @@ -141,7 +143,6 @@ public void createCertProviderServerSslContextProvider() throws XdsInitializatio sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); - verify(bootstrapper, times(1)).bootstrap(); } @Test @@ -168,7 +169,9 @@ public void bothPresent_expectCertProviderServerSslContextProvider() builder.build(), downstreamTlsContext.isRequireClientCertificate()); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + serverSslContextProviderFactory = + new ServerSslContextProviderFactory( + bootstrapInfo, certProviderServerSslContextProviderFactory); SslContextProvider sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); @@ -192,7 +195,9 @@ public void createCertProviderServerSslContextProvider_onlyCertInstance() /* requireClientCert= */ true); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + serverSslContextProviderFactory = + new ServerSslContextProviderFactory( + bootstrapInfo, certProviderServerSslContextProviderFactory); SslContextProvider sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); @@ -223,7 +228,9 @@ public void createCertProviderServerSslContextProvider_withStaticContext() /* requireClientCert= */ true); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + serverSslContextProviderFactory = + new ServerSslContextProviderFactory( + bootstrapInfo, certProviderServerSslContextProviderFactory); SslContextProvider sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); @@ -251,7 +258,9 @@ public void createCertProviderServerSslContextProvider_2providers() /* requireClientCert= */ true); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); - when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + serverSslContextProviderFactory = + new ServerSslContextProviderFactory( + bootstrapInfo, certProviderServerSslContextProviderFactory); SslContextProvider sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider).isInstanceOf(CertProviderServerSslContextProvider.class); @@ -259,32 +268,13 @@ public void createCertProviderServerSslContextProvider_2providers() verifyWatcher(sslContextProvider, watcherCaptor[1]); } - @Test - public void createCertProviderServerSslContextProvider_exception() - throws XdsInitializationException { - DownstreamTlsContext downstreamTlsContext = - CommonTlsContextTestsUtil.buildDownstreamTlsContextForCertProviderInstance( - "gcp_id", - "cert-default", - "gcp_id", - "root-default", - /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null, - /* requireClientCert= */ true); - when(bootstrapper.bootstrap()) - .thenThrow(new XdsInitializationException("test exception")); - try { - serverSslContextProviderFactory.create(downstreamTlsContext); - Assert.fail("no exception thrown"); - } catch (RuntimeException expected) { - assertThat(expected).hasMessageThat().contains("test exception"); - } - } - @Test public void createEmptyCommonTlsContext_exception() throws IOException { DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames(null, null, null); + serverSslContextProviderFactory = + new ServerSslContextProviderFactory( + null, certProviderServerSslContextProviderFactory); try { serverSslContextProviderFactory.create(downstreamTlsContext); Assert.fail("no exception thrown"); @@ -297,6 +287,9 @@ public void createEmptyCommonTlsContext_exception() throws IOException { @Test public void createNullCommonTlsContext_exception() throws IOException { + serverSslContextProviderFactory = + new ServerSslContextProviderFactory( + null, certProviderServerSslContextProviderFactory); DownstreamTlsContext downstreamTlsContext = new DownstreamTlsContext(null, true); try { serverSslContextProviderFactory.create(downstreamTlsContext); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java index 881d2f1efd2..0395f3055e9 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/SslContextProviderSupplierTest.java @@ -23,11 +23,16 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import com.google.common.util.concurrent.MoreExecutors; import io.grpc.xds.EnvoyServerProtoData; +import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; import java.util.concurrent.Executor; import org.junit.Assert; @@ -64,16 +69,20 @@ private void prepareSupplier() { doReturn(mockSslContextProvider) .when(mockTlsContextManager) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); + } + + private void callUpdateSslContext() { mockCallback = mock(SslContextProvider.Callback.class); Executor mockExecutor = mock(Executor.class); doReturn(mockExecutor).when(mockCallback).getExecutor(); - supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); supplier.updateSslContext(mockCallback); } @Test public void get_updateSecret() { prepareSupplier(); + callUpdateSslContext(); verify(mockTlsContextManager, times(2)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); verify(mockTlsContextManager, times(0)) @@ -96,6 +105,7 @@ public void get_updateSecret() { @Test public void get_onException() { prepareSupplier(); + callUpdateSslContext(); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(null); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); @@ -108,15 +118,47 @@ public void get_onException() { @Test public void testClose() { prepareSupplier(); + callUpdateSslContext(); supplier.close(); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); - SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - try { - supplier.updateSslContext(mockCallback); - Assert.fail("no exception thrown"); - } catch (IllegalStateException expected) { - assertThat(expected).hasMessageThat().isEqualTo("Supplier is shutdown!"); - } + .releaseClientSslContextProvider(eq(mockSslContextProvider)); + SslContextProvider.Callback mockCallback = spy( + new SslContextProvider.Callback(MoreExecutors.directExecutor()) { + @Override + public void updateSecret(SslContext sslContext) { + Assert.fail("unexpected call"); + } + + @Override + protected void onException(Throwable argument) { + assertThat(argument).isInstanceOf(IllegalStateException.class); + assertThat(argument).hasMessageThat().contains("Supplier is shutdown!"); + } + }); + supplier.updateSslContext(mockCallback); + } + + @Test + public void testClose_nullSslContextProvider() { + prepareSupplier(); + doThrow(new NullPointerException()).when(mockTlsContextManager) + .releaseClientSslContextProvider(null); + supplier.close(); + verify(mockTlsContextManager, never()) + .releaseClientSslContextProvider(eq(mockSslContextProvider)); + SslContextProvider.Callback mockCallback = spy( + new SslContextProvider.Callback(MoreExecutors.directExecutor()) { + @Override + public void updateSecret(SslContext sslContext) { + Assert.fail("unexpected call"); + } + + @Override + protected void onException(Throwable argument) { + assertThat(argument).isInstanceOf(IllegalStateException.class); + assertThat(argument).hasMessageThat().contains("Supplier is shutdown!"); + } + }); + supplier.updateSslContext(mockCallback); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java index d69892b7c5f..5f6ba418e18 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/TlsContextManagerTest.java @@ -33,8 +33,6 @@ import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.internal.sds.ReferenceCountingMap.ValueFactory; -import java.lang.reflect.Field; -import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -53,20 +51,13 @@ public class TlsContextManagerTest { @Mock ValueFactory mockServerFactory; - @Before - public void clearInstance() throws NoSuchFieldException, IllegalAccessException { - Field field = TlsContextManagerImpl.class.getDeclaredField("instance"); - field.setAccessible(true); - field.set(null, null); - } - @Test public void createServerSslContextProvider() { DownstreamTlsContext downstreamTlsContext = CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); - TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); + TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(null); SslContextProvider serverSecretProvider = tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); assertThat(serverSecretProvider).isNotNull(); @@ -82,7 +73,7 @@ public void createClientSslContextProvider() { CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( /* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); - TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); + TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(null); SslContextProvider clientSecretProvider = tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); assertThat(clientSecretProvider).isNotNull(); @@ -98,7 +89,7 @@ public void createServerSslContextProvider_differentInstance() { CommonTlsContextTestsUtil.buildDownstreamTlsContextFromFilenames( SERVER_1_KEY_FILE, SERVER_1_PEM_FILE, /* trustCa= */ null); - TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); + TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(null); SslContextProvider serverSecretProvider = tlsContextManagerImpl.findOrCreateServerSslContextProvider(downstreamTlsContext); assertThat(serverSecretProvider).isNotNull(); @@ -118,7 +109,7 @@ public void createClientSslContextProvider_differentInstance() { CommonTlsContextTestsUtil.buildUpstreamTlsContextFromFilenames( /* privateKey= */ null, /* certChain= */ null, CA_PEM_FILE); - TlsContextManagerImpl tlsContextManagerImpl = TlsContextManagerImpl.getInstance(); + TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(null); SslContextProvider clientSecretProvider = tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); assertThat(clientSecretProvider).isNotNull(); diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/XdsChannelBuilderTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/XdsChannelBuilderTest.java deleted file mode 100644 index b09e42a1a79..00000000000 --- a/xds/src/test/java/io/grpc/xds/internal/sds/XdsChannelBuilderTest.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright 2019 The gRPC Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.grpc.xds.internal.sds; - -import static com.google.common.truth.Truth.assertThat; - -import io.grpc.ManagedChannel; -import io.grpc.netty.NettyChannelBuilder; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; - -/** - * Unit tests for {@link XdsChannelBuilder}. - */ -@RunWith(JUnit4.class) -public class XdsChannelBuilderTest { - - @Test - public void buildsXdsChannelBuilder() { - XdsChannelBuilder builder = XdsChannelBuilder.forTarget("localhost:8080"); - assertThat(builder).isNotNull(); - assertThat(builder.delegate()).isInstanceOf(NettyChannelBuilder.class); - ManagedChannel channel = builder.build(); - assertThat(channel).isNotNull(); - } -} diff --git a/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManagerTest.java index 35218929651..166b60f4caf 100644 --- a/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/sds/trust/SdsX509TrustManagerTest.java @@ -29,6 +29,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.common.collect.ImmutableList; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; @@ -37,6 +38,8 @@ import java.security.cert.CertStoreException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.Collections; +import java.util.List; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLParameters; import javax.net.ssl.SSLSession; @@ -551,6 +554,29 @@ public void checkServerTrustedSslSocket_untrustedServer_expectException() verify(sslSocket, times(1)).getHandshakeSession(); } + @Test + public void unsupportedAltNameType() throws CertificateException, IOException { + StringMatcher stringMatcher = + StringMatcher.newBuilder() + .setExact("waterzooi.test.google.be") + .setIgnoreCase(false) + .build(); + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new SdsX509TrustManager(certContext, mockDelegate); + X509Certificate mockCert = mock(X509Certificate.class); + + when(mockCert.getSubjectAlternativeNames()) + .thenReturn(Collections.>singleton(ImmutableList.of(Integer.valueOf(1), "foo"))); + X509Certificate[] certs = new X509Certificate[] {mockCert}; + try { + trustManager.verifySubjectAltNameInChain(certs); + fail("no exception thrown"); + } catch (CertificateException expected) { + assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); + } + } + private TestSslEngine buildTrustManagerAndGetSslEngine() throws CertificateException, IOException, CertStoreException { SSLParameters sslParams = buildTrustManagerAndGetSslParameters(); diff --git a/xds/third_party/envoy/import.sh b/xds/third_party/envoy/import.sh index e1636db3198..7569de35ee9 100755 --- a/xds/third_party/envoy/import.sh +++ b/xds/third_party/envoy/import.sh @@ -125,6 +125,7 @@ envoy/config/trace/v3/zipkin.proto envoy/extensions/clusters/aggregate/v3/cluster.proto envoy/extensions/filters/common/fault/v3/fault.proto envoy/extensions/filters/http/fault/v3/fault.proto +envoy/extensions/filters/http/rbac/v3/rbac.proto envoy/extensions/filters/http/router/v3/router.proto envoy/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto envoy/extensions/transport_sockets/tls/v3/cert.proto diff --git a/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto new file mode 100644 index 00000000000..67cb338ef1f --- /dev/null +++ b/xds/third_party/envoy/src/main/proto/envoy/extensions/filters/http/rbac/v3/rbac.proto @@ -0,0 +1,48 @@ +syntax = "proto3"; + +package envoy.extensions.filters.http.rbac.v3; + +import "envoy/config/rbac/v3/rbac.proto"; + +import "udpa/annotations/status.proto"; +import "udpa/annotations/versioning.proto"; + +option java_package = "io.envoyproxy.envoy.extensions.filters.http.rbac.v3"; +option java_outer_classname = "RbacProto"; +option java_multiple_files = true; +option (udpa.annotations.file_status).package_version_status = ACTIVE; + +// [#protodoc-title: RBAC] +// Role-Based Access Control :ref:`configuration overview `. +// [#extension: envoy.filters.http.rbac] + +// RBAC filter config. +message RBAC { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.rbac.v2.RBAC"; + + // Specify the RBAC rules to be applied globally. + // If absent, no enforcing RBAC policy will be applied. + config.rbac.v3.RBAC rules = 1; + + // Shadow rules are not enforced by the filter (i.e., returning a 403) + // but will emit stats and logs and can be used for rule testing. + // If absent, no shadow RBAC policy will be applied. + config.rbac.v3.RBAC shadow_rules = 2; + + // If specified, shadow rules will emit stats with the given prefix. + // This is useful to distinguish the stat when there are more than 1 RBAC filter configured with + // shadow rules. + string shadow_rules_stat_prefix = 3; +} + +message RBACPerRoute { + option (udpa.annotations.versioning).previous_message_type = + "envoy.config.filter.http.rbac.v2.RBACPerRoute"; + + reserved 1; + + // Override the global configuration of the filter with this new config. + // If absent, the global RBAC policy will be disabled for this route. + RBAC rbac = 2; +}