Skip to content

Commit

Permalink
Adds support for entity hierarchies in compare
Browse files Browse the repository at this point in the history
This change allows users to declaratively specify hierarchical entities in their expected utterance results. For example, a user may declare the following:

```json
{
  "text": "Order a pepperoni pizza"
  "intent": "OrderFood",
  "entities": {
    "entity": "FoodItem":
    "startPos": 8,
    "endPos": 22,
    "children": [
      {
        "entity": "Topping",
        "startPos": 8,
        "endPos": 16
      },
      {
        "entity": "FoodType",
        "startPos": 18,
        "endPos": 22
      }
    ]
  }
}
```

This would result in 3 test cases, one for the parent entity (the "FoodItem" entity), and two additional test cases for each of the two nested entities ("FoodItem::Topping" and "FoodItem::FoodType").

Child entity type names are prefixed by their parent entity type names in the format `parentType::childType`. As such, the recursive entity parsing for the LUIS V3 provider has been updated to use this convention.

Fixes microsoft#335
  • Loading branch information
rozele committed Nov 19, 2020
1 parent 6ab6c5b commit 76fa82f
Show file tree
Hide file tree
Showing 7 changed files with 131 additions and 45 deletions.
18 changes: 15 additions & 3 deletions src/NLU.DevOps.Core.Tests/JsonLabeledUtteranceConverterTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
namespace NLU.DevOps.Core.Tests
{
using System;
using System.Collections.Generic;
using System.Linq;
using FluentAssertions;
using FluentAssertions.Json;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Newtonsoft.Json.Serialization;
Expand Down Expand Up @@ -97,13 +97,18 @@ public static void ConvertsUtteranceWithNestedEntities()
{ "entity", "baz" },
{ "startPos", 8 },
{ "endPos", 10 },
{ "foo", new JArray(42) },
{ "bar", null },
{ "baz", 42 },
{ "qux", JValue.CreateUndefined() },
};

var midEntity = new JObject
{
{ "entityType", "bar" },
{ "matchText", "bar baz" },
{ "children", new JArray { leafEntity } },
{ "entityValue", new JObject { { "bar", "qux" } } },
};

var entity = new JObject
Expand All @@ -126,10 +131,17 @@ public static void ConvertsUtteranceWithNestedEntities()
actual.Entities.Count.Should().Be(3);
actual.Entities[0].EntityType.Should().Be("foo");
actual.Entities[0].MatchText.Should().Be(text);
actual.Entities[1].EntityType.Should().Be("bar");
actual.Entities[1].EntityType.Should().Be("foo::bar");
actual.Entities[1].MatchText.Should().Be("bar baz");
actual.Entities[2].EntityType.Should().Be("baz");
actual.Entities[1].EntityValue.Should().BeEquivalentTo(new JObject { { "bar", "qux" } });
actual.Entities[2].EntityType.Should().Be("foo::bar::baz");
actual.Entities[2].MatchText.Should().Be("baz");

var additionalProperties = actual.Entities[2].As<Entity>().AdditionalProperties;
additionalProperties["foo"].As<JToken>().Should().BeEquivalentTo(new JArray(42));
additionalProperties["bar"].Should().BeNull();
additionalProperties["baz"].Should().Be(42);
additionalProperties["qux"].Should().BeNull();
}

private static JsonSerializer CreateSerializer()
Expand Down
3 changes: 2 additions & 1 deletion src/NLU.DevOps.Core.Tests/NLU.DevOps.Core.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
<PackageReference Include="nunit" Version="3.12.0" />
<PackageReference Include="NUnit3TestAdapter" Version="3.13.0" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.2.0" />
<PackageReference Include="FluentAssertions" Version="5.7.0" />
<PackageReference Include="FluentAssertions" Version="5.5.3" />
<PackageReference Include="FluentAssertions.Json" Version="5.0.0" />
</ItemGroup>

