From fc8815336931186b94b0a8b8a5390c735f27d9fb Mon Sep 17 00:00:00 2001 From: "Shin, Jungseob" Date: Fri, 11 Oct 2024 13:39:55 +0900 Subject: [PATCH] Update chunking config & web-app prompt --- .../web-app/pages/rag_integration.py | 73 +++++++++++++++---- .../other_stack/bedrock_agent_stack.py | 5 +- 2 files changed, 64 insertions(+), 14 deletions(-) diff --git a/cdk/examples/generative_ai_rag/web-app/pages/rag_integration.py b/cdk/examples/generative_ai_rag/web-app/pages/rag_integration.py index 0728e087..01b3f6e4 100644 --- a/cdk/examples/generative_ai_rag/web-app/pages/rag_integration.py +++ b/cdk/examples/generative_ai_rag/web-app/pages/rag_integration.py @@ -78,9 +78,9 @@ def get_prompt_template(retrieved_passages: List[str]) -> str: - Use the information from the retrieved passages to recommend sessions. - Prioritize recommendations based on: a) Exact matches in the "Related AWS Services" field (highest priority) - b) Relevance in the "Description" and "Title" fields - c) Relevance to the Topic and Areas of Interest - - Always include "Description", "Time" and "Venue" information for recommended sessions and summary it. + b) Relevance in the "Related AWS Services", "Description" and "Title" fields + c) Relevance to the "Topic" and "Areas of Interest" + - Must to include "Description", "Time" and "Venue" information for recommended sessions. 3. For re:Invent session information questions: - Use the information from the retrieved passages to provide detailed session information. @@ -101,20 +101,41 @@ def get_prompt_template(retrieved_passages: List[str]) -> str: - Verify user assertions against search results; don't assume user statements are factual. Here are the retrieved passages: + {retrieved_passages} - IMPORTANT: Your final response should only contain the actual answer to the user's question. + IMPORTANT: + Your final response should only contain the actual answer to the user's question. Do not include any explanation of your thought process, categorization, or analysis in the final response. + If retrieved passages are empty and question type is not GENERAL, respond with "Sorry. I couldn't find any related information." - Format your response as follows: + CRITICAL RESPONSE FORMAT: + You MUST format your entire response EXACTLY as follows, with no exceptions: + + [QUESTION_TYPE] + GENERAL or REINVENT_RECOMMENDATION or REINVENT_INFORMATION + [/QUESTION_TYPE] + [RESPONSE] + Your complete answer here, with no text outside these tags. + [/RESPONSE] + + IMPORTANT RULES: + 1. ALWAYS include both [QUESTION_TYPE] and [RESPONSE] tags. + 2. [QUESTION_TYPE] must contain ONLY ONE of the three specified types, nothing else. + 3. [RESPONSE] must contain your COMPLETE answer and nothing else. + 4. DO NOT include ANY text outside of these tags. + 5. If you cannot provide an answer, still use the tags and put your "Sorry" message inside the [RESPONSE] tags. + 6. Never use [GENERAL], [REINVENT_RECOMMENDATION], or [REINVENT_INFORMATION] as standalone tags. + + EXAMPLE OF CORRECT FORMAT: [QUESTION_TYPE] - The type of question (GENERAL, REINVENT_RECOMMENDATION, or REINVENT_INFORMATION) + GENERAL [/QUESTION_TYPE] [RESPONSE] - Your actual response here, without any preamble or explanation of your thought process. + This is where your complete answer would go, with no other text outside these tags. [/RESPONSE] - Everything outside the [RESPONSE] tags will be discarded, so ensure your complete answer is within these tags. + Failure to follow this format exactly will result in an error. Double-check your response before submitting. """ def retrieve_from_knowledge_base(client: Any, knowledge_base_id: str, prompt: str) -> List[Dict[str, Any]]: @@ -164,14 +185,36 @@ def main(): retrieved_passages = [result['content']['text'] for result in retrieved_results] formatted_passages = "\n\n".join(f"Passage {i+1}:\n{passage}" for i, passage in enumerate(retrieved_passages)) + print(formatted_passages) + + # Initialize variables to store the full response and the content after the [RESPONSE] tag + full_response = "" + response_content = "" + response_started = False + + message_placeholder = st.chat_message("assistant").empty() + # Generate the final response using the invoke_model API system_prompt = get_prompt_template(formatted_passages) - - message_placeholder = st.chat_message("assistant").empty() - full_response = "" + + # Process each chunk from the Bedrock stream for chunk in invoke_bedrock_stream(bedrock_runtime_client, system_prompt, prompt): + # Accumulate the full response full_response += chunk - message_placeholder.markdown(full_response + "▌") + + # Check if we've reached the [RESPONSE] tag + if '[RESPONSE]' in chunk: + response_started = True + # Extract the content after the [RESPONSE] tag + response_content = chunk.split('[RESPONSE]')[1] + elif response_started: + # If we're past the [RESPONSE] tag, continue accumulating the response content + response_content += chunk + + # Display the response content if we've started collecting it + if response_started: + # Remove the [/RESPONSE] tag if present and display the content + message_placeholder.markdown(response_content.replace('[/RESPONSE]', '') + "▌") # Extract question type and response import re @@ -181,12 +224,16 @@ def main(): question_type = question_type_match.group(1).strip() if question_type_match else "UNKNOWN" final_response = response_match.group(1).strip() if response_match else "I apologize. There was an issue generating an appropriate response." + print(full_response) + print(question_type) + print(final_response) + message_placeholder.markdown(final_response) st.session_state.messages.append({"role": "assistant", "content": final_response}) # Display citations only for non-general questions - if question_type != "GENERAL": + if question_type not in ["GENERAL", "UNKNOWN"]: with st.expander("Data Sources"): for i, result in enumerate(retrieved_results, 1): content = result['content']['text'] diff --git a/cdk/examples/other_stack/bedrock_agent_stack.py b/cdk/examples/other_stack/bedrock_agent_stack.py index 7b6b6f82..2a196aa3 100644 --- a/cdk/examples/other_stack/bedrock_agent_stack.py +++ b/cdk/examples/other_stack/bedrock_agent_stack.py @@ -58,7 +58,10 @@ def __init__( bucket=self.bucket, knowledge_base=self.knowledge_base, data_source_name='ReinventSessionInformationText', - chunking_strategy= bedrock.ChunkingStrategy.FIXED_SIZE, + chunking_strategy= bedrock.ChunkingStrategy.fixed_size( + max_tokens= 512, + overlap_percentage= 20 + ) ) # create parameter store