Skip to content

Commit

Permalink
fixing create index step and array input for processors
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Galitzky <[email protected]>
  • Loading branch information
amitgalitz committed Mar 20, 2024
1 parent 149e22a commit ce78f8f
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,13 @@ public enum DefaultUseCases {
"substitutionTemplates/semantic-search-with-model-and-query-enricher-template.json"
),
/** defaults file and substitution ready template for hybrid search, no model creation*/
HYBRID_SEARCH("hybrid_search", "defaults/hybrid-search-defaults.json", "substitutionTemplates/hybrid-search-template.json");
HYBRID_SEARCH("hybrid_search", "defaults/hybrid-search-defaults.json", "substitutionTemplates/hybrid-search-template.json"),
/** defaults file and substitution ready template for conversational search with cohere chat model*/
CONVERSATIONAL_SEARCH_WITH_COHERE_DEPLOY(
"conversational_search_with_llm_deploy",
"defaults/conversational-search-defaults.json",
"substitutionTemplates/conversational-search-with-cohere-model-template.json"
);

private final String useCaseName;
private final String defaultsFile;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
try {
XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
Map<String, String> userDefaults = ParseUtils.parseStringToStringMap(parser);
Map<String, Object> userDefaults = ParseUtils.parseStringToObjectMap(parser);
// updates the default params with anything user has given that matches
for (Map.Entry<String, String> userDefaultsEntry : userDefaults.entrySet()) {
for (Map.Entry<String, Object> userDefaultsEntry : userDefaults.entrySet()) {
String key = userDefaultsEntry.getKey();
String value = userDefaultsEntry.getValue();
String value = userDefaultsEntry.getValue().toString();
if (useCaseDefaultsMap.containsKey(key)) {
useCaseDefaultsMap.put(key, value);
}
Expand All @@ -154,7 +154,6 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli
null,
useCaseDefaultsMap
);

XContentParser parserTestJson = ParseUtils.jsonToParser(useCaseTemplateFileInStringFormat);
ensureExpectedToken(XContentParser.Token.START_OBJECT, parserTestJson.currentToken(), parserTestJson);
template = Template.parse(parserTestJson);
Expand Down
28 changes: 28 additions & 0 deletions src/main/java/org/opensearch/flowframework/util/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.time.Instant;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -169,6 +170,7 @@ public static Map<String, String> parseStringToStringMap(XContentParser parser)
/**
* Parses an XContent object representing a map of String keys to Object values.
* The Object value here can either be a string or a map
* If an array is found in the given parser we conver the array to a string representation of the array
* @param parser An XContent parser whose position is at the start of the map object to parse
* @return A map as identified by the key-value pairs in the XContent
* @throws IOException on a parse failure
Expand All @@ -182,6 +184,15 @@ public static Map<String, Object> parseStringToObjectMap(XContentParser parser)
if (parser.currentToken() == XContentParser.Token.START_OBJECT) {
// If the current token is a START_OBJECT, parse it as Map<String, String>
map.put(fieldName, parseStringToStringMap(parser));
} else if (parser.currentToken() == XContentParser.Token.START_ARRAY) {
// If an array, parse it to a string
// Handle array: convert it to a string representation
List<String> elements = new ArrayList<>();
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
elements.add("\"" + parser.text() + "\""); // Adding escaped quotes around each element
}
String arrayString = "[" + String.join(", ", elements) + "]";
map.put(fieldName, arrayString);
} else {
// Otherwise, parse it as a string
map.put(fieldName, parser.text());
Expand Down Expand Up @@ -413,4 +424,21 @@ public static Map<String, String> parseJsonFileToStringToStringMap(String path)
Map<String, String> mappedJsonFile = mapper.readValue(jsonContent, Map.class);
return mappedJsonFile;
}

/**
* Takes an input string, then checks if there is an array in the string with backslashes around strings
* (e.g. "[\"text\", \"hello\"]" to "["text", "hello"]"), this is needed for processors that take in string arrays,
* This also removes the quotations around the array making the array valid to consume
* (e.g. "weights": "[0.7, 0.3]" -> "weights": [0.7, 0.3])
* @param input The inputString given to be transformed
* @return the transformed string
*/
public static String removingBackslashesAndQuotesInArrayInJsonString(String input) {
return Pattern.compile("\"\\[(.*?)]\"").matcher(input).replaceAll(matchResult -> {
// Extract matched content and remove backslashes before quotes
String withoutEscapes = matchResult.group(1).replaceAll("\\\\\"", "\"");
// Return the transformed string with the brackets but without the outer quotes
return "[" + withoutEscapes + "]";
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,11 @@ public PlainActionFuture<WorkflowData> execute(
String pipelineId = (String) inputs.get(PIPELINE_ID);
String configurations = (String) inputs.get(CONFIGURATIONS);

// Special case for processors that have arrays that need to have the quotes removed
// (e.g. "weights": "[0.7, 0.3]" -> "weights": [0.7, 0.3]
// Define a regular expression pattern to match stringified arrays
String transformedJsonString = configurations.replaceAll("\"\\[(.*?)]\"", "[$1]");
// Special case for processors that have arrays that need to have the quotes around or
// backslashes around strings in array removed
String transformedJsonStringForStringArray = ParseUtils.removingBackslashesAndQuotesInArrayInJsonString(configurations);

byte[] byteArr = transformedJsonString.getBytes(StandardCharsets.UTF_8);
byte[] byteArr = transformedJsonStringForStringArray.getBytes(StandardCharsets.UTF_8);
BytesReference configurationsBytes = new BytesArray(byteArr);

String pipelineToBeCreated = this.getName();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,27 @@
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.client.Client;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.exception.WorkflowStepException;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.index.mapper.MapperService;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

import static java.util.Collections.singletonMap;
import static org.opensearch.flowframework.common.CommonValue.CONFIGURATIONS;
import static org.opensearch.flowframework.common.WorkflowResources.INDEX_NAME;
import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep;
Expand Down Expand Up @@ -85,8 +92,13 @@ public PlainActionFuture<WorkflowData> execute(

byte[] byteArr = configurations.getBytes(StandardCharsets.UTF_8);
BytesReference configurationsBytes = new BytesArray(byteArr);
CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName);
if (!configurations.isEmpty()) {
Map<String, Object> sourceAsMap = XContentHelper.convertToMap(configurationsBytes, false, MediaTypeRegistry.JSON).v2();
sourceAsMap = prepareMappings(sourceAsMap);
createIndexRequest.source(sourceAsMap, LoggingDeprecationHandler.INSTANCE);
}

CreateIndexRequest createIndexRequest = new CreateIndexRequest(indexName).source(configurationsBytes, XContentType.JSON);
client.admin().indices().create(createIndexRequest, ActionListener.wrap(acknowledgedResponse -> {
String resourceName = getResourceByWorkflowStep(getName());
logger.info("Created index: {}", indexName);
Expand Down Expand Up @@ -129,6 +141,26 @@ public PlainActionFuture<WorkflowData> execute(
return createIndexFuture;
}

// This method to check if the mapping contains a type `_doc` and if yes we fail the request
// is to duplicate the behavior we have today through create index rest API, we want users
// to encounter the same behavior and not suddenly have to add `_doc` while using our create_index step
private static Map<String, Object> prepareMappings(Map<String, Object> source) {
if (source.containsKey("mappings") == false || (source.get("mappings") instanceof Map) == false) {
return source;
}

Map<String, Object> newSource = new HashMap<>(source);

@SuppressWarnings("unchecked")
Map<String, Object> mappings = (Map<String, Object>) source.get("mappings");
if (MapperService.isMappingSourceTyped(MapperService.SINGLE_MAPPING_NAME, mappings)) {
throw new WorkflowStepException("The mapping definition cannot be nested under a type", RestStatus.BAD_REQUEST);
}

newSource.put("mappings", singletonMap(MapperService.SINGLE_MAPPING_NAME, mappings));
return newSource;
}

@Override
public String getName() {
return NAME;
Expand Down
20 changes: 20 additions & 0 deletions src/main/resources/defaults/conversational-search-defaults.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
{
"template.name": "deploy-cohere-chat-model",
"template.description": "deploying cohere chat model",
"create_connector.name": "Cohere Chat Model",
"create_connector.description": "The connector to Cohere's public chat API",
"create_connector.protocol": "http",
"create_connector.model": "command",
"create_connector.endpoint": "api.cohere.ai",
"create_connector.credential.key": "123",
"create_connector.actions.url": "https://api.cohere.ai/v1/chat",
"create_connector.actions.request_body": "{ \"message\": \"${parameters.message}\", \"model\": \"${parameters.model}\" }",
"register_remote_model.name": "Cohere chat model",
"register_remote_model.description": "cohere-chat-model",
"create_search_pipeline.pipeline_id": "rag-pipeline",
"create_search_pipeline.retrieval_augmented_generation.tag": "openai_pipeline_demo",
"create_search_pipeline.retrieval_augmented_generation.description": "Demo pipeline Using cohere Connector",
"create_search_pipeline.retrieval_augmented_generation.context_field_list": "[\"text\", \"hello\"]",
"create_search_pipeline.retrieval_augmented_generation.system_prompt": "You are a helpful assistant",
"create_search_pipeline.retrieval_augmented_generation.user_instructions": "Generate a concise and informative answer in less than 100 words for the given question"
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"create_connector.name": "Amazon Bedrock Connector: multi-modal embedding",
"create_connector.description": "The connector to bedrock Titan multi-modal embedding model",
"create_connector.region": "us-east-1",
"create_connector.input_docs_processed_step_size": 2,
"create_connector.input_docs_processed_step_size": "2",
"create_connector.endpoint": "api.openai.com",
"create_connector.credential.access_key": "123",
"create_connector.credential.secret_key": "123",
Expand All @@ -17,12 +17,12 @@
"register_remote_model.description": "bedrock-multi-modal-embedding-model",
"create_ingest_pipeline.pipeline_id": "nlp-multimodal-ingest-pipeline",
"create_ingest_pipeline.description": "A text/image embedding pipeline",
"create_ingest_pipeline.embedding": "vector_embedding",
"text_image_embedding.create_ingest_pipeline.embedding": "vector_embedding",
"text_image_embedding.field_map.text": "image_description",
"text_image_embedding.field_map.image": "image_binary",
"create_index.name": "my-multimodal-nlp-index",
"create_index.settings.number_of_shards": 2,
"text_image_embedding.field_map.output.dimension": 1024,
"create_index.settings.number_of_shards": "2",
"text_image_embedding.field_map.output.dimension": "1024",
"create_index.mappings.method.engine": "lucene",
"create_index.mappings.method.name": "hnsw"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
{
"name": "${{template.name}}",
"description": "${{template.description}}",
"use_case": "SEMANTIC_SEARCH",
"version": {
"template": "1.0.0",
"compatibility": [
"2.12.0",
"3.0.0"
]
},
"workflows": {
"provision": {
"nodes": [
{
"id": "create_connector",
"type": "create_connector",
"user_inputs": {
"name": "${{create_connector}}",
"description": "${{create_connector.description}}",
"version": "1",
"protocol": "${{create_connector.protocol}}",
"parameters": {
"endpoint": "${{create_connector.endpoint}}",
"model": "${{create_connector.model}}"
},
"credential": {
"key": "${{create_connector.credential.key}}"
},
"actions": [
{
"action_type": "predict",
"method": "POST",
"url": "${{create_connector.actions.url}}",
"headers": {
"Authorization": "Bearer ${credential.key}"
},
"request_body": "${{create_connector.actions.request_body}}"
}
]
}
},
{
"id": "register_model",
"type": "register_remote_model",
"previous_node_inputs": {
"create_connector": "parameters"
},
"user_inputs": {
"name": "${{register_remote_model.name}}",
"function_name": "remote",
"description": "${{register_remote_model.description}}",
"deploy": true
}
},
{
"id": "create_search_pipeline",
"type": "create_search_pipeline",
"previous_node_inputs": {
"register_model": "model_id"
},
"user_inputs": {
"pipeline_id": "${{create_search_pipeline.pipeline_id}}",
"configurations": {
"response_processors": [
{
"retrieval_augmented_generation": {
"tag": "${{create_search_pipeline.retrieval_augmented_generation.tag}}",
"description": "${{create_search_pipeline.retrieval_augmented_generation.description}}",
"model_id": "${{register_model.model_id}}",
"context_field_list": "${{create_search_pipeline.retrieval_augmented_generation.context_field_list}}",
"system_prompt": "${{create_search_pipeline.retrieval_augmented_generation.system_prompt}}",
"user_instructions": "${{create_search_pipeline.retrieval_augmented_generation.user_instructions}}"
}
}
]
}
}
}
]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"input_docs_processed_step_size": "${{create_connector.input_docs_processed_step_size}}"
},
"credential": {
"access_ key": "${{create_connector.credential.access_key}}",
"access_key": "${{create_connector.credential.access_key}}",
"secret_key": "${{create_connector.credential.secret_key}}",
"session_token": "${{create_connector.credential.session_token}}"
},
Expand Down Expand Up @@ -73,7 +73,7 @@
{
"text_image_embedding": {
"model_id": "${{register_model.model_id}}",
"embedding": "${{create_ingest_pipeline.embedding}}",
"embedding": "${{text_image_embedding.create_ingest_pipeline.embedding}}",
"field_map": {
"text": "${{text_image_embedding.field_map.text}}",
"image": "${{text_image_embedding.field_map.image}}"
Expand Down Expand Up @@ -103,7 +103,7 @@
"id": {
"type": "text"
},
"${{text_embedding.field_map.output}}": {
"${{text_image_embedding.create_ingest_pipeline.embedding}}": {
"type": "knn_vector",
"dimension": "${{text_image_embedding.field_map.output.dimension}}",
"method": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ public void testConditionallySubstituteWithUnmatchedPlaceholders() {
assertEquals("This string has unmatched ${{placeholder}}", result);
}

public void testRemovingBackslashesAndQuotesInArrayInJsonString() {
String inputNumArray = "normalization-processor.combination.parameters.weights: \"[0.3, 0.7]\"";
String outputNumArray = ParseUtils.removingBackslashesAndQuotesInArrayInJsonString(inputNumArray);
assertEquals("normalization-processor.combination.parameters.weights: [0.3, 0.7]", outputNumArray);
String inputStringArray =
"create_search_pipeline.retrieval_augmented_generation.context_field_list: \"[\\\"text\\\", \\\"hello\\\"]\"";
String outputStringArray = ParseUtils.removingBackslashesAndQuotesInArrayInJsonString(inputStringArray);
assertEquals("create_search_pipeline.retrieval_augmented_generation.context_field_list: [\"text\", \"hello\"]", outputStringArray);
}

public void testConditionallySubstituteWithOutputsSubstitution() {
String input = "This string contains ${{node.step}}";
Map<String, WorkflowData> outputs = new HashMap<>();
Expand Down

0 comments on commit ce78f8f

Please sign in to comment.