diff --git a/src/main/java/com/shapesecurity/salvation/data/Policy.java b/src/main/java/com/shapesecurity/salvation/data/Policy.java index 8876aca8..6cd26434 100644 --- a/src/main/java/com/shapesecurity/salvation/data/Policy.java +++ b/src/main/java/com/shapesecurity/salvation/data/Policy.java @@ -171,7 +171,7 @@ private void optimise() { if (directive instanceof SourceListDirective) { SourceListDirective sourceListDirective = (SourceListDirective) directive; Optional star = - sourceListDirective.values().filter(x -> x instanceof HostSource && ((HostSource) x).isWildcard()) + sourceListDirective.values().filter(x -> x instanceof HostSource && ((HostSource) x).isTLDWildcard()) .findAny(); if (star.isPresent()) { Set newSources = sourceListDirective.values() diff --git a/src/main/java/com/shapesecurity/salvation/directiveValues/HostSource.java b/src/main/java/com/shapesecurity/salvation/directiveValues/HostSource.java index e56d6926..593cbb1d 100644 --- a/src/main/java/com/shapesecurity/salvation/directiveValues/HostSource.java +++ b/src/main/java/com/shapesecurity/salvation/directiveValues/HostSource.java @@ -1,14 +1,5 @@ package com.shapesecurity.salvation.directiveValues; -import com.shapesecurity.salvation.Constants; -import com.shapesecurity.salvation.data.GUID; -import com.shapesecurity.salvation.data.Origin; -import com.shapesecurity.salvation.data.SchemeHostPortTriple; -import com.shapesecurity.salvation.data.URI; -import com.shapesecurity.salvation.interfaces.MatchesSource; - -import javax.annotation.Nonnull; -import javax.annotation.Nullable; import java.io.UnsupportedEncodingException; import java.net.URLDecoder; import java.util.ArrayList; @@ -17,6 +8,16 @@ import java.util.Objects; import java.util.regex.Matcher; +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import com.shapesecurity.salvation.Constants; +import com.shapesecurity.salvation.data.GUID; +import com.shapesecurity.salvation.data.Origin; +import com.shapesecurity.salvation.data.SchemeHostPortTriple; +import com.shapesecurity.salvation.data.URI; +import com.shapesecurity.salvation.interfaces.MatchesSource; + public class HostSource implements SourceExpression, AncestorSource, MatchesSource { public static final HostSource WILDCARD = new HostSource(null, "*", Constants.EMPTY_PORT, null); @@ -46,7 +47,7 @@ public boolean equals(@Nullable Object other) { return false; } HostSource otherPrime = (HostSource) other; - if (this.isWildcard() && otherPrime.isWildcard()) { + if (this.isTLDWildcard() && otherPrime.isTLDWildcard()) { return true; } @@ -75,10 +76,26 @@ public int hashCode() { return h; } - public boolean isWildcard() { + public boolean isTLDWildcard() { return this.host.equals("*") && this.scheme == null && this.port == Constants.EMPTY_PORT && this.path == null; } + public boolean isSubdomainWildcard() { + return this.host.startsWith("*."); + } + + public boolean subsumes(@Nonnull HostSource other) { + if (this.isTLDWildcard()) return true; + if (this.isSubdomainWildcard() && !other.isTLDWildcard() && !other.isSubdomainWildcard()) { + return Objects.equals(this.scheme != null ? this.scheme.toLowerCase() : null, + other.scheme != null ? other.scheme.toLowerCase() : null) && + other.host.toLowerCase().endsWith(this.host.substring(1).toLowerCase()) && + this.port == other.port && + Objects.equals(this.path, other.path); + } + return this.equals(other); + } + public static boolean hostMatches(@Nonnull String hostA, @Nonnull String hostB) { if (hostA.startsWith("*")) { String remaining = hostA.substring(1); @@ -121,12 +138,12 @@ public boolean matchesSource(@Nonnull Origin origin, @Nonnull GUID resource) { public boolean matchesSource(@Nonnull Origin origin, @Nonnull URI resource) { if (origin instanceof GUID) { // wildcard matches a network scheme - return this.isWildcard() && resource.isNetworkScheme(); + return this.isTLDWildcard() && resource.isNetworkScheme(); } else if (!(origin instanceof SchemeHostPortTriple)) { return false; } SchemeHostPortTriple shpOrigin = (SchemeHostPortTriple) origin; - if (this.isWildcard()) { + if (this.isTLDWildcard()) { return resource.isNetworkScheme() || shpOrigin.scheme.matches(resource.scheme); } boolean schemeMatches; diff --git a/src/main/java/com/shapesecurity/salvation/directives/Directive.java b/src/main/java/com/shapesecurity/salvation/directives/Directive.java index 649d8c75..f98562ec 100644 --- a/src/main/java/com/shapesecurity/salvation/directives/Directive.java +++ b/src/main/java/com/shapesecurity/salvation/directives/Directive.java @@ -130,11 +130,9 @@ private static Set union(@Nonnull Set a, @Nonnull Set b) { set.remove(None.INSTANCE); - Optional star = set.stream().filter(x -> x instanceof HostSource && ((HostSource) x).isWildcard()).findAny(); - if (star.isPresent()) { - set.removeIf(y -> y instanceof HostSource); - set.add(star.get()); - } + set.stream().filter(x -> x instanceof HostSource).collect(Collectors.toList()).forEach(x -> { + set.removeIf(y -> y != x && y instanceof HostSource && ((HostSource) x).subsumes((HostSource) y)); + }); return set; } @@ -159,14 +157,14 @@ private static Set intersect(@Nonnull Set a, @Nonnull Set b) { return set; } - Optional star = b.stream().filter(x -> x instanceof HostSource && ((HostSource) x).isWildcard()).findAny(); + Optional star = b.stream().filter(x -> x instanceof HostSource && ((HostSource) x).isTLDWildcard()).findAny(); if (star.isPresent()) { set.addAll(a); return set; } for (T x : a) { - if (x instanceof HostSource && ((HostSource) x).isWildcard()) { + if (x instanceof HostSource && ((HostSource) x).isTLDWildcard()) { set.clear(); set.addAll(b); return set; diff --git a/src/test/java/com/shapesecurity/salvation/PolicyMergeTest.java b/src/test/java/com/shapesecurity/salvation/PolicyMergeTest.java index f974a41e..b870d8b5 100644 --- a/src/test/java/com/shapesecurity/salvation/PolicyMergeTest.java +++ b/src/test/java/com/shapesecurity/salvation/PolicyMergeTest.java @@ -672,4 +672,58 @@ public void testUnionNone() { p.union(q); assertEquals("script-src *", p.show()); } + + @Test + public void testUnionTLDWildcard() { + Policy p = parse("default-src *"); + Policy q = parse("default-src http://b.c.atest.com"); + p.union(q); + assertEquals("default-src *", p.show()); + + p = parse("default-src *"); + q = parse("default-src http://*.c.atest.com"); + p.union(q); + assertEquals("default-src *", p.show()); + } + + @Test + public void testUnionSubdomainWildcard() { + Policy p = parse("default-src http://*.atest.com"); + Policy q = parse("default-src http://b.c.atest.com"); + p.union(q); + assertEquals("default-src http://*.atest.com", p.show()); + + p = parse("default-src http://*.b.atest.com"); + q = parse("default-src http://x.b.atest.com"); + p.union(q); + assertEquals("default-src http://*.b.atest.com", p.show()); + + p = parse("default-src https://*.b.atest.com:8443"); + q = parse("default-src https://x.b.atest.com:8443"); + p.union(q); + assertEquals("default-src https://*.b.atest.com:8443", p.show()); + + p = parse("default-src https://*.b.atest.com:8443/a"); + q = parse("default-src https://x.b.atest.com:8443/a"); + p.union(q); + assertEquals("default-src https://*.b.atest.com:8443/a", p.show()); + } + + @Test + public void testUnionFails() { + Policy p = parse("default-src https://*.atest.com"); + Policy q = parse("default-src http://b.c.atest.com"); + p.union(q); + assertEquals("default-src https://*.atest.com http://b.c.atest.com", p.show()); + + p = parse("default-src http://*.atest.com:80"); + q = parse("default-src http://b.c.atest.com:8080"); + p.union(q); + assertEquals("default-src http://*.atest.com http://b.c.atest.com:8080", p.show()); + + p = parse("default-src http://*.atest.com/a"); + q = parse("default-src http://b.c.atest.com/a/b"); + p.union(q); + assertEquals("default-src http://*.atest.com/a http://b.c.atest.com/a/b", p.show()); + } }