Skip to content

Commit

Permalink
Merge pull request #110 from g-degiorgi/feature/create_cpt
Browse files Browse the repository at this point in the history
added create CPT feature with interval helper methods
  • Loading branch information
davidhuber authored Aug 22, 2023
2 parents 42c5912 + 9490b27 commit 96eb66b
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 0 deletions.
125 changes: 125 additions & 0 deletions src/main/java/ch/idsia/crema/preprocess/creators/CreateCPT.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
package ch.idsia.crema.preprocess.creators;

import ch.idsia.crema.core.Strides;
import ch.idsia.crema.factor.credal.linear.interval.IntervalFactor;
import ch.idsia.crema.factor.credal.linear.interval.IntervalFactorFactory;
import ch.idsia.crema.utility.IndexIterator;

import java.util.Arrays;

public class CreateCPT {

/**
* Border returns the borders of the intervals
*
* @param lower array of lower bounds for all the variables
* ex: double[] lower = new double[]{1.55, 55.0};
* @param upper array of upper bounds for all the variables
* ex: double[] upper = new double[]{1.90, 115.0};
*/
public double[] borders(double[] lower, double[] upper, Op operation) {

//shortcut in case the function is strictly monotone growing
//double[] extremes= new double[]{operation.execute(lower[0], upper[1]), operation.execute(upper[0], lower[1])};
//return Arrays.stream(extremes).sorted().toArray();

double[] results = new double[2 * lower.length];
results[0] = operation.execute(lower[0], lower[1]);
results[1] = operation.execute(lower[0], upper[1]);
results[2] = operation.execute(upper[0], lower[1]);
results[3] = operation.execute(upper[0], upper[1]);

Arrays.sort(results);
double firstElement = results[0];
double lastElement = results[results.length - 1];

return new double[]{firstElement, lastElement};
}

/**
* Method to create a CPT
*
* @param childVar variable of the child
* @param parentsVars array of variables of the parents
* @param childCuts array of cuts for the child
* @param parentCuts array of cuts for the parents
* @param operation operation to be performed with the cuts
* example K(bmi|w,H)
* @return IntervalFactor representing the CPT of the child given the parents
*/
public IntervalFactor create(int childVar, int[] parentsVars, double[] childCuts, double[][] parentCuts, Op operation) {

// root nodes creation
// add child node
int dimChild = childCuts.length + 1;
Strides stridesChild = Strides.var(childVar, dimChild);

// create domain
Strides dom = Strides.empty();
for (int i = 0; i < parentsVars.length; i++) {
dom = dom.and(parentsVars[i], parentCuts[i].length - 1);
}
// add parents nodes
IntervalFactorFactory factory = IntervalFactorFactory.factory().domain(stridesChild, dom);

// create iterator
IndexIterator iterator = dom.getIterator();
//iterate over all possible combinations
while (iterator.hasNext()) {
int[] comb = iterator.getPositions().clone();
// be aware that the structure returned by the method has to be compliant with the child
double[] parentIntervalLower = new double[parentCuts.length];
double[] parentIntervalUpper = new double[parentCuts.length];

for (int i = 0; i < parentCuts.length; i++) {
parentIntervalLower[i] = parentCuts[i][comb[i]];
parentIntervalUpper[i] = parentCuts[i][comb[i] + 1];
}

double[] intervalBorders = borders(parentIntervalLower, parentIntervalUpper, operation);
//map the integers of the position of the childCuts
int[] intervalNumber = Arrays.stream(intervalBorders).mapToInt(val -> whichPosition(childCuts, val)).toArray();

// the lower is set to an array of zeroes the upper is set to 1 in the position of the interval number
factory.set(new double[dimChild], createUpper(intervalNumber, dimChild), comb);
iterator.next();
}
return factory.get();
}

/**
* @param interval array containing the position
* @param dim dimension of the array to be generated
* @return double array with 1.0 set for every interval value
*/
public double[] createUpper(int[] interval, int dim) {
double[] upper = new double[dim];

// we set to 1.0 all the element relative to the interval
for (int ind : interval) {
upper[ind - 1] = 1.0;
}
return upper;
}

/**
* Method to find the interval containing the specified value
* Specials: if the number is lower than the first cut, it will be placed in the first interval
* if the number is higher than the last cut, it will be placed in the last interval
*
* @param cutsX array of cuts, typically for the discretization
* @param X value that we want to place
* @return integer of the interval containing x
*/
//greedy
public int whichPosition(double[] cutsX, double X) {
int position = 0; //starts from 1
for (int i = 1; i < cutsX.length - 1; i++) {
if (X <= cutsX[i]) {
break;
}
position++;
}
return position + 1;
}
}
6 changes: 6 additions & 0 deletions src/main/java/ch/idsia/crema/preprocess/creators/Op.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
package ch.idsia.crema.preprocess.creators;

