diff --git a/jskat-base/src/main/java/org/jskat/control/iss/IssGameExtractor.java b/jskat-base/src/main/java/org/jskat/control/iss/IssGameExtractor.java index 2539f64c..d41abb80 100644 --- a/jskat-base/src/main/java/org/jskat/control/iss/IssGameExtractor.java +++ b/jskat-base/src/main/java/org/jskat/control/iss/IssGameExtractor.java @@ -45,7 +45,7 @@ private void filterGameDatabase(final Predicate predicate, final S try (final Stream stream = Files.lines(Paths.get(sourceFileName))) { final AtomicInteger count = new AtomicInteger(); final var filteredGames = stream - .skip(9_000_000) + .skip(6_000_000) .peek(logProgress(count)) //.peek(System.out::println) // TODO: fix parsing of filtered games @@ -78,7 +78,7 @@ private void filterGameDatabase(final Predicate predicate, final S .filter(predicate) //.map(SkatGameData::toString) .map(NETWORK_INPUTS) - .limit(100) + .limit(100_000) .collect(Collectors.toList()); final var lines = new ArrayList(); diff --git a/jskat-base/src/main/kotlin/org/jskat/ai/deeplearning/BidNetTrainer.kt b/jskat-base/src/main/kotlin/org/jskat/ai/deeplearning/BidNetTrainer.kt index abaa480c..e386546e 100644 --- a/jskat-base/src/main/kotlin/org/jskat/ai/deeplearning/BidNetTrainer.kt +++ b/jskat-base/src/main/kotlin/org/jskat/ai/deeplearning/BidNetTrainer.kt @@ -11,7 +11,7 @@ import ai.djl.nn.core.Linear import ai.djl.training.DefaultTrainingConfig import ai.djl.training.EasyTrain import ai.djl.training.evaluator.Accuracy -import ai.djl.training.initializer.NormalInitializer +import ai.djl.training.initializer.UniformInitializer import ai.djl.training.listener.TrainingListener import ai.djl.training.loss.Loss import ai.djl.training.optimizer.Optimizer @@ -25,7 +25,7 @@ fun main() { val batchSize = 10_000 val builder = DataFrameDataSet.Builder() builder.filePath = "data/kermit_games.csv" - builder.setSampling(10, true) + builder.setSampling(batchSize, true) builder.addCategoricalFeature("declarer", true) listOf( "♣A", "♣T", "♣K", "♣Q", "♣J", "♣9", "♣8", "♣7", @@ -51,13 +51,13 @@ fun main() { val featurizer2 = dataSet.labels[0].featurizer println("Featurizer gameType: ${featurizer2.dataRequired()}") - val defeature2 = featurizer2.deFeaturize(floatArrayOf(1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.7f)) + val defeature2 = featurizer2.deFeaturize(floatArrayOf(1.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.7f)) println("DeFeature: $defeature2") if (defeature2 is Classifications) { println(defeature2.best().className) } - val trainTest = dataSet.randomSplit(8, 2) + val trainTest = dataSet.randomSplit(80, 20) val training = trainTest[0] val test = trainTest[1] @@ -81,12 +81,16 @@ fun main() { // block.add(Activation::relu) // block.add(Linear.builder().setUnits(512).build()) // block.add(Activation::relu) - block.add(Linear.builder().setUnits(256).build()) - block.add(Activation::relu) +// block.add(Linear.builder().setUnits(256).build()) +// block.add(Activation::relu) block.add(Linear.builder().setUnits(128).build()) block.add(Activation::relu) + block.add(Linear.builder().setUnits(64).build()) + block.add(Activation::relu) + block.add(Linear.builder().setUnits(32).build()) + block.add(Activation::relu) block.add(Linear.builder().setUnits(outputSize).build()) - block.setInitializer(NormalInitializer(), Parameter.Type.WEIGHT) + block.setInitializer(UniformInitializer(), Parameter.Type.WEIGHT) val model = Model.newInstance("bidnet") model.block = block