From 594484689931c2909615bf2ac075d54da0d8a6fe Mon Sep 17 00:00:00 2001 From: Navarone Feekery <13634519+navarone-feekery@users.noreply.github.com> Date: Wed, 7 Feb 2024 15:15:03 +0100 Subject: [PATCH] Use API keys for native connector syncs (#2115) * Add method in `ESManagementClient` for fetching connector secrets * Change authorization for native connector syncs to API key found in connector secrets doc * Add feature flag `native_connector_api_keys` with default value `False` * Add unit tests --- connectors/es/management_client.py | 10 ++++ connectors/es/sink.py | 20 +++++++ connectors/protocol/connectors.py | 11 ++++ connectors/source.py | 4 ++ connectors/sync_job_runner.py | 8 +++ docs/CONNECTOR_PROTOCOL.md | 4 +- tests/es/test_management_client.py | 34 +++++++++++ tests/protocol/test_connectors.py | 3 + tests/test_sink.py | 51 +++++++++++++++++ tests/test_sync_job_runner.py | 91 ++++++++++++++++++++++++++++++ 10 files changed, 235 insertions(+), 1 deletion(-) diff --git a/connectors/es/management_client.py b/connectors/es/management_client.py index 32ed2e063..5102b46f3 100644 --- a/connectors/es/management_client.py +++ b/connectors/es/management_client.py @@ -183,3 +183,13 @@ async def yield_existing_documents_metadata(self, index): timestamp = source.get(TIMESTAMP_FIELD) yield doc_id, timestamp + + async def get_connector_secret(self, connector_secret_id): + secret = await self._retrier.execute_with_retry( + partial( + self.client.perform_request, + "GET", + f"/_connector/_secret/{connector_secret_id}", + ) + ) + return secret.get("value") diff --git a/connectors/es/sink.py b/connectors/es/sink.py index 52bb9e479..6c029a9be 100644 --- a/connectors/es/sink.py +++ b/connectors/es/sink.py @@ -24,6 +24,10 @@ import time from collections import defaultdict +from elasticsearch import ( + NotFoundError as ElasticNotFoundError, +) + from connectors.config import ( DEFAULT_ELASTICSEARCH_MAX_RETRIES, DEFAULT_ELASTICSEARCH_RETRY_INTERVAL, @@ -74,6 +78,10 @@ class ContentIndexDoesNotExistError(Exception): pass +class ApiKeyNotFoundError(Exception): + pass + + class Sink: """Send bulk operations in batches by consuming a queue. @@ -658,6 +666,18 @@ def __init__(self, elastic_config, logger_=None): async def close(self): await self.es_management_client.close() + async def update_authorization(self, index_name, secret_id): + # Updates the ESManagementClient auth options for native connectors after fetching API key + try: + api_key = await self.es_management_client.get_connector_secret(secret_id) + self._logger.debug( + f"Using API key found in secrets storage for authorization for index [{index_name}]." + ) + self.es_management_client.client.options(api_key=api_key) + except ElasticNotFoundError as e: + msg = f"API key not found in secrets storage for index [{index_name}]." + raise ApiKeyNotFoundError(msg) from e + async def has_active_license_enabled(self, license_): # TODO: think how to make it not a proxy method to the client return await self.es_management_client.has_active_license_enabled(license_) diff --git a/connectors/protocol/connectors.py b/connectors/protocol/connectors.py index 434652247..f2e063452 100644 --- a/connectors/protocol/connectors.py +++ b/connectors/protocol/connectors.py @@ -467,6 +467,8 @@ class Features: BASIC_RULES_OLD = "basic_rules_old" ADVANCED_RULES_OLD = "advanced_rules_old" + NATIVE_CONNECTOR_API_KEYS = "native_connector_api_keys" + def __init__(self, features=None): if features is None: features = {} @@ -483,6 +485,11 @@ def document_level_security_enabled(self): ["document_level_security", "enabled"], default=False ) + def native_connector_api_keys_enabled(self): + return self._nested_feature_enabled( + ["native_connector_api_keys", "enabled"], default=False + ) + def sync_rules_enabled(self): return any( [ @@ -624,6 +631,10 @@ def last_sync_scheduled_at_by_job_type(self, job_type): def sync_cursor(self): return self.get("sync_cursor") + @property + def api_key_secret_id(self): + return self.get("api_key_secret_id") + async def heartbeat(self, interval): if ( self.last_seen is None diff --git a/connectors/source.py b/connectors/source.py index d7e887b7b..8509f7bb5 100644 --- a/connectors/source.py +++ b/connectors/source.py @@ -385,6 +385,7 @@ class BaseDataSource: advanced_rules_enabled = False dls_enabled = False incremental_sync_enabled = False + native_connector_api_keys_enabled = False def __init__(self, configuration): # Initialize to the global logger @@ -490,6 +491,9 @@ def features(cls): "incremental_sync": { "enabled": cls.incremental_sync_enabled, }, + "native_connector_api_keys": { + "enabled": cls.native_connector_api_keys_enabled, + }, } def set_features(self, features): diff --git a/connectors/sync_job_runner.py b/connectors/sync_job_runner.py index 2159fa601..3a451f9b5 100644 --- a/connectors/sync_job_runner.py +++ b/connectors/sync_job_runner.py @@ -143,6 +143,14 @@ async def execute(self): self.es_config, self.sync_job.logger ) + if ( + self.connector.native + and self.connector.features.native_connector_api_keys_enabled() + ): + await self.sync_orchestrator.update_authorization( + self.connector.index_name, self.connector.api_key_secret_id + ) + if job_type in [JobType.INCREMENTAL, JobType.FULL]: self.sync_job.log_info(f"Executing {job_type.value} sync") await self._execute_content_sync_job(job_type, bulk_options) diff --git a/docs/CONNECTOR_PROTOCOL.md b/docs/CONNECTOR_PROTOCOL.md index f9910b6fe..644bac8d2 100644 --- a/docs/CONNECTOR_PROTOCOL.md +++ b/docs/CONNECTOR_PROTOCOL.md @@ -52,7 +52,8 @@ All communication will need to go through Elasticsearch. We've created a connect This is our main communication index, used to communicate the connector's configuration, status and other related data. All dates in UTC. ``` { - api_key_id: string; -> ID of the current API key in use + api_key_id: string; -> ID of the current API key in use + api_key_secret_id: string; -> ID of Connector Secret doc that stores the API key configuration: { [key]: { default_value: any; -> The value used if `value` is empty (only for non-required fields) @@ -188,6 +189,7 @@ This is our main communication index, used to communicate the connector's config "dynamic": false, "properties" : { "api_key_id" : { "type" : "keyword" }, + "api_key_secret_id" : { "type" : "keyword" }, "configuration" : { "type" : "object" }, "custom_scheduling" : { "type" : "object" }, "description" : { "type" : "text" }, diff --git a/tests/es/test_management_client.py b/tests/es/test_management_client.py index 0b05df3a6..0be6a1b52 100644 --- a/tests/es/test_management_client.py +++ b/tests/es/test_management_client.py @@ -252,3 +252,37 @@ async def test_yield_existing_documents_metadata_when_index_exists( ids.append(doc_id) assert ids == ["1", "2"] + + @pytest.mark.asyncio + async def test_get_connector_secret(self, es_management_client, mock_responses): + secret_id = "secret-id" + + es_management_client.client.perform_request = AsyncMock( + return_value={"id": secret_id, "value": "secret-value"} + ) + + secret = await es_management_client.get_connector_secret(secret_id) + assert secret == "secret-value" + es_management_client.client.perform_request.assert_awaited_with( + "GET", f"/_connector/_secret/{secret_id}" + ) + + @pytest.mark.asyncio + async def test_get_connector_secret_when_secret_does_not_exist( + self, es_management_client, mock_responses + ): + secret_id = "secret-id" + + error_meta = Mock() + error_meta.status = 404 + es_management_client.client.perform_request = AsyncMock( + side_effect=ElasticNotFoundError( + "resource_not_found_exception", + error_meta, + f"No secret with id [{secret_id}]", + ) + ) + + with pytest.raises(ElasticNotFoundError): + secret = await es_management_client.get_connector_secret(secret_id) + assert secret is None diff --git a/tests/protocol/test_connectors.py b/tests/protocol/test_connectors.py index 119ed717c..a08e73cfc 100644 --- a/tests/protocol/test_connectors.py +++ b/tests/protocol/test_connectors.py @@ -201,6 +201,7 @@ def test_utc(): mongo = { "api_key_id": "", + "api_key_secret_id": "", "configuration": { "host": {"value": "mongodb://127.0.0.1:27021", "label": "MongoDB Host"}, "database": {"value": "sample_airbnb", "label": "MongoDB Database"}, @@ -321,6 +322,7 @@ async def test_connector_properties(): connector_src = { "_id": "test", "_source": { + "api_key_secret_id": "api-key-secret-id", "service_type": "test", "index_name": "search-some-index", "configuration": {}, @@ -360,6 +362,7 @@ async def test_connector_properties(): assert connector.incremental_sync_scheduling["enabled"] assert connector.incremental_sync_scheduling["interval"] == "* * * * *" assert connector.sync_cursor == SYNC_CURSOR + assert connector.api_key_secret_id == "api-key-secret-id" assert isinstance(connector.last_seen, datetime) assert isinstance(connector.filtering, Filtering) assert isinstance(connector.pipeline, Pipeline) diff --git a/tests/test_sink.py b/tests/test_sink.py index 16b422272..92f9321e7 100644 --- a/tests/test_sink.py +++ b/tests/test_sink.py @@ -12,6 +12,9 @@ import pytest from elasticsearch import ApiError, BadRequestError +from elasticsearch import ( + NotFoundError as ElasticNotFoundError, +) from connectors.es import Mappings from connectors.es.management_client import ESManagementClient @@ -19,6 +22,7 @@ OP_DELETE, OP_INDEX, OP_UPSERT, + ApiKeyNotFoundError, AsyncBulkRunningError, Extractor, ForceCanceledError, @@ -1298,3 +1302,50 @@ async def test_cancel_sync(extractor_task_done, sink_task_done, force_cancel): else: es._extractor.force_cancel.assert_not_called() es._sink.force_cancel.assert_not_called() + + +@pytest.mark.asyncio +async def test_update_authorization(): + config = { + "host": "http://nowhere.com:9200", + "user": "someone", + "password": "something", + } + sync_orchestrator = SyncOrchestrator(config) + + sync_orchestrator.es_management_client.get_connector_secret = AsyncMock( + return_value="secret-value" + ) + sync_orchestrator.es_management_client.client.options = AsyncMock() + + await sync_orchestrator.update_authorization("my-index", "my-secret-id") + + sync_orchestrator.es_management_client.get_connector_secret.assert_called_with( + "my-secret-id" + ) + sync_orchestrator.es_management_client.client.options.assert_called_with( + api_key="secret-value" + ) + + +@pytest.mark.asyncio +async def test_update_authorization_when_api_key_not_found(): + config = { + "host": "http://nowhere.com:9200", + "user": "someone", + "password": "something", + } + sync_orchestrator = SyncOrchestrator(config) + + error_meta = Mock() + error_meta.status = 404 + sync_orchestrator.es_management_client.get_connector_secret = AsyncMock( + side_effect=ElasticNotFoundError( + "resource_not_found_exception", + error_meta, + "No secret with id [my-secret-id]", + ) + ) + + with pytest.raises(ApiKeyNotFoundError): + await sync_orchestrator.update_authorization("my-index", "my-secret-id") diff --git a/tests/test_sync_job_runner.py b/tests/test_sync_job_runner.py index 0d4e63d55..ef9accfac 100644 --- a/tests/test_sync_job_runner.py +++ b/tests/test_sync_job_runner.py @@ -11,6 +11,7 @@ from connectors.es.client import License from connectors.es.index import DocumentNotFoundError +from connectors.es.sink import ApiKeyNotFoundError from connectors.filtering.validation import InvalidFilteringError from connectors.protocol import Filter, JobStatus, JobType, Pipeline from connectors.source import BaseDataSource @@ -32,11 +33,13 @@ def mock_connector(): connector.last_sync_status = JobStatus.COMPLETED connector.features.sync_rules_enabled.return_value = True connector.features.incremental_sync_enabled.return_value = True + connector.features.native_connector_api_keys_enabled.return_value = True connector.sync_cursor = SYNC_CURSOR connector.document_count = AsyncMock(return_value=TOTAL_DOCUMENT_COUNT) connector.sync_starts = AsyncMock(return_value=True) connector.sync_done = AsyncMock() connector.reload = AsyncMock() + connector.native = True return connector @@ -126,6 +129,7 @@ def sync_orchestrator_mock(): sync_orchestrator_mock.has_active_license_enabled = AsyncMock( return_value=(True, License.PLATINUM) ) + sync_orchestrator_mock.update_authorization = AsyncMock() sync_orchestrator_klass_mock.return_value = sync_orchestrator_mock yield sync_orchestrator_mock @@ -891,3 +895,90 @@ async def test_unsupported_job_type(): with pytest.raises(SyncJobStartError): await sync_job_runner.execute() + + +@pytest.mark.parametrize( + "job_type, sync_cursor", + [ + (JobType.FULL, SYNC_CURSOR), + (JobType.INCREMENTAL, SYNC_CURSOR), + (JobType.ACCESS_CONTROL, None), + ], +) +@pytest.mark.asyncio +async def test_native_connector_sync_fails_when_api_key_secret_missing( + job_type, sync_cursor, sync_orchestrator_mock +): + ingestion_stats = { + "indexed_document_count": 0, + "indexed_document_volume": 0, + "deleted_document_count": 0, + "total_document_count": TOTAL_DOCUMENT_COUNT, + } + sync_orchestrator_mock.ingestion_stats.return_value = ingestion_stats + sync_orchestrator_mock.update_authorization = AsyncMock( + side_effect=ApiKeyNotFoundError() + ) + + sync_job_runner = create_runner(job_type=job_type, sync_cursor=sync_cursor) + + await sync_job_runner.execute() + + sync_job_runner.sync_job.claim.assert_awaited() + sync_job_runner.sync_job.fail.assert_awaited_with( + ANY, ingestion_stats=ingestion_stats + ) + sync_job_runner.sync_job.done.assert_not_awaited() + sync_job_runner.sync_job.cancel.assert_not_awaited() + sync_job_runner.sync_job.suspend.assert_not_awaited() + + sync_job_runner.sync_orchestrator.async_bulk.assert_not_awaited() + + sync_job_runner.connector.sync_starts.assert_awaited_with(job_type) + sync_job_runner.connector.sync_done.assert_awaited_with( + sync_job_runner.sync_job, cursor=sync_cursor + ) + + +@pytest.mark.parametrize( + "job_type, sync_cursor", + [ + (JobType.FULL, SYNC_CURSOR), + (JobType.INCREMENTAL, SYNC_CURSOR), + (JobType.ACCESS_CONTROL, None), + ], +) +@pytest.mark.asyncio +async def test_connector_client_sync_succeeds_when_api_key_secret_missing( + job_type, sync_cursor, sync_orchestrator_mock +): + connector = mock_connector() + connector.native = False + + ingestion_stats = { + "indexed_document_count": 25, + "indexed_document_volume": 30, + "deleted_document_count": 20, + } + sync_orchestrator_mock.ingestion_stats.return_value = ingestion_stats + sync_orchestrator_mock.update_authorization = AsyncMock( + side_effect=ApiKeyNotFoundError() + ) + + sync_job_runner = create_runner( + job_type=job_type, connector=connector, sync_cursor=sync_cursor + ) + await sync_job_runner.execute() + + ingestion_stats["total_document_count"] = TOTAL_DOCUMENT_COUNT + + sync_job_runner.connector.sync_starts.assert_awaited_with(job_type) + sync_job_runner.sync_job.claim.assert_awaited() + sync_job_runner.sync_orchestrator.async_bulk.assert_awaited() + sync_job_runner.sync_job.done.assert_awaited_with(ingestion_stats=ingestion_stats) + sync_job_runner.sync_job.fail.assert_not_awaited() + sync_job_runner.sync_job.cancel.assert_not_awaited() + sync_job_runner.sync_job.suspend.assert_not_awaited() + sync_job_runner.connector.sync_done.assert_awaited_with( + sync_job_runner.sync_job, cursor=sync_cursor + )