Skip to content

Commit

Permalink
#157 Next try on training data
Browse files Browse the repository at this point in the history
  • Loading branch information
b0n541 committed Sep 9, 2024
1 parent 15b167c commit a1f6639
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private void filterGameDatabase(final Predicate<SkatGameData> predicate, final S
try (final Stream<String> stream = Files.lines(Paths.get(sourceFileName))) {
final AtomicInteger count = new AtomicInteger();
final var filteredGames = stream
.skip(9_000_000)
.skip(6_000_000)
.peek(logProgress(count))
//.peek(System.out::println)
// TODO: fix parsing of filtered games
Expand Down Expand Up @@ -78,7 +78,7 @@ private void filterGameDatabase(final Predicate<SkatGameData> predicate, final S
.filter(predicate)
//.map(SkatGameData::toString)
.map(NETWORK_INPUTS)
.limit(100)
.limit(100_000)
.collect(Collectors.toList());

final var lines = new ArrayList<String>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import ai.djl.nn.core.Linear
import ai.djl.training.DefaultTrainingConfig
import ai.djl.training.EasyTrain
import ai.djl.training.evaluator.Accuracy
import ai.djl.training.initializer.NormalInitializer
import ai.djl.training.initializer.UniformInitializer
import ai.djl.training.listener.TrainingListener
import ai.djl.training.loss.Loss
import ai.djl.training.optimizer.Optimizer
Expand All @@ -25,7 +25,7 @@ fun main() {
val batchSize = 10_000
val builder = DataFrameDataSet.Builder()
builder.filePath = "data/kermit_games.csv"
builder.setSampling(10, true)
builder.setSampling(batchSize, true)
builder.addCategoricalFeature("declarer", true)
listOf(
"♣A", "♣T", "♣K", "♣Q", "♣J", "♣9", "♣8", "♣7",
Expand All @@ -51,13 +51,13 @@ fun main() {

val featurizer2 = dataSet.labels[0].featurizer
println("Featurizer gameType: ${featurizer2.dataRequired()}")
val defeature2 = featurizer2.deFeaturize(floatArrayOf(1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.7f))
val defeature2 = featurizer2.deFeaturize(floatArrayOf(1.0f, 0.1f, 0.2f, 0.3f, 0.4f, 0.7f))
println("DeFeature: $defeature2")
if (defeature2 is Classifications) {
println(defeature2.best<Classifications.Classification?>().className)
}

val trainTest = dataSet.randomSplit(8, 2)
val trainTest = dataSet.randomSplit(80, 20)
val training = trainTest[0]
val test = trainTest[1]

Expand All @@ -81,12 +81,16 @@ fun main() {
// block.add(Activation::relu)
// block.add(Linear.builder().setUnits(512).build())
// block.add(Activation::relu)
block.add(Linear.builder().setUnits(256).build())
block.add(Activation::relu)
// block.add(Linear.builder().setUnits(256).build())
// block.add(Activation::relu)
block.add(Linear.builder().setUnits(128).build())
block.add(Activation::relu)
block.add(Linear.builder().setUnits(64).build())
block.add(Activation::relu)
block.add(Linear.builder().setUnits(32).build())
block.add(Activation::relu)
block.add(Linear.builder().setUnits(outputSize).build())
block.setInitializer(NormalInitializer(), Parameter.Type.WEIGHT)
block.setInitializer(UniformInitializer(), Parameter.Type.WEIGHT)

val model = Model.newInstance("bidnet")
model.block = block
Expand Down

0 comments on commit a1f6639

Please sign in to comment.