<ItemGroup>
Expand Down
130 changes: 100 additions & 30 deletions src/NLU.DevOps.Core/EntityConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
namespace NLU.DevOps.Core
{
using System;
using System.Collections.Generic;
using System.Diagnostics;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;

Expand All @@ -16,42 +18,33 @@ public EntityConverter(string utterance)

private string Utterance { get; }

private string Prefix { get; set; } = string.Empty;

public override Entity ReadJson(JsonReader reader, Type objectType, Entity existingValue, bool hasExistingValue, JsonSerializer serializer)
{
Debug.Assert(!hasExistingValue, "Entity instance can only be constructor initialized.");

var jsonObject = JObject.Load(reader);
return typeof(HierarchicalEntity).IsAssignableFrom(objectType)
? this.ReadHierarchicalEntity(jsonObject, serializer)
: this.ReadEntity(jsonObject, objectType, serializer);
}

public override void WriteJson(JsonWriter writer, Entity value, JsonSerializer serializer)
{
throw new NotImplementedException();
}

private Entity ReadEntity(JObject jsonObject, Type objectType, JsonSerializer serializer)
{
var matchText = jsonObject.Value<string>("matchText");
var matchIndex = jsonObject.Value<int>("matchIndex");
var startPosOrNull = jsonObject.Value<int?>("startPos");
var endPosOrNull = jsonObject.Value<int?>("endPos");
if (matchText == null && startPosOrNull != null && endPosOrNull != null)
if (matchText == null && startPosOrNull.HasValue && endPosOrNull.HasValue)
{
var startPos = startPosOrNull.Value;
var endPos = endPosOrNull.Value;
var length = endPos - startPos + 1;
if (!this.IsValid(startPos, endPos))
{
throw new InvalidOperationException(
$"Invalid start position '{startPos}' or end position '{endPos}' for utterance '{this.Utterance}'.");
}

matchText = this.Utterance.Substring(startPos, length);
(matchText, matchIndex) = this.GetMatchInfo(startPosOrNull.Value, endPosOrNull.Value);
jsonObject.Add("matchText", matchText);
var matchIndex = 0;
var currentPos = 0;
while (true)
{
currentPos = this.Utterance.IndexOf(matchText, currentPos, StringComparison.InvariantCulture);

// Because 'matchText' is derived from the utterance from 'startPos' and 'endPos',
// we are guaranteed to find a match at with index 'startPos'.
if (currentPos == startPos)
{
break;
}

currentPos += length;
matchIndex++;
}

jsonObject.Add("matchIndex", matchIndex);
jsonObject.Remove("startPos");
jsonObject.Remove("endPos");
Expand All @@ -76,9 +69,86 @@ public override Entity ReadJson(JsonReader reader, Type objectType, Entity exist
}
}

public override void WriteJson(JsonWriter writer, Entity value, JsonSerializer serializer)
private HierarchicalEntity ReadHierarchicalEntity(JObject jsonObject, JsonSerializer serializer)
{
throw new NotImplementedException();
var matchText = jsonObject.Value<string>("matchText");
var matchIndex = jsonObject.Value<int>("matchIndex");
var startPosOrNull = jsonObject.Value<int?>("startPos");
var endPosOrNull = jsonObject.Value<int?>("endPos");
if (matchText == null && startPosOrNull.HasValue && endPosOrNull.HasValue)
{
(matchText, matchIndex) = this.GetMatchInfo(startPosOrNull.Value, endPosOrNull.Value);
}

var entityType = jsonObject.Value<string>("entityType") ?? jsonObject.Value<string>("entity");
var childrenJson = jsonObject["children"];
var children = default(IEnumerable<HierarchicalEntity>);
if (childrenJson != null)
{
var prefix = $"{entityType}::";
this.Prefix += prefix;
try
{
children = childrenJson.ToObject<IEnumerable<HierarchicalEntity>>(serializer);
}
finally
{
this.Prefix = this.Prefix.Substring(0, this.Prefix.Length - prefix.Length);
}
}

var entity = new HierarchicalEntity($"{this.Prefix}{entityType}", jsonObject["entityValue"], matchText, matchIndex, children);
foreach (var property in jsonObject)
{
switch (property.Key)
{
case "children":
case "endPos":
case "entity":
case "entityType":
case "entityValue":
case "matchText":
case "matchIndex":
case "startPos":
break;
default:
var value = property.Value is JValue jsonValue ? jsonValue.Value : property.Value;
entity.AdditionalProperties.Add(property.Key, value);
break;
}
}

return entity;
}

private Tuple<string, int> GetMatchInfo(int startPos, int endPos)
{
if (!this.IsValid(startPos, endPos))
{
throw new InvalidOperationException(
$"Invalid start position '{startPos}' or end position '{endPos}' for utterance '{this.Utterance}'.");
}

var length = endPos - startPos + 1;
var matchText = this.Utterance.Substring(startPos, length);
var matchIndex = 0;
var currentPos = 0;
while (true)
{
currentPos = this.Utterance.IndexOf(matchText, currentPos, StringComparison.InvariantCulture);

// Because 'matchText' is derived from the utterance from 'startPos' and 'endPos',
// we are guaranteed to find a match at with index 'startPos'.
if (currentPos == startPos)
{
break;
}

currentPos += length;
matchIndex++;
}

return Tuple.Create(matchText, matchIndex);
}

private bool IsValid(int startPos, int endPos)
Expand Down
2 changes: 1 addition & 1 deletion src/NLU.DevOps.Core/HierarchicalEntity.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace NLU.DevOps.Core
/// <summary>
/// Entity appearing in utterance.
/// </summary>
public class HierarchicalEntity : Entity, IHierarchicalEntity
public sealed class HierarchicalEntity : Entity, IHierarchicalEntity
{
/// <summary>
/// Initializes a new instance of the <see cref="HierarchicalEntity"/> class.
Expand Down
3 changes: 3 additions & 0 deletions src/NLU.DevOps.Core/JsonLabeledUtteranceConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
namespace NLU.DevOps.Core
{
using System;
using System.Diagnostics;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;

Expand All @@ -18,6 +19,8 @@ public class JsonLabeledUtteranceConverter : JsonConverter<JsonLabeledUtterance>
/// <inheritdoc />
public override JsonLabeledUtterance ReadJson(JsonReader reader, Type objectType, JsonLabeledUtterance existingValue, bool hasExistingValue, JsonSerializer serializer)
{
Debug.Assert(!hasExistingValue, "Utterance instance can only be constructor initialized.");

var jsonObject = JObject.Load(reader);
var utterance = jsonObject.Value<string>("text") ?? jsonObject.Value<string>("query");
var entityConverter = new EntityConverter(utterance);
Expand Down
8 changes: 4 additions & 4 deletions src/NLU.DevOps.LuisV3.Tests/LuisNLUTestClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -402,19 +402,19 @@ public static async Task UtteranceWithNestedMLEntity()
result.Text.Should().Be(test);
result.Intent.Should().Be("RequestVacation");
result.Entities.Count.Should().Be(7);
result.Entities[0].EntityType.Should().Be("leave-type");
result.Entities[0].EntityType.Should().Be("vacation-request::leave-type");
result.Entities[0].EntityValue.Should().BeEquivalentTo(@"[ ""sick"" ]");
result.Entities[0].MatchText.Should().Be("sick leave");
result.Entities[0].MatchIndex.Should().Be(0);
result.Entities[1].EntityType.Should().Be("days-number");
result.Entities[1].EntityType.Should().Be("vacation-request::days-duration::days-number");
result.Entities[1].EntityValue.Should().BeEquivalentTo("6");
result.Entities[1].MatchText.Should().Be("6");
result.Entities[1].MatchIndex.Should().Be(0);
result.Entities[2].EntityType.Should().Be("days-duration");
result.Entities[2].EntityType.Should().Be("vacation-request::days-duration");
result.Entities[2].EntityValue.Should().BeEquivalentTo(@"{ ""days-number"": [ 6 ] }");
result.Entities[2].MatchText.Should().Be("6 days");
result.Entities[2].MatchIndex.Should().Be(0);
result.Entities[3].EntityType.Should().Be("start-date");
result.Entities[3].EntityType.Should().Be("vacation-request::start-date");
result.Entities[3].MatchText.Should().Be("starting march 5");
result.Entities[3].MatchIndex.Should().Be(0);
result.Entities[4].EntityType.Should().Be("vacation-request");
Expand Down
12 changes: 6 additions & 6 deletions src/NLU.DevOps.LuisV3/LuisNLUTestClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ private static IEnumerable<IEntity> GetEntities(string utterance, IDictionary<st
return null;
}

IEnumerable<IEntity> getEntitiesForType(string type, object instances, JToken metadata)
IEnumerable<IEntity> getEntitiesForType(string prefix, string type, object instances, JToken metadata)
{
if (instances is JArray instancesJson)
{
Expand All @@ -111,14 +111,14 @@ IEnumerable<IEntity> getEntitiesForType(string type, object instances, JToken me
.Zip(
typeMetadata,
(instance, instanceMetadata) =>
getEntitiesRecursive(type, instance, instanceMetadata))
getEntitiesRecursive(prefix, type, instance, instanceMetadata))
.SelectMany(e => e);
}

return Array.Empty<IEntity>();
}

IEnumerable<IEntity> getEntitiesRecursive(string entityType, JToken entityJson, JToken entityMetadata)
IEnumerable<IEntity> getEntitiesRecursive(string prefix, string entityType, JToken entityJson, JToken entityMetadata)
{
var startIndex = entityMetadata.Value<int>("startIndex");
var length = entityMetadata.Value<int>("length");
Expand All @@ -136,15 +136,15 @@ IEnumerable<IEntity> getEntitiesRecursive(string entityType, JToken entityJson,
if (entityJson is JObject entityJsonObject && entityJsonObject.TryGetValue("$instance", out var innerMetadata))
{
var children = ((IDictionary<string, JToken>)entityJsonObject)
.SelectMany(pair => getEntitiesForType(pair.Key, pair.Value, innerMetadata));
.SelectMany(pair => getEntitiesForType($"{prefix}{entityType}::", pair.Key, pair.Value, innerMetadata));

foreach (var child in children)
{
yield return child;
}
}

yield return new Entity(entityType, entityValue, matchText, matchIndex)
yield return new Entity($"{prefix}{entityType}", entityValue, matchText, matchIndex)
.WithScore(score);
}

Expand All @@ -159,7 +159,7 @@ IEnumerable<IEntity> getEntitiesRecursive(string entityType, JToken entityJson,
}

return entities.SelectMany(pair =>
getEntitiesForType(pair.Key, pair.Value, globalMetadata));
getEntitiesForType(string.Empty, pair.Key, pair.Value, globalMetadata));
}

private static JToken PruneMetadata(JToken json)
Expand Down

0 comments on commit 76fa82f

Please sign in to comment.