Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat memory implementation for Neo4j #2063

Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ Refer to the xref:_retrieval_augmented_generation[Retrieval Augmented Generation

The interface `ChatMemory` represents a storage for chat conversation history. It provides methods to add messages to a conversation, retrieve messages from a conversation, and clear the conversation history.

There are currently two implementations, `InMemoryChatMemory` and `CassandraChatMemory`, that provide storage for chat conversation history, in-memory and persisted with `time-to-live`, correspondingly.
There are currently three implementations, `InMemoryChatMemory`, `CassandraChatMemory` and `Neo4jChatMemory`, that provide storage for chat conversation history, in-memory, persisted with `time-to-live` in Cassandra, and persisted in Neo4j without `time-to-live` correspondingly.

To create a `CassandraChatMemory` with `time-to-live`:

Expand All @@ -383,6 +383,42 @@ To create a `CassandraChatMemory` with `time-to-live`:
CassandraChatMemory.create(CassandraChatMemoryConfig.builder().withTimeToLive(Duration.ofDays(1)).build());
----

The Neo4j chat memory implementation supports the following configuration parameters

[cols="2,5,1",stripes=even]
|===
|Property | Description | Default Value

|`spring.ai.chat.memory.neo4j.media-label` | The label for nodes that contain a message media objects | `Media`
|`spring.ai.chat.memory.neo4j.message-label` | The label for nodes that contain a message | `Message`
|`spring.ai.chat.memory.neo4j.metadata-label` | The label to use for nodes that contain a message metadata
| `Metadata`
|`spring.ai.chat.memory.neo4j.session-label`| The label to use for nodes that contain a chat session | `Session`
|`spring.ai.chat.memory.neo4j.tool-call-label` | The label to use for nodes that contain tool call information
| `ToolCall`
|`spring.ai.chat.memory.neo4j.tool-response-label` | The label for nodes that contain a tool response message |
`ToolResponse`
|===


The Neo4j chat memory supports the following configuration parameters:

[cols="2,5,1",stripes=even]
|===
|Property | Description | Default Value

| `spring.ai.chat.memory.neo4j.messageLabel` | The label for the nodes that store messages | `Message`
| `spring.ai.chat.memory.neo4j.sessionLabel` | The label for the nodes that store conversation sessions | `Session`
| `spring.ai.chat.memory.neo4j.toolCallLabel` | The label for nodes that store tool calls, for example
in Assistant Messages | `ToolCall`
| `spring.ai.chat.memory.neo4j.metadataLabel` | The label for the node that store a message metadata | `Metadata`
| `spring.ai.chat.memory.neo4j.toolResponseLabel` | The label for the nodes that store tool responses | `ToolResponse`
| `spring.ai.chat.memory.neo4j.mediaLabel` | The label for the nodes that store the media associated to a message | `ToolResponse`


|===


The following advisor implementations use the `ChatMemory` interface to advice the prompt with conversation history which differ in the details of how the memory is added to the prompt

* `MessageChatMemoryAdvisor` : Memory is retrieved and added as a collection of messages to the prompt
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.chat.memory.neo4j;

import org.neo4j.driver.Driver;
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory;
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.neo4j.Neo4jAutoConfiguration;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;

/**
* {@link AutoConfiguration Auto-configuration} for {@link Neo4jChatMemory}.
*
* @author Enrico Rampazzo
* @since 1.0.0
*/
@AutoConfiguration(after = Neo4jAutoConfiguration.class)
@ConditionalOnClass({ Neo4jChatMemory.class, Driver.class })
@EnableConfigurationProperties(Neo4jChatMemoryProperties.class)
public class Neo4jChatMemoryAutoConfiguration {

@Bean
@ConditionalOnMissingBean
public Neo4jChatMemory chatMemory(Neo4jChatMemoryProperties properties, Driver driver) {

var builder = Neo4jChatMemoryConfig.builder().withMediaLabel(properties.getMediaLabel())
.withMessageLabel(properties.getMessageLabel()).withMetadataLabel(properties.getMetadataLabel())
.withSessionLabel(properties.getSessionLabel()).withToolCallLabel(properties.getToolCallLabel())
.withToolResponseLabel(properties.getToolResponseLabel())
.withDriver(driver);

return Neo4jChatMemory.create(builder.build());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.chat.memory.neo4j;

import org.springframework.ai.autoconfigure.chat.memory.CommonChatMemoryProperties;
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
import org.springframework.boot.context.properties.ConfigurationProperties;

/**
* Configuration properties for Neo4j chat memory.
*
* @author Enrico Rampazzo
*/
@ConfigurationProperties(Neo4jChatMemoryProperties.CONFIG_PREFIX)
public class Neo4jChatMemoryProperties {

public static final String CONFIG_PREFIX = "spring.ai.chat.memory.neo4j";
private String sessionLabel = Neo4jChatMemoryConfig.DEFAULT_SESSION_LABEL;
private String toolCallLabel = Neo4jChatMemoryConfig.DEFAULT_TOOL_CALL_LABEL;
private String metadataLabel = Neo4jChatMemoryConfig.DEFAULT_METADATA_LABEL;
private String messageLabel = Neo4jChatMemoryConfig.DEFAULT_MESSAGE_LABEL;
private String toolResponseLabel = Neo4jChatMemoryConfig.DEFAULT_TOOL_RESPONSE_LABEL;
private String mediaLabel = Neo4jChatMemoryConfig.DEFAULT_MEDIA_LABEL;

public String getSessionLabel() {
return sessionLabel;
}

public void setSessionLabel(String sessionLabel) {
this.sessionLabel = sessionLabel;
}

public String getToolCallLabel() {
return toolCallLabel;
}

public String getMetadataLabel() {
return metadataLabel;
}

public String getMessageLabel() {
return messageLabel;
}

public String getToolResponseLabel() {
return toolResponseLabel;
}

public String getMediaLabel() {
return mediaLabel;
}

public void setToolCallLabel(String toolCallLabel) {
this.toolCallLabel = toolCallLabel;
}

public void setMetadataLabel(String metadataLabel) {
this.metadataLabel = metadataLabel;
}

public void setMessageLabel(String messageLabel) {
this.messageLabel = messageLabel;
}

public void setToolResponseLabel(String toolResponseLabel) {
this.toolResponseLabel = toolResponseLabel;
}

public void setMediaLabel(String mediaLabel) {
this.mediaLabel = mediaLabel;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.chat.memory.neo4j;

import com.datastax.driver.core.utils.UUIDs;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemory;
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.messages.ToolResponseMessage.ToolResponse;
import org.springframework.ai.model.Media;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.neo4j.Neo4jAutoConfiguration;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.util.MimeType;
import org.testcontainers.containers.Neo4jContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;
import org.testcontainers.utility.DockerImageName;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

/**
* @author Enrico Rampazzo
*/
@Testcontainers
class Neo4jChatMemoryAutoConfigurationIT {

static final DockerImageName DEFAULT_IMAGE_NAME = DockerImageName.parse("neo4j");

@SuppressWarnings({"rawtypes", "resource"})
@Container
static Neo4jContainer neo4jContainer = (Neo4jContainer) new Neo4jContainer(DEFAULT_IMAGE_NAME.withTag("5")).withoutAuthentication().withExposedPorts(7474,7687);

private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(
AutoConfigurations.of(Neo4jChatMemoryAutoConfiguration.class, Neo4jAutoConfiguration.class));


@Test
void addAndGet() {
this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl())
.run(context -> {
Neo4jChatMemory memory = context.getBean(Neo4jChatMemory.class);

String sessionId = UUIDs.timeBased().toString();
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();

UserMessage userMessage = new UserMessage("test question");


memory.add(sessionId, userMessage);
List<Message> messages = memory.get(sessionId, Integer.MAX_VALUE);
assertThat(messages).hasSize(1);
assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(userMessage);

memory.clear(sessionId);
assertThat(memory.get(sessionId, Integer.MAX_VALUE)).isEmpty();

AssistantMessage assistantMessage = new AssistantMessage("test answer", Map.of(),
List.of(new AssistantMessage.ToolCall(
"id", "type", "name", "arguments")));

memory.add(sessionId, List.of(userMessage, assistantMessage));
messages = memory.get(sessionId, Integer.MAX_VALUE);
assertThat(messages).hasSize(2);
assertThat(messages.get(1)).isEqualTo(userMessage);

assertThat(messages.get(0)).isEqualTo(assistantMessage);
memory.clear(sessionId);
MimeType textPlain = MimeType.valueOf("text/plain");
List<Media> media = List.of(Media.builder().name("some media").id(UUIDs.random().toString())
.mimeType(textPlain).data("hello".getBytes(StandardCharsets.UTF_8)).build(),
Media.builder().data(URI.create("http://www.google.com").toURL()).mimeType(textPlain).build());
UserMessage userMessageWithMedia = new UserMessage("Message with media", media);
memory.add(sessionId, userMessageWithMedia);

messages = memory.get(sessionId, Integer.MAX_VALUE);
assertThat(messages.size()).isEqualTo(1);
assertThat(messages.get(0)).isEqualTo(userMessageWithMedia);
assertThat(((UserMessage)messages.get(0)).getMedia()).hasSize(2);
assertThat(((UserMessage) messages.get(0)).getMedia()).usingRecursiveFieldByFieldElementComparator().isEqualTo(media);
memory.clear(sessionId);
ToolResponseMessage toolResponseMessage = new ToolResponseMessage(List.of(
new ToolResponse("id", "name", "responseData"),
new ToolResponse("id2", "name2", "responseData2")),
Map.of("id", "id", "metadataKey", "metadata"));
memory.add(sessionId, toolResponseMessage);
messages = memory.get(sessionId, Integer.MAX_VALUE);
assertThat(messages.size()).isEqualTo(1);
assertThat(messages.get(0)).isEqualTo(toolResponseMessage);

memory.clear(sessionId);
SystemMessage sm = new SystemMessage("this is a System message");
memory.add(sessionId, sm);
messages = memory.get(sessionId, Integer.MAX_VALUE);
assertThat(messages).hasSize(1);
assertThat(messages.get(0)).usingRecursiveAssertion().isEqualTo(sm);
});
}
@Test
void setCustomConfiguration(){
final String sessionLabel = "LabelSession";
final String toolCallLabel = "LabelToolCall";
final String metadataLabel = "LabelMetadata";
final String messageLabel = "LabelMessage";
final String toolResponseLabel = "LabelToolResponse";
final String mediaLabel = "LabelMedia";

final String propertyBase = "spring.ai.chat.memory.neo4j.%s=%s";
this.contextRunner.withPropertyValues("spring.neo4j.uri=" + neo4jContainer.getBoltUrl(),
propertyBase.formatted("sessionlabel", sessionLabel),
propertyBase.formatted("toolcallLabel", toolCallLabel),
propertyBase.formatted("metadatalabel", metadataLabel),
propertyBase.formatted("messagelabel", messageLabel),
propertyBase.formatted("toolresponselabel", toolResponseLabel),
propertyBase.formatted("medialabel", mediaLabel))
.run(context -> {
Neo4jChatMemory chatMemory = context.getBean(Neo4jChatMemory.class);
Neo4jChatMemoryConfig config = chatMemory.getConfig();
assertThat(config.getMessageLabel()).isEqualTo(messageLabel);
assertThat(config.getMediaLabel()).isEqualTo(mediaLabel);
assertThat(config.getMetadataLabel()).isEqualTo(metadataLabel);
assertThat(config.getSessionLabel()).isEqualTo(sessionLabel);
assertThat(config.getToolResponseLabel()).isEqualTo(toolResponseLabel);
assertThat(config.getToolCallLabel()).isEqualTo(toolCallLabel);
});
}



}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.autoconfigure.chat.memory.neo4j;

import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.memory.neo4j.Neo4jChatMemoryConfig;

import java.time.Duration;

import static org.assertj.core.api.Assertions.assertThat;

/**
* @author Enrico Rampazzo
* @since 1.0.0
*/
class Neo4jChatMemoryPropertiesTest {

@Test
void defaultValues() {
var props = new Neo4jChatMemoryProperties();
assertThat(props.getMediaLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_MEDIA_LABEL);
assertThat(props.getMessageLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_MESSAGE_LABEL);
assertThat(props.getMetadataLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_METADATA_LABEL);
assertThat(props.getSessionLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_SESSION_LABEL);
assertThat(props.getToolCallLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_TOOL_CALL_LABEL);
assertThat(props.getToolResponseLabel()).isEqualTo(Neo4jChatMemoryConfig.DEFAULT_TOOL_RESPONSE_LABEL);
}
}
Loading