-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
haesleinhuepf
committed
Aug 23, 2020
1 parent
ae3f347
commit e91da23
Showing
10 changed files
with
174 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
49 changes: 49 additions & 0 deletions
49
src/main/java/net/haesleinhuepf/clijx/incubator/optimize/GradientDescentOptimizer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package net.haesleinhuepf.clijx.incubator.optimize; | ||
|
||
import org.apache.commons.math3.analysis.MultivariateFunction; | ||
import org.apache.commons.math3.optim.InitialGuess; | ||
import org.apache.commons.math3.optim.MaxEval; | ||
import org.apache.commons.math3.optim.PointValuePair; | ||
import org.apache.commons.math3.optim.SimpleValueChecker; | ||
import org.apache.commons.math3.optim.nonlinear.scalar.*; | ||
import org.apache.commons.math3.optim.nonlinear.scalar.gradient.NonLinearConjugateGradientOptimizer; | ||
import org.apache.commons.math3.random.GaussianRandomGenerator; | ||
import org.apache.commons.math3.random.JDKRandomGenerator; | ||
import org.apache.commons.math3.random.RandomVectorGenerator; | ||
import org.apache.commons.math3.random.UncorrelatedRandomVectorGenerator; | ||
|
||
import java.util.Arrays; | ||
|
||
import static net.haesleinhuepf.clijx.incubator.optimize.OptimizationUtilities.range; | ||
|
||
public class GradientDescentOptimizer implements Optimizer { | ||
int iterations = 6; | ||
public GradientDescentOptimizer() { | ||
|
||
} | ||
public GradientDescentOptimizer(int iterations) { | ||
this.iterations = iterations; | ||
} | ||
|
||
@Override | ||
public double[] optimize(double[] current, Workflow workflow, int[] parameter_index_map, MultivariateFunction fitness) { | ||
GradientMultivariateOptimizer underlying = new NonLinearConjugateGradientOptimizer(NonLinearConjugateGradientOptimizer.Formula.POLAK_RIBIERE, new SimpleValueChecker(1e-10, 1e-10)); | ||
|
||
for (int i = 0; i < iterations; i++) { | ||
|
||
double[] stdDev = range(current.length, workflow.getNumericParameterNames(), parameter_index_map, Math.pow(2, iterations / 2 - i - 1)); | ||
|
||
RandomVectorGenerator generator = new UncorrelatedRandomVectorGenerator(current, stdDev, new GaussianRandomGenerator(new JDKRandomGenerator())); | ||
int nbStarts = 10; | ||
MultiStartMultivariateOptimizer optimizer = new MultiStartMultivariateOptimizer(underlying, nbStarts, generator); | ||
|
||
PointValuePair solution = optimizer.optimize(new MaxEval(1000), new ObjectiveFunction(fitness), new ObjectiveFunctionGradient(new GradientOfMultivariateFunction(fitness, stdDev)), GoalType.MINIMIZE, new InitialGuess(current)); | ||
|
||
current = solution.getKey(); | ||
System.out.println("Intermediate optimum: " + Arrays.toString(current)); | ||
} | ||
|
||
return current; | ||
} | ||
|
||
} |
33 changes: 33 additions & 0 deletions
33
src/main/java/net/haesleinhuepf/clijx/incubator/optimize/GradientOfMultivariateFunction.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
package net.haesleinhuepf.clijx.incubator.optimize; | ||
|
||
import org.apache.commons.math3.analysis.MultivariateFunction; | ||
import org.apache.commons.math3.analysis.MultivariateVectorFunction; | ||
|
||
public class GradientOfMultivariateFunction implements MultivariateVectorFunction { | ||
private MultivariateFunction function; | ||
private double[] steps; | ||
|
||
public GradientOfMultivariateFunction(MultivariateFunction function, double[] steps) { | ||
this.function = function; | ||
this.steps = steps; | ||
} | ||
|
||
@Override | ||
public double[] value(double[] point) throws IllegalArgumentException { | ||
double[] result = new double[point.length]; | ||
for (int i = 0; i < result.length; i++) { | ||
double[] input = new double[point.length]; | ||
|
||
System.arraycopy(point, 0, input, 0, input.length); | ||
input[i] -= steps[i]; | ||
double a = function.value(input); | ||
|
||
System.arraycopy(point, 0, input, 0, input.length); | ||
input[i] += steps[i]; | ||
double b = function.value(input); | ||
|
||
result[i] = (a - b) / steps[i] / 2; | ||
} | ||
return result; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
7 changes: 7 additions & 0 deletions
7
src/main/java/net/haesleinhuepf/clijx/incubator/optimize/Optimizer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
package net.haesleinhuepf.clijx.incubator.optimize; | ||
|
||
import org.apache.commons.math3.analysis.MultivariateFunction; | ||
|
||
public interface Optimizer { | ||
double[] optimize(double[] current, Workflow workflow, int[] parameter_index_map, MultivariateFunction fitness); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
53 changes: 53 additions & 0 deletions
53
src/main/java/net/haesleinhuepf/clijx/incubator/optimize/SimplexOptimizer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
package net.haesleinhuepf.clijx.incubator.optimize; | ||
|
||
import net.imagej.ops.Op; | ||
import org.apache.commons.math3.analysis.MultivariateFunction; | ||
import org.apache.commons.math3.optim.InitialGuess; | ||
import org.apache.commons.math3.optim.MaxEval; | ||
import org.apache.commons.math3.optim.PointValuePair; | ||
import org.apache.commons.math3.optim.nonlinear.scalar.GoalType; | ||
import org.apache.commons.math3.optim.nonlinear.scalar.ObjectiveFunction; | ||
import org.apache.commons.math3.optim.nonlinear.scalar.noderiv.NelderMeadSimplex; | ||
|
||
import java.util.Arrays; | ||
|
||
import static net.haesleinhuepf.clijx.incubator.optimize.OptimizationUtilities.range; | ||
|
||
public class SimplexOptimizer implements Optimizer { | ||
int iterations = 6; | ||
public SimplexOptimizer() { | ||
|
||
} | ||
public SimplexOptimizer(int iterations) { | ||
this.iterations = iterations; | ||
} | ||
|
||
@Override | ||
public double[] optimize(double[] current, Workflow workflow, int[] parameter_index_map, MultivariateFunction fitness) { | ||
org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer optimizer = new org.apache.commons.math3.optim.nonlinear.scalar.noderiv.SimplexOptimizer(-1, 1e-5); | ||
for (int i = 0; i < iterations; i++) { | ||
NelderMeadSimplex simplex = makeSimplexOptimizer(current.length, workflow.getNumericParameterNames(), parameter_index_map, Math.pow(2, iterations / 2 - i - 1)); | ||
//double[] lowerBounds = new double[simplex.getDimension()]; | ||
//double[] upperBounds = new double[simplex.getDimension()]; | ||
//for (int b = 0; b < upperBounds.length; b++) { | ||
// upperBounds[b] = Double.MAX_VALUE; | ||
//} | ||
//, new SimpleBounds(lowerBounds, upperBounds) | ||
PointValuePair solution = optimizer.optimize(new MaxEval(1000), new InitialGuess(current), simplex, new ObjectiveFunction(fitness), GoalType.MINIMIZE); | ||
|
||
current = solution.getKey(); | ||
System.out.println("Intermediate optimum: " + Arrays.toString(current)); | ||
} | ||
return current; | ||
} | ||
|
||
private static NelderMeadSimplex makeSimplexOptimizer(int numDimensions, String[] numericParameterNames, int[] parameter_index_map, double factor) { | ||
double[] steps = range(numDimensions, numericParameterNames, parameter_index_map, factor); | ||
|
||
System.out.println("Step lengths: " + Arrays.toString(steps) ); | ||
|
||
NelderMeadSimplex simplex = new NelderMeadSimplex(steps); | ||
return simplex; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters