Skip to content

Commit

Permalink
refactoring: Optimiziers
Browse files Browse the repository at this point in the history
  • Loading branch information
haesleinhuepf committed Aug 23, 2020
1 parent ae3f347 commit e91da23
Show file tree
Hide file tree
Showing 10 changed files with 174 additions and 36 deletions.
3 changes: 2 additions & 1 deletion src/main/java/net/haesleinhuepf/IncubatorPlayground.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ public static void main(String... args) throws FileNotFoundException {

//ImagePlus imp = IJ.openImage("C:/structure/data/spim_TL18_Angle0-1.tif");
//ImagePlus imp = IJ.openImage("D:/structure/data/Irene/ISB200714_well5_1pos_3h_MyosinGFP-small.tif");
ImagePlus imp = IJ.openImage("C:/structure/data/mitosis.tif");
//ImagePlus imp = IJ.openImage("C:/structure/data/mitosis.tif");
ImagePlus imp = IJ.openImage("C:/structure/data/blobs.tif");
imp.show();

//ImagePlus imp1 = IJ.openImage("C:\\structure\\teaching\\lecture_applied_bioimage_analysis_2020\\12_Volumetric_image_data\\data\\000200.raw.tif");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@
import net.haesleinhuepf.clij2.utilities.IsCategorized;
import net.haesleinhuepf.clijx.incubator.interactive.handcrafted.Crop;
import net.haesleinhuepf.clijx.incubator.interactive.handcrafted.ExtractChannel;
import net.haesleinhuepf.clijx.incubator.optimize.AnnotationTool;
import net.haesleinhuepf.clijx.incubator.optimize.BinaryImageFitnessFunction;
import net.haesleinhuepf.clijx.incubator.optimize.OptimizationUtilities;
import net.haesleinhuepf.clijx.incubator.optimize.Workflow;
import net.haesleinhuepf.clijx.incubator.optimize.*;
import net.haesleinhuepf.clijx.incubator.scriptgenerator.*;
import net.haesleinhuepf.clijx.incubator.services.CLIJMacroPluginService;
import net.haesleinhuepf.clijx.incubator.utilities.*;
Expand All @@ -42,7 +39,6 @@
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer;
import org.apache.commons.math3.optim.univariate.UnivariatePointValuePair;

import java.awt.*;
Expand Down Expand Up @@ -562,13 +558,19 @@ protected PopupMenu buildPopup(MouseEvent e, ImagePlus my_source, ImagePlus my_t
// -------------------------------------------------------------------------------------------------------------
Menu more_actions = new Menu("More actions");
if (IncubatorUtilities.resultIsBinaryImage(this)) {
addMenuAction(more_actions, "Optimize parameters (auto)", (a) -> {
optimize(false);
addMenuAction(more_actions, "Optimize parameters (simplex, auto)", (a) -> {
optimize(new SimplexOptimizer(), false);
});
addMenuAction(more_actions, "Optimize parameters (config)", (a) -> {
optimize(true);
addMenuAction(more_actions, "Optimize parameters (gradient descent, auto)", (a) -> {
optimize(new GradientDescentOptimizer(), false);
});
more_actions.add("-");
addMenuAction(more_actions, "Optimize parameters (simplex, configurable)", (a) -> {
optimize(new SimplexOptimizer((int)IJ.getNumber( "Range",6 )), true);
});
addMenuAction(more_actions, "Optimize parameters (gradient descent, configurable)", (a) -> {
optimize(new GradientDescentOptimizer((int)IJ.getNumber( "Range",6 )), true);
});
}

menu.add(more_actions);
Expand Down Expand Up @@ -932,14 +934,17 @@ public String getName() {
return plugin.getName().replace("CLIJ2_", "").replace("CLIJx_", "");
}

public void optimize(boolean show_gui) {
public void optimize(Optimizer optimizer, boolean show_gui) {
CLIJ2 clij2 = CLIJx.getInstance();

// -------------------------------------------------------------------------------------------------------------
// determine ground truth
RoiManager rm = RoiManager.getRoiManager();
if (rm.getCount() == 0) {
IJ.log("Please define reference ROIs in the ROI Manager.\nThese ROIs should have names starting with 'p' for positive and 'n' for negative.");
IJ.log("Please define reference ROIs in the ROI Manager.\n\n" +
"These ROIs should have names starting with 'p' for positive and 'n' for negative.\n\n" +
"The just activated annotation tool can help you with that: Use the left mouse button to annotate positive regions.\n" +
"Additionally hold CTRL/CMD to annotate negative (background) regions.");
Toolbar.addPlugInTool(new AnnotationTool());
return;
}
Expand Down Expand Up @@ -981,26 +986,12 @@ public void optimize(boolean show_gui) {
mask
);

SimplexOptimizer optimizer = new SimplexOptimizer(-1, 1e-5);

double[] current = f.getCurrent();
System.out.println("Initial: " + Arrays.toString(current));

int iterations = 6;
for (int i = 0; i < iterations; i++) {
//current = Optimizers.optimizeSimplex(current, workflow, parameter_index_map, f);
current = optimizer.optimize(current, workflow, parameter_index_map, f);

NelderMeadSimplex simplex = OptimizationUtilities.makeOptimizer(f.getNumDimensions(), workflow.getNumericParameterNames(), parameter_index_map, Math.pow(2, iterations / 2 - i - 1));
//double[] lowerBounds = new double[simplex.getDimension()];
//double[] upperBounds = new double[simplex.getDimension()];
//for (int b = 0; b < upperBounds.length; b++) {
// upperBounds[b] = Double.MAX_VALUE;
//}
//, new SimpleBounds(lowerBounds, upperBounds)
PointValuePair solution = optimizer.optimize(new MaxEval(1000), new InitialGuess(current), simplex, new ObjectiveFunction(f), GoalType.MINIMIZE);

current = solution.getKey();
System.out.println("Intermediate optimum: " + Arrays.toString(current));
}

System.out.println("Optimum: ");
f.value(current);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import net.haesleinhuepf.clij.clearcl.ClearCLBuffer;
import net.haesleinhuepf.clij2.CLIJ2;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;

import java.util.Arrays;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package net.haesleinhuepf.clijx.incubator.optimize;

import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.SimpleValueChecker;
import org.apache.commons.math3.optim.nonlinear.scalar.*;
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer;
import org.apache.commons.math3.random.GaussianRandomGenerator;
import org.apache.commons.math3.random.JDKRandomGenerator;
import org.apache.commons.math3.random.RandomVectorGenerator;
import org.apache.commons.math3.random.UncorrelatedRandomVectorGenerator;

import java.util.Arrays;

import static net.haesleinhuepf.clijx.incubator.optimize.OptimizationUtilities.range;

public class GradientDescentOptimizer implements Optimizer {
int iterations = 6;
public GradientDescentOptimizer() {

}
public GradientDescentOptimizer(int iterations) {
this.iterations = iterations;
}

@Override
public double[] optimize(double[] current, Workflow workflow, int[] parameter_index_map, MultivariateFunction fitness) {
GradientMultivariateOptimizer underlying = new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE, new SimpleValueChecker(1e-10, 1e-10));

for (int i = 0; i < iterations; i++) {

double[] stdDev = range(current.length, workflow.getNumericParameterNames(), parameter_index_map, Math.pow(2, iterations / 2 - i - 1));

RandomVectorGenerator generator = new UncorrelatedRandomVectorGenerator(current, stdDev, new GaussianRandomGenerator(new JDKRandomGenerator()));
int nbStarts = 10;
MultiStartMultivariateOptimizer optimizer = new MultiStartMultivariateOptimizer(underlying, nbStarts, generator);

PointValuePair solution = optimizer.optimize(new MaxEval(1000), new ObjectiveFunction(fitness), new ObjectiveFunctionGradient(new GradientOfMultivariateFunction(fitness, stdDev)), GoalType.MINIMIZE, new InitialGuess(current));

current = solution.getKey();
System.out.println("Intermediate optimum: " + Arrays.toString(current));
}

return current;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
package net.haesleinhuepf.clijx.incubator.optimize;

import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.analysis.MultivariateVectorFunction;

public class GradientOfMultivariateFunction implements MultivariateVectorFunction {
private MultivariateFunction function;
private double[] steps;

public GradientOfMultivariateFunction(MultivariateFunction function, double[] steps) {
this.function = function;
this.steps = steps;
}

@Override
public double[] value(double[] point) throws IllegalArgumentException {
double[] result = new double[point.length];
for (int i = 0; i < result.length; i++) {
double[] input = new double[point.length];

System.arraycopy(point, 0, input, 0, input.length);
input[i] -= steps[i];
double a = function.value(input);

System.arraycopy(point, 0, input, 0, input.length);
input[i] += steps[i];
double b = function.value(input);

result[i] = (a - b) / steps[i] / 2;
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,22 @@ public static int hammingStringDistance(String name1, String name2) {
}


public static NelderMeadSimplex makeOptimizer(int numDimensions, String[] numericParameterNames, int[] parameter_index_map, double factor) {
static double[] range(int numDimensions, String[] numericParameterNames, int[] parameter_index_map, double factor) {
double[] steps = new double[numDimensions];
for (int i = 0; i < steps.length; i++) {
steps[i] = 1;
}

for (int j = 0; j < parameter_index_map.length; j++) {
int i = parameter_index_map[j];
System.out.println("param: " + j + " -> " + i);
if (i >= 0) {
steps[i] = IncubatorUtilities.parmeterNameToStepSizeSuggestion(numericParameterNames[i], true) * factor;
System.out.println("name: " + numericParameterNames[j]);
steps[i] = IncubatorUtilities.parmeterNameToStepSizeSuggestion(numericParameterNames[j], true) * factor;
System.out.println("step: " + steps[i]);
}
}

System.out.println("Step lengths: " + Arrays.toString(steps) );

NelderMeadSimplex simplex = new NelderMeadSimplex(steps);
return simplex;
return steps;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package net.haesleinhuepf.clijx.incubator.optimize;

import org.apache.commons.math3.analysis.MultivariateFunction;

public interface Optimizer {
double[] optimize(double[] current, Workflow workflow, int[] parameter_index_map, MultivariateFunction fitness);
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import java.util.ArrayList;
import java.util.Arrays;

public class OptimizerPlayground {
class OptimizerPlayground {
public static void main(String[] args) {

new ImageJ();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package net.haesleinhuepf.clijx.incubator.optimize;

import net.imagej.ops.Op;
import org.apache.commons.math3.analysis.MultivariateFunction;
import org.apache.commons.math3.optim.InitialGuess;
import org.apache.commons.math3.optim.MaxEval;
import org.apache.commons.math3.optim.PointValuePair;
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType;
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction;
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex;

import java.util.Arrays;

import static net.haesleinhuepf.clijx.incubator.optimize.OptimizationUtilities.range;

public class SimplexOptimizer implements Optimizer {
int iterations = 6;
public SimplexOptimizer() {

}
public SimplexOptimizer(int iterations) {
this.iterations = iterations;
}

@Override
public double[] optimize(double[] current, Workflow workflow, int[] parameter_index_map, MultivariateFunction fitness) {
org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer optimizer = new org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer(-1, 1e-5);
for (int i = 0; i < iterations; i++) {
NelderMeadSimplex simplex = makeSimplexOptimizer(current.length, workflow.getNumericParameterNames(), parameter_index_map, Math.pow(2, iterations / 2 - i - 1));
//double[] lowerBounds = new double[simplex.getDimension()];
//double[] upperBounds = new double[simplex.getDimension()];
//for (int b = 0; b < upperBounds.length; b++) {
// upperBounds[b] = Double.MAX_VALUE;
//}
//, new SimpleBounds(lowerBounds, upperBounds)
PointValuePair solution = optimizer.optimize(new MaxEval(1000), new InitialGuess(current), simplex, new ObjectiveFunction(fitness), GoalType.MINIMIZE);

current = solution.getKey();
System.out.println("Intermediate optimum: " + Arrays.toString(current));
}
return current;
}

private static NelderMeadSimplex makeSimplexOptimizer(int numDimensions, String[] numericParameterNames, int[] parameter_index_map, double factor) {
double[] steps = range(numDimensions, numericParameterNames, parameter_index_map, factor);

System.out.println("Step lengths: " + Arrays.toString(steps) );

NelderMeadSimplex simplex = new NelderMeadSimplex(steps);
return simplex;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,9 @@ public static double parmeterNameToStepSizeSuggestion(String parameterName, bool
if (parameterName.toLowerCase().contains("long range")) {
return small_step ? 64 : 256;
}
if (parameterName.toLowerCase().contains("constant")) {
return small_step ? 10 : 100;
}
return small_step ? 1 : 10;
}

Expand Down

0 comments on commit e91da23

Please sign in to comment.