Skip to content

Add enterprise license check for Inference API actions #119893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jan 13, 2025
5 changes: 5 additions & 0 deletions docs/changelog/119893.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 119893
summary: Add enterprise license check for Inference API actions
area: Machine Learning
type: enhancement
issues: []
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.junit.ClassRule;

import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.mockSparseServiceModelConfig;

public class InferenceBasicLicenseIT extends InferenceLicenseBaseRestTest {
@ClassRule
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.distribution(DistributionType.DEFAULT)
.setting("xpack.license.self_generated.type", "basic")
.setting("xpack.security.enabled", "true")
.user("x_pack_rest_user", "x-pack-test-password")
.plugin("inference-service-test")
.build();

@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
}

@Override
protected Settings restClientSettings() {
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
}

public void testPutModel_RestrictedWithBasicLicense() throws Exception {
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
var modelConfig = mockSparseServiceModelConfig(null, true);
sendRestrictedRequest("PUT", endpoint, modelConfig);
}

public void testUpdateModel_RestrictedWithBasicLicense() throws Exception {
var endpoint = Strings.format("_inference/%s/%s/_update?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
var requestBody = """
{
"task_settings": {
"num_threads": 2
}
}
""";
sendRestrictedRequest("PUT", endpoint, requestBody);
}

public void testPerformInference_RestrictedWithBasicLicense() throws Exception {
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
var requestBody = """
{
"input": ["washing", "machine"]
}
""";
sendRestrictedRequest("POST", endpoint, requestBody);
}

public void testGetServices_NonRestrictedWithBasicLicense() throws Exception {
var endpoint = "_inference/_services";
sendNonRestrictedRequest("GET", endpoint, null, 200, false);
}

public void testGetModels_NonRestrictedWithBasicLicense() throws Exception {
var endpoint = "_inference/_all";
sendNonRestrictedRequest("GET", endpoint, null, 200, false);
}

public void testDeleteModel_NonRestrictedWithBasicLicense() throws Exception {
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
sendNonRestrictedRequest("DELETE", endpoint, null, 404, true);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.client.Request;
import org.elasticsearch.client.ResponseException;
import org.elasticsearch.test.rest.ESRestTestCase;

import java.io.IOException;

import static org.hamcrest.Matchers.containsString;

public class InferenceLicenseBaseRestTest extends ESRestTestCase {
protected void sendRestrictedRequest(String method, String endpoint, String body) throws IOException {
var request = new Request(method, endpoint);
request.setJsonEntity(body);

var exception = assertThrows(ResponseException.class, () -> client().performRequest(request));
assertEquals(403, exception.getResponse().getStatusLine().getStatusCode());
assertThat(exception.getMessage(), containsString("current license is non-compliant for [inference]"));
}

protected void sendNonRestrictedRequest(String method, String endpoint, String body, int expectedStatusCode, boolean exceptionExpected)
throws IOException {
var request = new Request(method, endpoint);
request.setJsonEntity(body);

int actualStatusCode;
if (exceptionExpected) {
var exception = assertThrows(ResponseException.class, () -> client().performRequest(request));
actualStatusCode = exception.getResponse().getStatusLine().getStatusCode();
} else {
var response = client().performRequest(request);
actualStatusCode = response.getStatusLine().getStatusCode();
}
assertEquals(expectedStatusCode, actualStatusCode);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference;

import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.junit.ClassRule;

import static org.elasticsearch.xpack.inference.InferenceBaseRestTest.mockSparseServiceModelConfig;

public class InferenceTrialLicenseIT extends InferenceLicenseBaseRestTest {
@ClassRule
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.distribution(DistributionType.DEFAULT)
.setting("xpack.license.self_generated.type", "trial")
.setting("xpack.security.enabled", "true")
.user("x_pack_rest_user", "x-pack-test-password")
.plugin("inference-service-test")
.build();

@Override
protected String getTestRestCluster() {
return cluster.getHttpAddresses();
}

@Override
protected Settings restClientSettings() {
String token = basicAuthHeaderValue("x_pack_rest_user", new SecureString("x-pack-test-password".toCharArray()));
return Settings.builder().put(ThreadContext.PREFIX + ".Authorization", token).build();
}

public void testPutModel_NonRestrictedWithTrialLicense() throws Exception {
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
var modelConfig = mockSparseServiceModelConfig(null, true);
sendNonRestrictedRequest("PUT", endpoint, modelConfig, 200, false);
}

public void testUpdateModel_NonRestrictedWithTrialLicense() throws Exception {
var endpoint = Strings.format("_inference/%s/%s/_update?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
var requestBody = """
{
"task_settings": {
"num_threads": 2
}
}
""";
sendNonRestrictedRequest("PUT", endpoint, requestBody, 404, true);
}

public void testPerformInference_NonRestrictedWithTrialLicense() throws Exception {
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
var requestBody = """
{
"input": ["washing", "machine"]
}
""";
sendNonRestrictedRequest("POST", endpoint, requestBody, 404, true);
}

public void testGetServices_NonRestrictedWithBasicLicense() throws Exception {
var endpoint = "_inference/_services";
sendNonRestrictedRequest("GET", endpoint, null, 200, false);
}

public void testGetModels_NonRestrictedWithBasicLicense() throws Exception {
var endpoint = "_inference/_all";
sendNonRestrictedRequest("GET", endpoint, null, 200, false);
}

public void testDeleteModel_NonRestrictedWithBasicLicense() throws Exception {
var endpoint = Strings.format("_inference/%s/%s?error_trace", TaskType.SPARSE_EMBEDDING, "endpoint-id");
sendNonRestrictedRequest("DELETE", endpoint, null, 404, true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
import org.elasticsearch.client.Request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.cluster.ElasticsearchCluster;
import org.elasticsearch.test.cluster.local.distribution.DistributionType;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.upgrades.AbstractRollingUpgradeTestCase;
import org.elasticsearch.upgrades.ParameterizedRollingUpgradeTestCase;
import org.junit.ClassRule;

import java.io.IOException;
import java.util.LinkedList;
Expand All @@ -22,14 +25,28 @@

import static org.elasticsearch.core.Strings.format;

public class InferenceUpgradeTestCase extends AbstractRollingUpgradeTestCase {
public class InferenceUpgradeTestCase extends ParameterizedRollingUpgradeTestCase {

static final String MODELS_RENAMED_TO_ENDPOINTS = "8.15.0";

public InferenceUpgradeTestCase(@Name("upgradedNodes") int upgradedNodes) {
super(upgradedNodes);
}

@ClassRule
public static ElasticsearchCluster cluster = ElasticsearchCluster.local()
.distribution(DistributionType.DEFAULT)
.version(getOldClusterTestVersion())
.nodes(NODE_NUM)
.setting("xpack.security.enabled", "false")
.setting("xpack.license.self_generated.type", "trial")
.build();

@Override
protected ElasticsearchCluster getUpgradeCluster() {
return cluster;
}

protected static String getUrl(MockWebServer webServer) {
return format("http://%s:%s", webServer.getHostName(), webServer.getPort());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.inference.Utils;
import org.elasticsearch.xpack.inference.mock.TestDenseInferenceServiceExtension;
import org.elasticsearch.xpack.inference.mock.TestSparseInferenceServiceExtension;
Expand Down Expand Up @@ -73,7 +74,7 @@ public void setup() throws Exception {

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(Utils.TestInferencePlugin.class);
return Arrays.asList(Utils.TestInferencePlugin.class, LocalStateCompositeXPackPlugin.class);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.LocalStateCompositeXPackPlugin;
import org.elasticsearch.xpack.inference.InferencePlugin;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsTests;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
Expand Down Expand Up @@ -76,7 +77,7 @@ public void createComponents() {

@Override
protected Collection<Class<? extends Plugin>> getPlugins() {
return pluginList(ReindexPlugin.class, InferencePlugin.class);
return pluginList(ReindexPlugin.class, InferencePlugin.class, LocalStateCompositeXPackPlugin.class);
}

public void testStoreModel() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import org.elasticsearch.indices.SystemIndexDescriptor;
import org.elasticsearch.inference.InferenceServiceExtension;
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.license.License;
import org.elasticsearch.license.LicensedFeature;
import org.elasticsearch.node.PluginComponentBinding;
import org.elasticsearch.plugins.ActionPlugin;
import org.elasticsearch.plugins.ExtensiblePlugin;
Expand Down Expand Up @@ -150,6 +152,12 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP
Setting.Property.Dynamic
);

public static final LicensedFeature.Momentary INFERENCE_API_FEATURE = LicensedFeature.momentary(
"inference",
"api",
License.OperationMode.ENTERPRISE
);

public static final String NAME = "inference";
public static final String UTILITY_THREAD_POOL_NAME = "inference_utility";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.license.LicenseUtils;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
Expand All @@ -38,6 +41,7 @@
import java.util.stream.Collectors;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.InferencePlugin.INFERENCE_API_FEATURE;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes;
import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes;

Expand All @@ -48,6 +52,7 @@ public abstract class BaseTransportInferenceAction<Request extends BaseInference
private static final Logger log = LogManager.getLogger(BaseTransportInferenceAction.class);
private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference";
private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]";
private final XPackLicenseState licenseState;
private final ModelRegistry modelRegistry;
private final InferenceServiceRegistry serviceRegistry;
private final InferenceStats inferenceStats;
Expand All @@ -57,13 +62,15 @@ public BaseTransportInferenceAction(
String inferenceActionName,
TransportService transportService,
ActionFilters actionFilters,
XPackLicenseState licenseState,
ModelRegistry modelRegistry,
InferenceServiceRegistry serviceRegistry,
InferenceStats inferenceStats,
StreamingTaskManager streamingTaskManager,
Writeable.Reader<Request> requestReader
) {
super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE);
this.licenseState = licenseState;
this.modelRegistry = modelRegistry;
this.serviceRegistry = serviceRegistry;
this.inferenceStats = inferenceStats;
Expand All @@ -72,6 +79,11 @@ public BaseTransportInferenceAction(

@Override
protected void doExecute(Task task, Request request, ActionListener<InferenceAction.Response> listener) {
if (INFERENCE_API_FEATURE.check(licenseState) == false) {
listener.onFailure(LicenseUtils.newComplianceException(XPackField.INFERENCE));
return;
}

var timer = InferenceTimer.start();

var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.UnparsedModel;
import org.elasticsearch.injection.guice.Inject;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager;
Expand All @@ -28,6 +29,7 @@ public class TransportInferenceAction extends BaseTransportInferenceAction<Infer
public TransportInferenceAction(
TransportService transportService,
ActionFilters actionFilters,
XPackLicenseState licenseState,
ModelRegistry modelRegistry,
InferenceServiceRegistry serviceRegistry,
InferenceStats inferenceStats,
Expand All @@ -37,6 +39,7 @@ public TransportInferenceAction(
InferenceAction.NAME,
transportService,
actionFilters,
licenseState,
modelRegistry,
serviceRegistry,
inferenceStats,
Expand Down
Loading