@FunctionalInterface
public interface Op {
double execute(double a, double b);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package ch.idsia.crema.preprocess.creators;

import ch.idsia.crema.factor.credal.linear.interval.IntervalFactor;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;

public class CreateCPTTest {

private final CreateCPT creator = new CreateCPT();

@Test
public void testCreate() {

//BMI example
double[] cutsH = new double[]{1.55, 1.60, 1.65, 1.70, 1.75, 1.80, 1.85, 1.90, 1.95, 2.00};
double[] cutsW = new double[]{55.0, 60.0, 65.0, 70.0, 75.0, 80.0, 85.0, 90.0, 95.0, 100.0, 105.0, 110.0, 115.0};
double[] cutsBMI = new double[]{0.0, 15.0, 16, 18.5, 25.0, 30.0, 35.0, 40.0, 100.0};

double[][] parents = new double[][]{cutsW, cutsH};
Op bmi = (w, h) -> w / h / h;

IntervalFactor cpt = creator.create(2, new int[]{0, 1}, cutsBMI, parents, bmi);
//System.out.println(cpt);

// w1 and h1
assertArrayEquals(cpt.getLower(0, 0), new double[10]);
assertArrayEquals(cpt.getUpper(0, 0), creator.createUpper(new int[]{4}, 10));

// w2 and h1
assertArrayEquals(cpt.getLower(1, 0), new double[10]);
assertArrayEquals(cpt.getUpper(1, 0), creator.createUpper(new int[]{4, 5}, 10));

// w5 and h6
assertArrayEquals(cpt.getLower(4, 5), new double[10]);
assertArrayEquals(cpt.getUpper(4, 5), creator.createUpper(new int[]{4}, 10));

// w12 and h9
assertArrayEquals(cpt.getLower(11, 8), new double[10]);
assertArrayEquals(cpt.getUpper(11, 8), creator.createUpper(new int[]{5, 6}, 10));
}

@Test
public void testBorders() {
double[] parentIntervalLower = new double[]{55.0, 1.55};
double[] parentIntervalUpper = new double[]{60.0, 1.60};
//example of the BMI
Op bmi = (w, h) -> w / h / h;

//bmi low = 21.48
//bmi high = 24.97
double tolerance = 0.001;
double[] interval = creator.borders(parentIntervalLower, parentIntervalUpper, bmi);

assertArrayEquals(new double[]{21.484, 24.973}, interval, tolerance);
}

@Test
public void testWhichPosition() {
double[] intervals = new double[]{0.0, 15.0, 16, 18.5, 25.0, 30.0, 35.0, 40.0, 100.0};
assertEquals(1, creator.whichPosition(intervals, -10));
assertEquals(1, creator.whichPosition(intervals, 11));
assertEquals(1, creator.whichPosition(intervals, 15));
assertEquals(3, creator.whichPosition(intervals, 17));
assertEquals(5, creator.whichPosition(intervals, 29));
assertEquals(8, creator.whichPosition(intervals, 41));
assertEquals(8, creator.whichPosition(intervals, 200));
}

@Test
public void testCreateUpper() {
int[] interval = new int[]{3, 4};
double[] result = creator.createUpper(interval, 4);
double[] expected = new double[]{.0, .0, 1.0, 1.0};

assertArrayEquals(result, expected);
}
}

0 comments on commit 96eb66b

Please sign in to comment.