Skip to content

Commit

Permalink
stage 2 works. need to fix classpath for stage 1 test.
Browse files Browse the repository at this point in the history
  • Loading branch information
evanchooly committed Nov 12, 2024
1 parent d9f5b7c commit b9e37c6
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 204 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public MethodInvocation visitMethodInvocation(MethodInvocation methodInvocation,
if (matchers.stream().anyMatch(matcher -> matcher.matches(methodInvocation))) {
return super.visitMethodInvocation(methodInvocation
.withName(methodInvocation.getName().withSimpleName("pipeline")),
context);
context);
} else {
return super.visitMethodInvocation(methodInvocation, context);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,40 +1,21 @@
package dev.morphia.rewrite.recipes;

import java.util.ArrayList;

import org.openrewrite.ExecutionContext;
import org.openrewrite.Preconditions;
import org.openrewrite.Recipe;
import org.openrewrite.TreeVisitor;
import org.openrewrite.java.JavaIsoVisitor;
import org.openrewrite.java.JavaTemplate;
import org.openrewrite.java.MethodMatcher;
import org.openrewrite.java.search.UsesMethod;
import org.openrewrite.java.tree.Expression;
import org.openrewrite.java.tree.J;
import org.openrewrite.java.tree.J.Block;
import org.openrewrite.java.tree.J.Identifier;
import org.openrewrite.java.tree.J.MethodInvocation;
import org.openrewrite.java.tree.JContainer;
import org.openrewrite.java.tree.JRightPadded;
import org.openrewrite.java.tree.Space;
import org.openrewrite.java.tree.Statement;
import org.openrewrite.marker.Markers;

import java.lang.annotation.ElementType;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.UUID;

import static java.util.List.of;
import static org.openrewrite.java.tree.JRightPadded.build;

public class PipelineRewriteStage2 extends Recipe {

static final String AGGREGATION = "dev.morphia.aggregation.Aggregation";

static final MethodMatcher PIPELINE = new MethodMatcher(PipelineRewriteStage2.AGGREGATION + " pipeline(..)");
private static final String AGGREGATION = "dev.morphia.aggregation.Aggregation";

private final JavaTemplate pipelineTemplate = null; //JavaTemplate.builder("pipeline(..)").contextSensitive().build();
private static final MethodMatcher PIPELINE = new MethodMatcher(PipelineRewriteStage2.AGGREGATION + " pipeline(..)");

@Override
public String getDisplayName() {
Expand All @@ -51,100 +32,24 @@ public TreeVisitor<?, ExecutionContext> getVisitor() {
return Preconditions.check(new UsesMethod<>(PIPELINE), new JavaIsoVisitor<>() {

public MethodInvocation visitMethodInvocation(MethodInvocation method, ExecutionContext p) {
// exit if method doesn't match isEqualTo(..)
if (!PIPELINE.matches(method.getSelect()/*.getMethodType()*/)) {
if (!PIPELINE.matches(method.getSelect())) {
return method;
}

var arguments = new ArrayList<JRightPadded<Expression>>();
var select = method.getSelect();
while (PIPELINE.matches(select)) {
// System.out.println("select = " + select);
J.MethodInvocation invocation = (J.MethodInvocation) select;
arguments.add(build(invocation.getArguments().get(0)));
select = invocation.getSelect();
}
System.out.println("done: select = " + select);
Collections.reverse(arguments);
Markers markers = new Markers(UUID.randomUUID(), of());
Identifier identifier = (Identifier) select;
Space prefix = method.getPrefix();
var newInvocation = new MethodInvocation(
UUID.randomUUID(), prefix, markers, build(method.getSelect()), null, identifier,
JContainer.build(arguments),
method.getMethodType());
return newInvocation;
}
/*
@Override
public J.Block visitBlock(J.Block block, ExecutionContext ctx) {
J.Block bl = super.visitBlock(block, ctx);
return bl.withStatements(rewritePipelineStatements(bl));
}
*/

/*
private MethodInvocation rewritePipelineStatements(Block bl) {
List<Statement> statements = new ArrayList<>();
for (var statement : bl.getStatements()) {
if (statement instanceof J.MethodInvocation && isPipeline(statement)) {
List<Expression> arguments = new ArrayList<>();
J.MethodInvocation pipeline = (J.MethodInvocation) statement;
var select = pipeline.getSelect();
while (PIPELINE.matches(select)) {
System.out.println("select = " + select);
J.MethodInvocation invocation = (J.MethodInvocation) select;
arguments.add(invocation.getArguments().get(0));
select = invocation.getSelect();
}
System.out.println("done: select = " + select);
Collections.reverse(arguments);
// Markers markers = new Markers(UUID.randomUUID(), of());
// Identifier identifier = (Identifier) select;
// Space prefix = statement.getPrefix();
// var newInvocation = new MethodInvocation(
// UUID.randomUUID(), prefix, markers, null, null, identifier, arguments,
// ((MethodInvocation) statement).getMethodType() );
MethodInvocation m = pipelineTemplate.apply(getCursor(), statement.getCoordinates().replace());
return m;
statements.add((newInvocation));
} else {
statements.add(statement);
}
}
return statements;
}
*/

private boolean isPipeline(Statement statement) {
J.MethodInvocation methodInvocation = (J.MethodInvocation) statement;
// Only match method invocations where the select is an assertThat, containing a non-method call argument
if (PIPELINE.matches(methodInvocation.getSelect())) {
J.MethodInvocation invocation = (J.MethodInvocation) methodInvocation.getSelect();
if (invocation != null && PIPELINE.matches(invocation.getSelect())) {
return true;
var updated = method;
while (PIPELINE.matches(updated.getSelect()) && PIPELINE.matches(((MethodInvocation) updated.getSelect()).getSelect())) {
var select = updated.getSelect();
MethodInvocation invocation = (MethodInvocation) select;
if (PIPELINE.matches(invocation.getSelect())) {
MethodInvocation parent = (MethodInvocation) invocation.getSelect();
var args = new ArrayList<>(parent.getArguments());
args.addAll(invocation.getArguments());
updated = updated.withSelect(((MethodInvocation) invocation.getSelect()).withArguments(args));
}
}
return false;
return updated;
}

private J.MethodInvocation getCollapsedAssertThat(List<Statement> consecutiveAssertThatStatement) {
assert !consecutiveAssertThatStatement.isEmpty();
Space originalPrefix = consecutiveAssertThatStatement.get(0).getPrefix();
String continuationIndent = originalPrefix.getIndent().contains("\t") ? "\t\t" : " ";
Space indentedNewline = Space.format(originalPrefix.getLastWhitespace().replaceAll("^\\s+\n", "\n") +
continuationIndent);
J.MethodInvocation collapsed = null;
for (Statement st : consecutiveAssertThatStatement) {
J.MethodInvocation assertion = (J.MethodInvocation) st;
J.MethodInvocation assertThat = (J.MethodInvocation) assertion.getSelect();
assert assertThat != null;
J.MethodInvocation newSelect = collapsed == null ? assertThat : collapsed;
collapsed = assertion.getPadding().withSelect(build((Expression) newSelect.withPrefix(Space.EMPTY))
.withAfter(indentedNewline));
}
return collapsed.withPrefix(originalPrefix);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ public abstract class Morphia2RewriteTest extends MorphiaRewriteTest {
@Override
protected @NotNull String findMorphiaCore() {
var core = runtimeClasspath.stream()
.filter(uri -> {
String string = uri.toString();
return string.contains("morphia") && string.contains("core");
})
.findFirst().orElseThrow().toString();
.filter(uri -> {
String string = uri.toString();
return string.contains("morphia") && string.contains("core");
})
.findFirst().orElseThrow().toString();

final String artifact = core.contains("morphia-core") ? "morphia-core" : "morphia/core";
return artifact;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package dev.morphia.rewrite.recipes.test;

import io.github.classgraph.ClassGraph;
import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;

import org.jetbrains.annotations.NotNull;
import org.openrewrite.Recipe;
import org.openrewrite.java.JavaParser;
import org.openrewrite.test.RecipeSpec;
import org.openrewrite.test.RewriteTest;

import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
import io.github.classgraph.ClassGraph;

public abstract class MorphiaRewriteTest implements RewriteTest {
protected List<URI> runtimeClasspath = new ClassGraph().disableNestedJarScanning().getClasspathURIs();
Expand All @@ -32,9 +33,9 @@ public void defaults(RecipeSpec spec) {
@NotNull
protected List<String> findMongoArtifacts() {
List<String> classpath = runtimeClasspath.stream()
.filter(uri -> uri.toString().contains("mongodb") || uri.toString().contains("bson"))
.map(uri -> new File(uri).getAbsolutePath()/*.getName().replaceAll("-[0-9].*", "")*/)
.collect(ArrayList::new, List::add, List::addAll);
.filter(uri -> uri.toString().contains("mongodb") || uri.toString().contains("bson"))
.map(uri -> new File(uri).getAbsolutePath()/* .getName().replaceAll("-[0-9].*", "") */)
.collect(ArrayList::new, List::add, List::addAll);
return classpath;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,10 @@

import dev.morphia.rewrite.recipes.PipelineRewriteStage1;

import io.github.classgraph.ClassGraph;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Test;
import org.openrewrite.Recipe;

import java.io.File;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;

import static org.openrewrite.java.Assertions.java;

public class PipelineRewriteStage1Test extends Morphia2RewriteTest {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
package dev.morphia.rewrite.recipes.test;

import java.io.File;
import java.nio.file.Path;

import dev.morphia.rewrite.recipes.PipelineRewriteStage2;

import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.Test;
import org.openrewrite.Recipe;
import org.openrewrite.java.JavaParser;
import org.openrewrite.java.JavaParser.Builder;
import org.openrewrite.test.RecipeSpec;

import java.io.File;
import java.nio.file.Path;
import java.util.List;

import static org.openrewrite.java.Assertions.java;

public class PipelineRewriteStage2Test extends MorphiaRewriteTest {
Expand All @@ -30,65 +30,61 @@ public class PipelineRewriteStage2Test extends MorphiaRewriteTest {
@Test
void unwrapStageMethods() {
rewriteRun(
//language=java
java(
"""
import dev.morphia.aggregation.expressions.ComparisonExpressions;
import static dev.morphia.aggregation.expressions.AccumulatorExpressions.sum;
import static dev.morphia.aggregation.stages.Group.group;
import static dev.morphia.aggregation.stages.Group.id;
import static dev.morphia.aggregation.stages.Projection.project;
import static dev.morphia.aggregation.expressions.Expressions.field;
import static dev.morphia.aggregation.expressions.Expressions.value;
import static dev.morphia.aggregation.stages.Sort.sort;
import dev.morphia.aggregation.Aggregation;
import org.bson.Document;
public class UnwrapTest {
public void update(Aggregation<?> aggregation) {
aggregation
.pipeline(group(id("author")).field("count", sum(value(1))))
.pipeline(sort().ascending("1"))
.pipeline(sort().ascending("2"))
.pipeline(sort().ascending("3"))
.pipeline(sort().ascending("4"))
.execute(Document.class);
var dummy = 42;
}
}
""",
"""
import dev.morphia.aggregation.expressions.ComparisonExpressions;
import static dev.morphia.aggregation.expressions.AccumulatorExpressions.sum;
import static dev.morphia.aggregation.stages.Group.group;
import static dev.morphia.aggregation.stages.Group.id;
import static dev.morphia.aggregation.stages.Projection.project;
import static dev.morphia.aggregation.expressions.Expressions.field;
import static dev.morphia.aggregation.expressions.Expressions.value;
import static dev.morphia.aggregation.stages.Sort.sort;
import dev.morphia.aggregation.Aggregation;
import org.bson.Document;
public class UnwrapTest {
public void update(Aggregation<?> aggregation) {
aggregation
.pipeline(
group(id("author")).field("count", sum(value(1))),
sort().ascending("1"),
sort().ascending("2"),
sort().ascending("3"),
sort().ascending("4"))
.execute(Document.class);
}
}
"""));
} @Override
//language=java
java(
"""
import dev.morphia.aggregation.expressions.ComparisonExpressions;
import static dev.morphia.aggregation.expressions.AccumulatorExpressions.sum;
import static dev.morphia.aggregation.stages.Group.group;
import static dev.morphia.aggregation.stages.Group.id;
import static dev.morphia.aggregation.stages.Projection.project;
import static dev.morphia.aggregation.expressions.Expressions.field;
import static dev.morphia.aggregation.expressions.Expressions.value;
import static dev.morphia.aggregation.stages.Sort.sort;
import dev.morphia.aggregation.Aggregation;
import org.bson.Document;
public class UnwrapTest {
public void update(Aggregation<?> aggregation) {
aggregation
.pipeline(group(id("author")).field("count", sum(value(1))))
.pipeline(sort().ascending("1"))
.pipeline(sort().ascending("2"))
.pipeline(sort().ascending("3"))
.pipeline(sort().ascending("4"))
.execute(Document.class);
}
}
""",
"""
import dev.morphia.aggregation.expressions.ComparisonExpressions;
import static dev.morphia.aggregation.expressions.AccumulatorExpressions.sum;
import static dev.morphia.aggregation.stages.Group.group;
import static dev.morphia.aggregation.stages.Group.id;
import static dev.morphia.aggregation.stages.Projection.project;
import static dev.morphia.aggregation.expressions.Expressions.field;
import static dev.morphia.aggregation.expressions.Expressions.value;
import static dev.morphia.aggregation.stages.Sort.sort;
import dev.morphia.aggregation.Aggregation;
import org.bson.Document;
public class UnwrapTest {
public void update(Aggregation<?> aggregation) {
aggregation
.pipeline(group(id("author")).field("count", sum(value(1))),sort().ascending("1"),sort().ascending("2"),sort().ascending("3"),sort().ascending("4"))
.execute(Document.class);
}
}
"""));
}

@Override
protected @NotNull String findMorphiaCore() {
return classesFolder;
return classesFolder;
}

public String[] classpath() {
Expand All @@ -98,11 +94,11 @@ public String[] classpath() {
@Override
public void defaults(RecipeSpec spec) {
Builder<? extends JavaParser, ?> builder = JavaParser.fromJavaVersion()
.addClasspathEntry(Path.of(classesFolder));
.addClasspathEntry(Path.of(classesFolder));
findMongoArtifacts().stream().map(Path::of)
.forEach(builder::addClasspathEntry);
.forEach(builder::addClasspathEntry);
spec.recipe(getRecipe())
.parser(builder);
.parser(builder);
}

@Override
Expand All @@ -111,5 +107,4 @@ protected Recipe getRecipe() {
return new PipelineRewriteStage2();
}


}
Loading

0 comments on commit b9e37c6

Please sign in to comment.