From ce78f8f4e5463a5c91e4b08a43208e7f5504fed6 Mon Sep 17 00:00:00 2001 From: Amit Galitzky Date: Wed, 20 Mar 2024 10:21:11 -0700 Subject: [PATCH] fixing create index step and array input for processors Signed-off-by: Amit Galitzky --- .../flowframework/common/DefaultUseCases.java | 8 +- .../rest/RestCreateWorkflowAction.java | 7 +- .../flowframework/util/ParseUtils.java | 28 +++++++ .../workflow/AbstractCreatePipelineStep.java | 9 +- .../workflow/CreateIndexStep.java | 36 +++++++- .../conversational-search-defaults.json | 20 +++++ ...timodal-search-bedrock-titan-defaults.json | 8 +- ...nal-search-with-cohere-model-template.json | 83 +++++++++++++++++++ ...al-search-with-bedrock-titan-template.json | 6 +- .../flowframework/util/ParseUtilsTests.java | 10 +++ 10 files changed, 196 insertions(+), 19 deletions(-) create mode 100644 src/main/resources/defaults/conversational-search-defaults.json create mode 100644 src/main/resources/substitutionTemplates/conversational-search-with-cohere-model-template.json diff --git a/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java b/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java index 71be18a65..f4b4ce49d 100644 --- a/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java +++ b/src/main/java/org/opensearch/flowframework/common/DefaultUseCases.java @@ -93,7 +93,13 @@ public enum DefaultUseCases { "substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json" ), /** defaults file and substitution ready template for hybrid search, no model creation*/ - HYBRID_SEARCH("hybrid_search", "defaults/hybrid-search-defaults.json", "substitutionTemplates/hybrid-search-template.json"); + HYBRID_SEARCH("hybrid_search", "defaults/hybrid-search-defaults.json", "substitutionTemplates/hybrid-search-template.json"), + /** defaults file and substitution ready template for conversational search with cohere chat model*/ + CONVERSATIONAL_SEARCH_WITH_COHERE_DEPLOY( + "conversational_search_with_llm_deploy", + "defaults/conversational-search-defaults.json", + "substitutionTemplates/conversational-search-with-cohere-model-template.json" + ); private final String useCaseName; private final String defaultsFile; diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index c2ec444c4..ffad1a732 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -131,11 +131,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli try { XContentParser parser = request.contentParser(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - Map userDefaults = ParseUtils.parseStringToStringMap(parser); + Map userDefaults = ParseUtils.parseStringToObjectMap(parser); // updates the default params with anything user has given that matches - for (Map.Entry userDefaultsEntry : userDefaults.entrySet()) { + for (Map.Entry userDefaultsEntry : userDefaults.entrySet()) { String key = userDefaultsEntry.getKey(); - String value = userDefaultsEntry.getValue(); + String value = userDefaultsEntry.getValue().toString(); if (useCaseDefaultsMap.containsKey(key)) { useCaseDefaultsMap.put(key, value); } @@ -154,7 +154,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli null, useCaseDefaultsMap ); - XContentParser parserTestJson = ParseUtils.jsonToParser(useCaseTemplateFileInStringFormat); ensureExpectedToken(XContentParser.Token.START_OBJECT, parserTestJson.currentToken(), parserTestJson); template = Template.parse(parserTestJson); diff --git a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java index 0a923721f..c1d82ca46 100644 --- a/src/main/java/org/opensearch/flowframework/util/ParseUtils.java +++ b/src/main/java/org/opensearch/flowframework/util/ParseUtils.java @@ -34,6 +34,7 @@ import java.io.IOException; import java.io.InputStream; import java.time.Instant; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -169,6 +170,7 @@ public static Map parseStringToStringMap(XContentParser parser) /** * Parses an XContent object representing a map of String keys to Object values. * The Object value here can either be a string or a map + * If an array is found in the given parser we conver the array to a string representation of the array * @param parser An XContent parser whose position is at the start of the map object to parse * @return A map as identified by the key-value pairs in the XContent * @throws IOException on a parse failure @@ -182,6 +184,15 @@ public static Map parseStringToObjectMap(XContentParser parser) if (parser.currentToken() == XContentParser.Token.START_OBJECT) { // If the current token is a START_OBJECT, parse it as Map map.put(fieldName, parseStringToStringMap(parser)); + } else if (parser.currentToken() == XContentParser.Token.START_ARRAY) { + // If an array, parse it to a string + // Handle array: convert it to a string representation + List elements = new ArrayList<>(); + while (parser.nextToken() != XContentParser.Token.END_ARRAY) { + elements.add("\"" + parser.text() + "\""); // Adding escaped quotes around each element + } + String arrayString = "[" + String.join(", ", elements) + "]"; + map.put(fieldName, arrayString); } else { // Otherwise, parse it as a string map.put(fieldName, parser.text()); @@ -413,4 +424,21 @@ public static Map parseJsonFileToStringToStringMap(String path) Map mappedJsonFile = mapper.readValue(jsonContent, Map.class); return mappedJsonFile; } + + /** + * Takes an input string, then checks if there is an array in the string with backslashes around strings + * (e.g. "[\"text\", \"hello\"]" to "["text", "hello"]"), this is needed for processors that take in string arrays, + * This also removes the quotations around the array making the array valid to consume + * (e.g. "weights": "[0.7, 0.3]" -> "weights": [0.7, 0.3]) + * @param input The inputString given to be transformed + * @return the transformed string + */ + public static String removingBackslashesAndQuotesInArrayInJsonString(String input) { + return Pattern.compile("\"\\[(.*?)]\"").matcher(input).replaceAll(matchResult -> { + // Extract matched content and remove backslashes before quotes + String withoutEscapes = matchResult.group(1).replaceAll("\\\\\"", "\""); + // Return the transformed string with the brackets but without the outer quotes + return "[" + withoutEscapes + "]"; + }); + } } diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java index 0689ce4e4..ce9bca27e 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractCreatePipelineStep.java @@ -85,12 +85,11 @@ public PlainActionFuture execute( String pipelineId = (String) inputs.get(PIPELINE_ID); String configurations = (String) inputs.get(CONFIGURATIONS); - // Special case for processors that have arrays that need to have the quotes removed - // (e.g. "weights": "[0.7, 0.3]" -> "weights": [0.7, 0.3] - // Define a regular expression pattern to match stringified arrays - String transformedJsonString = configurations.replaceAll("\"\\[(.*?)]\"", "[$1]"); + // Special case for processors that have arrays that need to have the quotes around or + // backslashes around strings in array removed + String transformedJsonStringForStringArray = ParseUtils.removingBackslashesAndQuotesInArrayInJsonString(configurations); - byte[] byteArr = transformedJsonString.getBytes(StandardCharsets.UTF_8); + byte[] byteArr = transformedJsonStringForStringArray.getBytes(StandardCharsets.UTF_8); BytesReference configurationsBytes = new BytesArray(byteArr); String pipelineToBeCreated = this.getName(); diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 509d4b417..03e00e2a4 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -14,20 +14,27 @@ import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; -import org.opensearch.common.xcontent.XContentType; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.exception.WorkflowStepException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.index.mapper.MapperService; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.Set; +import static java.util.Collections.singletonMap; import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS; import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME; import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep; @@ -85,8 +92,13 @@ public PlainActionFuture execute( byte[] byteArr = configurations.getBytes(StandardCharsets.UTF_8); BytesReference configurationsBytes = new BytesArray(byteArr); + CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName); + if (!configurations.isEmpty()) { + Map sourceAsMap = XContentHelper.convertToMap(configurationsBytes, false, MediaTypeRegistry.JSON).v2(); + sourceAsMap = prepareMappings(sourceAsMap); + createIndexRequest.source(sourceAsMap, LoggingDeprecationHandler.INSTANCE); + } - CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName).source(configurationsBytes, XContentType.JSON); client.admin().indices().create(createIndexRequest, ActionListener.wrap(acknowledgedResponse -> { String resourceName = getResourceByWorkflowStep(getName()); logger.info("Created index: {}", indexName); @@ -129,6 +141,26 @@ public PlainActionFuture execute( return createIndexFuture; } + // This method to check if the mapping contains a type `_doc` and if yes we fail the request + // is to duplicate the behavior we have today through create index rest API, we want users + // to encounter the same behavior and not suddenly have to add `_doc` while using our create_index step + private static Map prepareMappings(Map source) { + if (source.containsKey("mappings") == false || (source.get("mappings") instanceof Map) == false) { + return source; + } + + Map newSource = new HashMap<>(source); + + @SuppressWarnings("unchecked") + Map mappings = (Map) source.get("mappings"); + if (MapperService.isMappingSourceTyped(MapperService.SINGLE_MAPPING_NAME, mappings)) { + throw new WorkflowStepException("The mapping definition cannot be nested under a type", RestStatus.BAD_REQUEST); + } + + newSource.put("mappings", singletonMap(MapperService.SINGLE_MAPPING_NAME, mappings)); + return newSource; + } + @Override public String getName() { return NAME; diff --git a/src/main/resources/defaults/conversational-search-defaults.json b/src/main/resources/defaults/conversational-search-defaults.json new file mode 100644 index 000000000..1371da552 --- /dev/null +++ b/src/main/resources/defaults/conversational-search-defaults.json @@ -0,0 +1,20 @@ +{ + "template.name": "deploy-cohere-chat-model", + "template.description": "deploying cohere chat model", + "create_connector.name": "Cohere Chat Model", + "create_connector.description": "The connector to Cohere's public chat API", + "create_connector.protocol": "http", + "create_connector.model": "command", + "create_connector.endpoint": "api.cohere.ai", + "create_connector.credential.key": "123", + "create_connector.actions.url": "https://api.cohere.ai/v1/chat", + "create_connector.actions.request_body": "{ \"message\": \"${parameters.message}\", \"model\": \"${parameters.model}\" }", + "register_remote_model.name": "Cohere chat model", + "register_remote_model.description": "cohere-chat-model", + "create_search_pipeline.pipeline_id": "rag-pipeline", + "create_search_pipeline.retrieval_augmented_generation.tag": "openai_pipeline_demo", + "create_search_pipeline.retrieval_augmented_generation.description": "Demo pipeline Using cohere Connector", + "create_search_pipeline.retrieval_augmented_generation.context_field_list": "[\"text\", \"hello\"]", + "create_search_pipeline.retrieval_augmented_generation.system_prompt": "You are a helpful assistant", + "create_search_pipeline.retrieval_augmented_generation.user_instructions": "Generate a concise and informative answer in less than 100 words for the given question" +} diff --git a/src/main/resources/defaults/multimodal-search-bedrock-titan-defaults.json b/src/main/resources/defaults/multimodal-search-bedrock-titan-defaults.json index 222053db1..f7656d967 100644 --- a/src/main/resources/defaults/multimodal-search-bedrock-titan-defaults.json +++ b/src/main/resources/defaults/multimodal-search-bedrock-titan-defaults.json @@ -4,7 +4,7 @@ "create_connector.name": "Amazon Bedrock Connector: multi-modal embedding", "create_connector.description": "The connector to bedrock Titan multi-modal embedding model", "create_connector.region": "us-east-1", - "create_connector.input_docs_processed_step_size": 2, + "create_connector.input_docs_processed_step_size": "2", "create_connector.endpoint": "api.openai.com", "create_connector.credential.access_key": "123", "create_connector.credential.secret_key": "123", @@ -17,12 +17,12 @@ "register_remote_model.description": "bedrock-multi-modal-embedding-model", "create_ingest_pipeline.pipeline_id": "nlp-multimodal-ingest-pipeline", "create_ingest_pipeline.description": "A text/image embedding pipeline", - "create_ingest_pipeline.embedding": "vector_embedding", + "text_image_embedding.create_ingest_pipeline.embedding": "vector_embedding", "text_image_embedding.field_map.text": "image_description", "text_image_embedding.field_map.image": "image_binary", "create_index.name": "my-multimodal-nlp-index", - "create_index.settings.number_of_shards": 2, - "text_image_embedding.field_map.output.dimension": 1024, + "create_index.settings.number_of_shards": "2", + "text_image_embedding.field_map.output.dimension": "1024", "create_index.mappings.method.engine": "lucene", "create_index.mappings.method.name": "hnsw" } diff --git a/src/main/resources/substitutionTemplates/conversational-search-with-cohere-model-template.json b/src/main/resources/substitutionTemplates/conversational-search-with-cohere-model-template.json new file mode 100644 index 000000000..9c919f553 --- /dev/null +++ b/src/main/resources/substitutionTemplates/conversational-search-with-cohere-model-template.json @@ -0,0 +1,83 @@ +{ + "name": "${{template.name}}", + "description": "${{template.description}}", + "use_case": "SEMANTIC_SEARCH", + "version": { + "template": "1.0.0", + "compatibility": [ + "2.12.0", + "3.0.0" + ] + }, + "workflows": { + "provision": { + "nodes": [ + { + "id": "create_connector", + "type": "create_connector", + "user_inputs": { + "name": "${{create_connector}}", + "description": "${{create_connector.description}}", + "version": "1", + "protocol": "${{create_connector.protocol}}", + "parameters": { + "endpoint": "${{create_connector.endpoint}}", + "model": "${{create_connector.model}}" + }, + "credential": { + "key": "${{create_connector.credential.key}}" + }, + "actions": [ + { + "action_type": "predict", + "method": "POST", + "url": "${{create_connector.actions.url}}", + "headers": { + "Authorization": "Bearer ${credential.key}" + }, + "request_body": "${{create_connector.actions.request_body}}" + } + ] + } + }, + { + "id": "register_model", + "type": "register_remote_model", + "previous_node_inputs": { + "create_connector": "parameters" + }, + "user_inputs": { + "name": "${{register_remote_model.name}}", + "function_name": "remote", + "description": "${{register_remote_model.description}}", + "deploy": true + } + }, + { + "id": "create_search_pipeline", + "type": "create_search_pipeline", + "previous_node_inputs": { + "register_model": "model_id" + }, + "user_inputs": { + "pipeline_id": "${{create_search_pipeline.pipeline_id}}", + "configurations": { + "response_processors": [ + { + "retrieval_augmented_generation": { + "tag": "${{create_search_pipeline.retrieval_augmented_generation.tag}}", + "description": "${{create_search_pipeline.retrieval_augmented_generation.description}}", + "model_id": "${{register_model.model_id}}", + "context_field_list": "${{create_search_pipeline.retrieval_augmented_generation.context_field_list}}", + "system_prompt": "${{create_search_pipeline.retrieval_augmented_generation.system_prompt}}", + "user_instructions": "${{create_search_pipeline.retrieval_augmented_generation.user_instructions}}" + } + } + ] + } + } + } + ] + } + } +} diff --git a/src/main/resources/substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json b/src/main/resources/substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json index 54df3710a..2c5d1efd2 100644 --- a/src/main/resources/substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json +++ b/src/main/resources/substitutionTemplates/multi-modal-search-with-bedrock-titan-template.json @@ -26,7 +26,7 @@ "input_docs_processed_step_size": "${{create_connector.input_docs_processed_step_size}}" }, "credential": { - "access_ key": "${{create_connector.credential.access_key}}", + "access_key": "${{create_connector.credential.access_key}}", "secret_key": "${{create_connector.credential.secret_key}}", "session_token": "${{create_connector.credential.session_token}}" }, @@ -73,7 +73,7 @@ { "text_image_embedding": { "model_id": "${{register_model.model_id}}", - "embedding": "${{create_ingest_pipeline.embedding}}", + "embedding": "${{text_image_embedding.create_ingest_pipeline.embedding}}", "field_map": { "text": "${{text_image_embedding.field_map.text}}", "image": "${{text_image_embedding.field_map.image}}" @@ -103,7 +103,7 @@ "id": { "type": "text" }, - "${{text_embedding.field_map.output}}": { + "${{text_image_embedding.create_ingest_pipeline.embedding}}": { "type": "knn_vector", "dimension": "${{text_image_embedding.field_map.output.dimension}}", "method": { diff --git a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java index 7ae057d24..92406b3e7 100644 --- a/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java +++ b/src/test/java/org/opensearch/flowframework/util/ParseUtilsTests.java @@ -110,6 +110,16 @@ public void testConditionallySubstituteWithUnmatchedPlaceholders() { assertEquals("This string has unmatched ${{placeholder}}", result); } + public void testRemovingBackslashesAndQuotesInArrayInJsonString() { + String inputNumArray = "normalization-processor.combination.parameters.weights: \"[0.3, 0.7]\""; + String outputNumArray = ParseUtils.removingBackslashesAndQuotesInArrayInJsonString(inputNumArray); + assertEquals("normalization-processor.combination.parameters.weights: [0.3, 0.7]", outputNumArray); + String inputStringArray = + "create_search_pipeline.retrieval_augmented_generation.context_field_list: \"[\\\"text\\\", \\\"hello\\\"]\""; + String outputStringArray = ParseUtils.removingBackslashesAndQuotesInArrayInJsonString(inputStringArray); + assertEquals("create_search_pipeline.retrieval_augmented_generation.context_field_list: [\"text\", \"hello\"]", outputStringArray); + } + public void testConditionallySubstituteWithOutputsSubstitution() { String input = "This string contains ${{node.step}}"; Map outputs = new HashMap<>();