mirror of
https://github.com/Card-Forge/forge.git
synced 2025-11-18 03:38:01 +00:00
Fully working LDA based deck generation for Standard and Modern
(cherry picked from commit 892ae23)
This commit is contained in:
@@ -1,204 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.dataset;
|
||||
|
||||
import com.google.common.base.Predicates;
|
||||
import com.google.common.collect.Iterables;
|
||||
import com.google.common.collect.Lists;
|
||||
import forge.card.CardRulesPredicates;
|
||||
import forge.deck.Deck;
|
||||
import forge.deck.io.DeckStorage;
|
||||
import forge.game.GameFormat;
|
||||
import forge.item.PaperCard;
|
||||
import forge.properties.ForgeConstants;
|
||||
import forge.util.storage.IStorage;
|
||||
import forge.util.storage.StorageImmediatelySerialized;
|
||||
|
||||
import java.io.*;
|
||||
import java.util.*;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* This class is immutable.
|
||||
*/
|
||||
public final class BagOfWords {
|
||||
|
||||
private final int numDocs;
|
||||
private final int numVocabs;
|
||||
private final int numNNZ;
|
||||
private final int numWords;
|
||||
|
||||
public Vocabularies getVocabs() {
|
||||
return vocabs;
|
||||
}
|
||||
|
||||
private final Vocabularies vocabs;
|
||||
|
||||
// docID -> the vocabs sequence in the doc
|
||||
private Map<Integer, List<Integer>> words;
|
||||
|
||||
// docID -> the doc length
|
||||
private Map<Integer, Integer> docLength;
|
||||
|
||||
|
||||
/**
|
||||
* Read the bag-of-words dataset.
|
||||
* @throws FileNotFoundException
|
||||
* @throws IOException
|
||||
* @throws Exception
|
||||
* @throws NullPointerException filePath is null
|
||||
*/
|
||||
public BagOfWords(GameFormat format) throws FileNotFoundException, IOException, Exception {
|
||||
IStorage<Deck> decks = new StorageImmediatelySerialized<Deck>("Generator", new DeckStorage(new File(ForgeConstants.DECK_GEN_DIR+ForgeConstants.PATH_SEPARATOR+format.getName()),
|
||||
ForgeConstants.DECK_GEN_DIR, false),
|
||||
true);
|
||||
|
||||
List<PaperCard> cardList = format.getAllCards();
|
||||
|
||||
List<Deck> legalDecks = new ArrayList<>();
|
||||
for(Deck deck:decks){
|
||||
if(format.isDeckLegal(deck)){
|
||||
legalDecks.add(deck);
|
||||
}
|
||||
}
|
||||
|
||||
this.words = new HashMap<>();
|
||||
this.docLength = new HashMap<>();
|
||||
ArrayList<Vocabulary> vocabList = new ArrayList<Vocabulary>();
|
||||
|
||||
int numDocs = legalDecks.size();
|
||||
int numVocabs = cardList.size();
|
||||
int numNNZ = 0;
|
||||
int numWords = 0;
|
||||
|
||||
Map<String, Integer> cardIntegerMap = new HashMap<>();
|
||||
Map<Integer, PaperCard> integerCardMap = new HashMap<>();
|
||||
for (int i=0; i<cardList.size(); ++i){
|
||||
cardIntegerMap.put(cardList.get(i).getName(), i);
|
||||
vocabList.add(new Vocabulary(i,cardList.get(i).getName()));
|
||||
integerCardMap.put(i, cardList.get(i));
|
||||
}
|
||||
|
||||
this.vocabs = new Vocabularies(vocabList);
|
||||
int deckID = 0;
|
||||
for (Deck deck:legalDecks){
|
||||
Iterator<Map.Entry<PaperCard,Integer>> cardIterator = deck.getMain().iterator();
|
||||
while (cardIterator.hasNext()){
|
||||
Map.Entry<PaperCard,Integer> entry = cardIterator.next();
|
||||
numNNZ++;
|
||||
}
|
||||
List<Integer> cardNumbers = new ArrayList<>();
|
||||
for(PaperCard card:deck.getMain().toFlatList()){
|
||||
if(cardIntegerMap.get(card.getName()) == null){
|
||||
System.out.println(card.getName() + " is missing!!");
|
||||
}
|
||||
cardNumbers.add(cardIntegerMap.get(card.getName()));
|
||||
}
|
||||
words.put(deckID,cardNumbers);
|
||||
numWords+=cardNumbers.size();
|
||||
docLength.put(deckID,cardNumbers.size());
|
||||
deckID ++;
|
||||
}
|
||||
|
||||
/*String s = null;
|
||||
while ((s = reader.readLine()) != null) {
|
||||
List<Integer> numbers
|
||||
= Arrays.asList(s.split(" ")).stream().map(Integer::parseInt).collect(Collectors.toList());
|
||||
|
||||
if (numbers.size() == 1) {
|
||||
if (headerCount == 2) numDocs = numbers.get(0);
|
||||
else if (headerCount == 1) numVocabs = numbers.get(0);
|
||||
else if (headerCount == 0) numNNZ = numbers.get(0);
|
||||
--headerCount;
|
||||
continue;
|
||||
}
|
||||
else if (numbers.size() == 3) {
|
||||
final int docID = numbers.get(0);
|
||||
final int vocabID = numbers.get(1);
|
||||
final int count = numbers.get(2);
|
||||
|
||||
// Set up the words container
|
||||
if (!words.containsKey(docID)) {
|
||||
words.put(docID, new ArrayList<>());
|
||||
}
|
||||
for (int c = 0; c < count; ++c) {
|
||||
words.get(docID).add(vocabID);
|
||||
}
|
||||
|
||||
// Set up the doc length map
|
||||
Optional<Integer> currentCount
|
||||
= Optional.ofNullable(docLength.putIfAbsent(docID, count));
|
||||
currentCount.ifPresent(c -> docLength.replace(docID, c + count));
|
||||
|
||||
numWords += count;
|
||||
}
|
||||
else {
|
||||
throw new Exception("Invalid dataset form was detected.");
|
||||
}
|
||||
}
|
||||
reader.close();*/
|
||||
|
||||
this.numDocs = numDocs;
|
||||
this.numVocabs = numVocabs;
|
||||
this.numNNZ = numNNZ;
|
||||
this.numWords = numWords;
|
||||
}
|
||||
|
||||
public int getNumDocs() {
|
||||
return numDocs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the length of the document.
|
||||
* @param docID
|
||||
* @return length of the document
|
||||
* @throws IllegalArgumentException docID <= 0 || #documents < docID
|
||||
*/
|
||||
public int getDocLength(int docID) {
|
||||
if (docID < 0 || getNumDocs() < docID) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
|
||||
return docLength.get(docID);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the unmodifiable list of words in the document.
|
||||
* @param docID
|
||||
* @return the unmodifiable list of words
|
||||
* @throws IllegalArgumentException docID <= 0 || #documents < docID
|
||||
*/
|
||||
public List<Integer> getWords(final int docID) {
|
||||
if (docID < 0 || getNumDocs() < docID) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return Collections.unmodifiableList(words.get(docID));
|
||||
}
|
||||
|
||||
public int getNumVocabs() {
|
||||
return numVocabs;
|
||||
}
|
||||
|
||||
public int getNumNNZ() {
|
||||
return numNNZ;
|
||||
}
|
||||
|
||||
public int getNumWords() {
|
||||
return numWords;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.dataset;
|
||||
|
||||
import forge.game.GameFormat;
|
||||
import forge.model.FModel;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public class Dataset {
|
||||
private BagOfWords bow;
|
||||
private Vocabularies vocabs;
|
||||
|
||||
public Dataset(GameFormat format) throws Exception {
|
||||
bow = new BagOfWords(format);
|
||||
vocabs = bow.getVocabs();
|
||||
}
|
||||
|
||||
public Dataset(BagOfWords bow) {
|
||||
this.bow = bow;
|
||||
this.vocabs = null;
|
||||
}
|
||||
|
||||
public BagOfWords getBow() {
|
||||
return bow;
|
||||
}
|
||||
|
||||
public int getNumDocs() {
|
||||
return bow.getNumDocs();
|
||||
}
|
||||
|
||||
public int getDocLength(int docID) {
|
||||
return bow.getDocLength(docID);
|
||||
}
|
||||
|
||||
public List<Integer> getWords(int docID) {
|
||||
return bow.getWords(docID);
|
||||
}
|
||||
|
||||
public int getNumVocabs() {
|
||||
return bow.getNumVocabs();
|
||||
}
|
||||
|
||||
public int getNumNNZ() {
|
||||
return bow.getNumNNZ();
|
||||
}
|
||||
|
||||
public int getNumWords() {
|
||||
return bow.getNumWords();
|
||||
}
|
||||
|
||||
public Vocabulary get(int id) {
|
||||
return vocabs.get(id);
|
||||
}
|
||||
|
||||
public int size() {
|
||||
return vocabs.size();
|
||||
}
|
||||
|
||||
public Vocabularies getVocabularies() {
|
||||
return vocabs;
|
||||
}
|
||||
|
||||
public List<Vocabulary> getVocabularyList() {
|
||||
return vocabs.getVocabularyList();
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.dataset;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Path;
|
||||
import java.nio.file.Paths;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
public class Vocabularies {
|
||||
private List<Vocabulary> vocabs;
|
||||
|
||||
public Vocabularies(List<Vocabulary> vocabs) {
|
||||
this.vocabs = vocabs;
|
||||
}
|
||||
|
||||
public Vocabulary get(int id) {
|
||||
return vocabs.get(id);
|
||||
}
|
||||
|
||||
public int size() {
|
||||
return vocabs.size();
|
||||
}
|
||||
|
||||
public List<Vocabulary> getVocabularyList() {
|
||||
return Collections.unmodifiableList(vocabs);
|
||||
}
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.dataset;
|
||||
|
||||
public class Vocabulary {
|
||||
private final int id;
|
||||
private String vocabulary;
|
||||
|
||||
public Vocabulary(int id, String vocabulary) {
|
||||
if (vocabulary == null) throw new NullPointerException();
|
||||
this.id = id;
|
||||
this.vocabulary = vocabulary;
|
||||
}
|
||||
|
||||
public int id() {
|
||||
return id;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return vocabulary;
|
||||
}
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.examples;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import com.google.common.base.Function;
|
||||
import forge.GuiBase;
|
||||
import forge.GuiDesktop;
|
||||
import forge.deck.generate.lda.lda.LDA;
|
||||
import static forge.deck.generate.lda.lda.inference.InferenceMethod.*;
|
||||
|
||||
import forge.model.FModel;
|
||||
import forge.properties.ForgePreferences;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
import forge.deck.generate.lda.dataset.Dataset;
|
||||
|
||||
public class Example {
|
||||
public static void main(String[] args) throws Exception {
|
||||
GuiBase.setInterface(new GuiDesktop());
|
||||
FModel.initialize(null, new Function<ForgePreferences, Void>() {
|
||||
@Override
|
||||
public Void apply(ForgePreferences preferences) {
|
||||
preferences.setPref(ForgePreferences.FPref.LOAD_CARD_SCRIPTS_LAZILY, false);
|
||||
return null;
|
||||
}
|
||||
});
|
||||
Dataset dataset = new Dataset(FModel.getFormats().getStandard());
|
||||
|
||||
final int numTopics = 50;
|
||||
LDA lda = new LDA(0.1, 0.1, numTopics, dataset, CGS);
|
||||
lda.run();
|
||||
System.out.println(lda.computePerplexity(dataset));
|
||||
|
||||
for (int t = 0; t < numTopics; ++t) {
|
||||
List<Pair<String, Double>> highRankVocabs = lda.getVocabsSortedByPhi(t);
|
||||
System.out.print("t" + t + ": ");
|
||||
for (int i = 0; i < 20; ++i) {
|
||||
System.out.println("[" + highRankVocabs.get(i).getLeft() + "," + highRankVocabs.get(i).getRight() + "],");
|
||||
}
|
||||
System.out.println();
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
class Alpha {
|
||||
private List<Double> alphas;
|
||||
|
||||
Alpha(double alpha, int numTopics) {
|
||||
if (alpha <= 0.0 || numTopics <= 0) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.alphas = Stream.generate(() -> alpha)
|
||||
.limit(numTopics)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
double get(int i) {
|
||||
return alphas.get(i);
|
||||
}
|
||||
|
||||
void set(int i, double value) {
|
||||
alphas.set(i, value);
|
||||
}
|
||||
|
||||
double sum() {
|
||||
return alphas.stream().reduce(Double::sum).get();
|
||||
}
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
class Beta {
|
||||
private List<Double> betas;
|
||||
|
||||
Beta(double beta, int numVocabs) {
|
||||
if (beta <= 0.0 || numVocabs <= 0) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.betas = Stream.generate(() -> beta)
|
||||
.limit(numVocabs)
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
Beta(double beta) {
|
||||
if (beta <= 0.0) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
this.betas = Arrays.asList(beta);
|
||||
}
|
||||
|
||||
double get() {
|
||||
return get(0);
|
||||
}
|
||||
|
||||
double get(int i) {
|
||||
return betas.get(i);
|
||||
}
|
||||
|
||||
void set(int i, double value) {
|
||||
betas.set(i, value);
|
||||
}
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda;
|
||||
|
||||
class Hyperparameters {
|
||||
private Alpha alpha;
|
||||
private Beta beta;
|
||||
|
||||
Hyperparameters(double alpha, double beta, int numTopics, int numVocabs) {
|
||||
this.alpha = new Alpha(alpha, numTopics);
|
||||
this.beta = new Beta(beta, numVocabs);
|
||||
}
|
||||
|
||||
Hyperparameters(double alpha, double beta, int numTopics) {
|
||||
this.alpha = new Alpha(alpha, numTopics);
|
||||
this.beta = new Beta(beta);
|
||||
}
|
||||
|
||||
double alpha(int i) {
|
||||
return alpha.get(i);
|
||||
}
|
||||
|
||||
double sumAlpha() {
|
||||
return alpha.sum();
|
||||
}
|
||||
|
||||
double beta() {
|
||||
return beta.get();
|
||||
}
|
||||
|
||||
double beta(int i) {
|
||||
return beta.get(i);
|
||||
}
|
||||
|
||||
void setAlpha(int i, double value) {
|
||||
alpha.set(i, value);
|
||||
}
|
||||
|
||||
void setBeta(int i, double value) {
|
||||
beta.set(i, value);
|
||||
}
|
||||
|
||||
void setBeta(double value) {
|
||||
beta.set(0, value);
|
||||
}
|
||||
}
|
||||
@@ -1,184 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.List;
|
||||
|
||||
import forge.deck.generate.lda.lda.inference.Inference;
|
||||
import forge.deck.generate.lda.lda.inference.InferenceFactory;
|
||||
import forge.deck.generate.lda.lda.inference.InferenceMethod;
|
||||
import forge.deck.generate.lda.lda.inference.InferenceProperties;
|
||||
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
import forge.deck.generate.lda.dataset.BagOfWords;
|
||||
import forge.deck.generate.lda.dataset.Dataset;
|
||||
import forge.deck.generate.lda.dataset.Vocabularies;
|
||||
|
||||
public class LDA {
|
||||
private Hyperparameters hyperparameters;
|
||||
private final int numTopics;
|
||||
private Dataset dataset;
|
||||
private final Inference inference;
|
||||
private InferenceProperties properties;
|
||||
private boolean trained;
|
||||
|
||||
/**
|
||||
* @param alpha doc-topic hyperparameter
|
||||
* @param beta topic-vocab hyperparameter
|
||||
* @param numTopics the number of topics
|
||||
* @param dataset dataset
|
||||
* @param bow bag-of-words
|
||||
* @param method inference method
|
||||
*/
|
||||
public LDA(final double alpha, final double beta, final int numTopics,
|
||||
final Dataset dataset, InferenceMethod method) {
|
||||
this(alpha, beta, numTopics, dataset, method, InferenceProperties.PROPERTIES_FILE_NAME);
|
||||
}
|
||||
|
||||
LDA(final double alpha, final double beta, final int numTopics,
|
||||
final Dataset dataset, InferenceMethod method, String propertiesFilePath) {
|
||||
this.hyperparameters = new Hyperparameters(alpha, beta, numTopics);
|
||||
this.numTopics = numTopics;
|
||||
this.dataset = dataset;
|
||||
this.inference = InferenceFactory.getInstance(method);
|
||||
this.properties = new InferenceProperties();
|
||||
this.trained = false;
|
||||
|
||||
properties.setSeed(123L);
|
||||
properties.setNumIteration(100);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the vocabulary from its ID.
|
||||
* @param vocabID
|
||||
* @return the vocabulary
|
||||
* @throws IllegalArgumentException vocabID <= 0 || the number of vocabularies < vocabID
|
||||
*/
|
||||
public String getVocab(int vocabID) {
|
||||
if (vocabID < 0 || dataset.getNumVocabs() < vocabID) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return dataset.get(vocabID).toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* Run model inference.
|
||||
*/
|
||||
public void run() {
|
||||
if (properties == null) inference.setUp(this);
|
||||
else inference.setUp(this, properties);
|
||||
inference.run();
|
||||
trained = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get hyperparameter alpha corresponding to topic.
|
||||
* @param topic
|
||||
* @return alpha corresponding to topicID
|
||||
* @throws ArrayIndexOutOfBoundsException topic < 0 || #topics <= topic
|
||||
*/
|
||||
public double getAlpha(final int topic) {
|
||||
if (topic < 0 || numTopics < topic) {
|
||||
throw new ArrayIndexOutOfBoundsException(topic);
|
||||
}
|
||||
return hyperparameters.alpha(topic);
|
||||
}
|
||||
|
||||
public double getSumAlpha() {
|
||||
return hyperparameters.sumAlpha();
|
||||
}
|
||||
|
||||
public double getBeta() {
|
||||
return hyperparameters.beta();
|
||||
}
|
||||
|
||||
public int getNumTopics() {
|
||||
return numTopics;
|
||||
}
|
||||
|
||||
public BagOfWords getBow() {
|
||||
return dataset.getBow();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the value of doc-topic probability \theta_{docID, topicID}.
|
||||
* @param docID
|
||||
* @param topicID
|
||||
* @return the value of doc-topic probability
|
||||
* @throws IllegalArgumentException docID <= 0 || #docs < docID || topicID < 0 || #topics <= topicID
|
||||
* @throws IllegalStateException call this method when the inference has not been finished yet
|
||||
*/
|
||||
public double getTheta(final int docID, final int topicID) {
|
||||
if (docID < 0 || dataset.getNumDocs() < docID
|
||||
|| topicID < 0 || numTopics < topicID) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if (!trained) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
|
||||
return inference.getTheta(docID, topicID);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the value of topic-vocab probability \phi_{topicID, vocabID}.
|
||||
* @param topicID
|
||||
* @param vocabID
|
||||
* @return the value of topic-vocab probability
|
||||
* @throws IllegalArgumentException topicID < 0 || #topics <= topicID || vocabID <= 0
|
||||
* @throws IllegalStateException call this method when the inference has not been finished yet
|
||||
*/
|
||||
public double getPhi(final int topicID, final int vocabID) {
|
||||
if (topicID < 0 || numTopics < topicID || vocabID < 0) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if (!trained) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
|
||||
return inference.getPhi(topicID, vocabID);
|
||||
}
|
||||
|
||||
public Vocabularies getVocabularies() {
|
||||
return dataset.getVocabularies();
|
||||
}
|
||||
|
||||
public List<Pair<String, Double>> getVocabsSortedByPhi(int topicID) {
|
||||
return inference.getVocabsSortedByPhi(topicID);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the perplexity of trained LDA for the test bag-of-words dataset.
|
||||
* @param testDataset
|
||||
* @return the perplexity for the test bag-of-words dataset
|
||||
*/
|
||||
public double computePerplexity(Dataset testDataset) {
|
||||
double loglikelihood = 0.0;
|
||||
for (int d = 0; d < testDataset.getNumDocs(); ++d) {
|
||||
for (Integer w : testDataset.getWords(d)) {
|
||||
double sum = 0.0;
|
||||
for (int t = 0; t < getNumTopics(); ++t) {
|
||||
sum += getTheta(d, t) * getPhi(t, w.intValue());
|
||||
}
|
||||
loglikelihood += Math.log(sum);
|
||||
}
|
||||
}
|
||||
return Math.exp(-loglikelihood / testDataset.getNumWords());
|
||||
}
|
||||
}
|
||||
@@ -1,62 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import forge.deck.generate.lda.lda.LDA;
|
||||
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
public interface Inference {
|
||||
/**
|
||||
* Set up for inference.
|
||||
* @param lda
|
||||
*/
|
||||
public void setUp(LDA lda);
|
||||
|
||||
/**
|
||||
* Set up for inference.
|
||||
* The configuration is read from properties class.
|
||||
* @param lda
|
||||
* @param properties
|
||||
*/
|
||||
public void setUp(LDA lda, InferenceProperties properties);
|
||||
|
||||
/**
|
||||
* Run model inference.
|
||||
*/
|
||||
public void run();
|
||||
|
||||
/**
|
||||
* Get the value of doc-topic probability \theta_{docID, topicID}.
|
||||
* @param docID
|
||||
* @param topicID
|
||||
* @return the value of doc-topic probability
|
||||
*/
|
||||
public double getTheta(final int docID, final int topicID);
|
||||
|
||||
/**
|
||||
* Get the value of topic-vocab probability \phi_{topicID, vocabID}.
|
||||
* @param topicID
|
||||
* @param vocabID
|
||||
* @return the value of topic-vocab probability
|
||||
*/
|
||||
public double getPhi(final int topicID, final int vocabID);
|
||||
|
||||
public List<Pair<String, Double>> getVocabsSortedByPhi(int topicID);
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference;
|
||||
|
||||
|
||||
public class InferenceFactory {
|
||||
private InferenceFactory() {}
|
||||
|
||||
/**
|
||||
* Get the LDAInference instance specified by the argument
|
||||
* @param method
|
||||
* @return the instance which implements LDAInference
|
||||
*/
|
||||
public static Inference getInstance(InferenceMethod method) {
|
||||
Inference clazz = null;
|
||||
try {
|
||||
clazz = (Inference)Class.forName(method.toString()).newInstance();
|
||||
} catch (ReflectiveOperationException roe) {
|
||||
roe.printStackTrace();
|
||||
}
|
||||
return clazz;
|
||||
}
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference;
|
||||
|
||||
public enum InferenceMethod {
|
||||
CGS("forge.deck.generate.lda.lda.inference.internal.CollapsedGibbsSampler"),
|
||||
// more
|
||||
;
|
||||
|
||||
private String className;
|
||||
|
||||
private InferenceMethod(final String className) {
|
||||
this.className = className;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return className;
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference;
|
||||
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.IOException;
|
||||
import java.io.InputStream;
|
||||
import java.util.Properties;
|
||||
|
||||
public class InferenceProperties {
|
||||
public static String PROPERTIES_FILE_NAME = "lda.properties";
|
||||
|
||||
PropertiesLoader loader = new PropertiesLoader();
|
||||
private Properties properties;
|
||||
|
||||
public InferenceProperties() {
|
||||
this.properties = new Properties();
|
||||
}
|
||||
|
||||
public void setSeed(Long seed){
|
||||
properties.setProperty("seed",seed.toString());
|
||||
}
|
||||
|
||||
public void setNumIteration(Integer numIteration){
|
||||
properties.setProperty("numIteration",numIteration.toString());
|
||||
}
|
||||
|
||||
/**
|
||||
* Load properties.
|
||||
* @param fileName
|
||||
* @throws IOException
|
||||
* @throws NullPointerException fileName is null
|
||||
*/
|
||||
public void load(String fileName) throws IOException {
|
||||
if (fileName == null) throw new NullPointerException();
|
||||
InputStream stream = loader.getInputStream(fileName);
|
||||
if (stream == null) throw new NullPointerException();
|
||||
properties.load(stream);
|
||||
}
|
||||
|
||||
public Long seed() {
|
||||
return Long.parseLong(properties.getProperty("seed"));
|
||||
}
|
||||
|
||||
public Integer numIteration() {
|
||||
return Integer.parseInt(properties.getProperty("numIteration"));
|
||||
}
|
||||
}
|
||||
|
||||
class PropertiesLoader {
|
||||
public InputStream getInputStream(String fileName) throws FileNotFoundException {
|
||||
if (fileName == null) throw new NullPointerException();
|
||||
return new FileInputStream(fileName);
|
||||
}
|
||||
}
|
||||
@@ -1,65 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference.internal;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
class AssignmentCounter {
|
||||
private List<Integer> counter;
|
||||
|
||||
AssignmentCounter(int size) {
|
||||
if (size <= 0) throw new IllegalArgumentException();
|
||||
this.counter = IntStream.generate(() -> 0)
|
||||
.limit(size)
|
||||
.boxed()
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
int size() {
|
||||
return counter.size();
|
||||
}
|
||||
|
||||
int get(int id) {
|
||||
if (id < 0 || counter.size() <= id) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return counter.get(id);
|
||||
}
|
||||
|
||||
int getSum() {
|
||||
return counter.stream().reduce(Integer::sum).get();
|
||||
}
|
||||
|
||||
void increment(int id) {
|
||||
if (id < 0 || counter.size() <= id) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
counter.set(id, counter.get(id) + 1);
|
||||
}
|
||||
|
||||
void decrement(int id) {
|
||||
if (id < 0 || counter.size() <= id) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
if (counter.get(id) == 0) {
|
||||
throw new IllegalStateException();
|
||||
}
|
||||
counter.set(id, counter.get(id) - 1);
|
||||
}
|
||||
}
|
||||
@@ -1,220 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference.internal;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.stream.IntStream;
|
||||
|
||||
import forge.deck.generate.lda.lda.LDA;
|
||||
import forge.deck.generate.lda.lda.inference.Inference;
|
||||
import forge.deck.generate.lda.lda.inference.InferenceProperties;
|
||||
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
|
||||
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||
|
||||
import forge.deck.generate.lda.dataset.Vocabulary;
|
||||
|
||||
public class CollapsedGibbsSampler implements Inference {
|
||||
private LDA lda;
|
||||
private Topics topics;
|
||||
private Documents documents;
|
||||
private int numIteration;
|
||||
|
||||
private static final long DEFAULT_SEED = 0L;
|
||||
private static final int DEFAULT_NUM_ITERATION = 100;
|
||||
|
||||
// ready for Gibbs sampling
|
||||
private boolean ready;
|
||||
|
||||
public CollapsedGibbsSampler() {
|
||||
ready = false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setUp(LDA lda, InferenceProperties properties) {
|
||||
if (properties == null) {
|
||||
setUp(lda);
|
||||
return;
|
||||
}
|
||||
|
||||
this.lda = lda;
|
||||
initialize(this.lda);
|
||||
|
||||
final long seed = properties.seed() != null ? properties.seed() : DEFAULT_SEED;
|
||||
initializeTopicAssignment(seed);
|
||||
|
||||
this.numIteration
|
||||
= properties.numIteration() != null ? properties.numIteration() : DEFAULT_NUM_ITERATION;
|
||||
this.ready = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setUp(LDA lda) {
|
||||
if (lda == null) throw new NullPointerException();
|
||||
|
||||
this.lda = lda;
|
||||
|
||||
initialize(this.lda);
|
||||
initializeTopicAssignment(DEFAULT_SEED);
|
||||
|
||||
this.numIteration = DEFAULT_NUM_ITERATION;
|
||||
this.ready = true;
|
||||
}
|
||||
|
||||
private void initialize(LDA lda) {
|
||||
assert lda != null;
|
||||
this.topics = new Topics(lda);
|
||||
this.documents = new Documents(lda);
|
||||
}
|
||||
|
||||
public boolean isReady() {
|
||||
return ready;
|
||||
}
|
||||
|
||||
public int getNumIteration() {
|
||||
return numIteration;
|
||||
}
|
||||
|
||||
public void setNumIteration(final int numIteration) {
|
||||
this.numIteration = numIteration;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
if (!ready) {
|
||||
throw new IllegalStateException("instance has not set up yet");
|
||||
}
|
||||
|
||||
for (int i = 1; i <= numIteration; ++i) {
|
||||
System.out.println("Iteraion " + i + ".");
|
||||
runSampling();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run collapsed Gibbs sampling [Griffiths and Steyvers 2004].
|
||||
*/
|
||||
void runSampling() {
|
||||
for (Document d : documents.getDocuments()) {
|
||||
for (int w = 0; w < d.getDocLength(); ++w) {
|
||||
final Topic oldTopic = topics.get(d.getTopicID(w));
|
||||
d.decrementTopicCount(oldTopic.id());
|
||||
|
||||
final Vocabulary v = d.getVocabulary(w);
|
||||
oldTopic.decrementVocabCount(v.id());
|
||||
|
||||
IntegerDistribution distribution
|
||||
= getFullConditionalDistribution(lda.getNumTopics(), d.id(), v.id());
|
||||
|
||||
final int newTopicID = distribution.sample();
|
||||
d.setTopicID(w, newTopicID);
|
||||
|
||||
d.incrementTopicCount(newTopicID);
|
||||
final Topic newTopic = topics.get(newTopicID);
|
||||
newTopic.incrementVocabCount(v.id());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the full conditional distribution over topics.
|
||||
* docID and vocabID are passed to this distribution for parameters.
|
||||
* @param numTopics
|
||||
* @param docID
|
||||
* @param vocabID
|
||||
* @return the integer distribution over topics
|
||||
*/
|
||||
IntegerDistribution getFullConditionalDistribution(final int numTopics, final int docID, final int vocabID) {
|
||||
int[] topics = IntStream.range(0, numTopics).toArray();
|
||||
double[] probabilities = Arrays.stream(topics)
|
||||
.mapToDouble(t -> getTheta(docID, t) * getPhi(t, vocabID))
|
||||
.toArray();
|
||||
return new EnumeratedIntegerDistribution(topics, probabilities);
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the topic assignment.
|
||||
* @param seed the seed of a pseudo random number generator
|
||||
*/
|
||||
void initializeTopicAssignment(final long seed) {
|
||||
documents.initializeTopicAssignment(topics, seed);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the count of topicID assigned to docID.
|
||||
* @param docID
|
||||
* @param topicID
|
||||
* @return the count of topicID assigned to docID
|
||||
*/
|
||||
int getDTCount(final int docID, final int topicID) {
|
||||
if (!ready) throw new IllegalStateException();
|
||||
if (docID <= 0 || lda.getBow().getNumDocs() < docID
|
||||
|| topicID < 0 || lda.getNumTopics() <= topicID) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return documents.getTopicCount(docID, topicID);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the count of vocabID assigned to topicID.
|
||||
* @param topicID
|
||||
* @param vocabID
|
||||
* @return the count of vocabID assigned to topicID
|
||||
*/
|
||||
int getTVCount(final int topicID, final int vocabID) {
|
||||
if (!ready) throw new IllegalStateException();
|
||||
if (topicID < 0 || lda.getNumTopics() <= topicID || vocabID <= 0) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
final Topic topic = topics.get(topicID);
|
||||
return topic.getVocabCount(vocabID);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the sum of counts of vocabs assigned to topicID.
|
||||
* This is the sum of topic-vocab count over vocabs.
|
||||
* @param topicID
|
||||
* @return the sum of counts of vocabs assigned to topicID
|
||||
* @throws IllegalArgumentException topicID < 0 || #topic <= topicID
|
||||
*/
|
||||
int getTSumCount(final int topicID) {
|
||||
if (topicID < 0 || lda.getNumTopics() <= topicID) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
final Topic topic = topics.get(topicID);
|
||||
return topic.getSumCount();
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getTheta(final int docID, final int topicID) {
|
||||
if (!ready) throw new IllegalStateException();
|
||||
return documents.getTheta(docID, topicID, lda.getAlpha(topicID), lda.getSumAlpha());
|
||||
}
|
||||
|
||||
@Override
|
||||
public double getPhi(int topicID, int vocabID) {
|
||||
if (!ready) throw new IllegalStateException();
|
||||
return topics.getPhi(topicID, vocabID, lda.getBeta());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Pair<String, Double>> getVocabsSortedByPhi(int topicID) {
|
||||
return topics.getVocabsSortedByPhi(topicID, lda.getVocabularies(), lda.getBeta());
|
||||
}
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference.internal;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import forge.deck.generate.lda.dataset.Vocabulary;
|
||||
|
||||
public class Document {
|
||||
private final int id;
|
||||
private TopicCounter topicCount;
|
||||
private Words words;
|
||||
private TopicAssignment assignment;
|
||||
|
||||
Document(int id, int numTopics, List<Vocabulary> words) {
|
||||
if (id < 0 || numTopics <= 0) throw new IllegalArgumentException();
|
||||
this.id = id;
|
||||
this.topicCount = new TopicCounter(numTopics);
|
||||
this.words = new Words(words);
|
||||
this.assignment = new TopicAssignment();
|
||||
}
|
||||
|
||||
int id() {
|
||||
return id;
|
||||
}
|
||||
|
||||
int getTopicCount(int topicID) {
|
||||
return topicCount.getTopicCount(topicID);
|
||||
}
|
||||
|
||||
int getDocLength() {
|
||||
return words.getNumWords();
|
||||
}
|
||||
|
||||
void incrementTopicCount(int topicID) {
|
||||
topicCount.incrementTopicCount(topicID);
|
||||
}
|
||||
|
||||
void decrementTopicCount(int topicID) {
|
||||
topicCount.decrementTopicCount(topicID);
|
||||
}
|
||||
|
||||
void initializeTopicAssignment(long seed) {
|
||||
assignment.initialize(getDocLength(), topicCount.size(), seed);
|
||||
for (int w = 0; w < getDocLength(); ++w) {
|
||||
incrementTopicCount(assignment.get(w));
|
||||
}
|
||||
}
|
||||
|
||||
int getTopicID(int wordID) {
|
||||
return assignment.get(wordID);
|
||||
}
|
||||
|
||||
void setTopicID(int wordID, int topicID) {
|
||||
assignment.set(wordID, topicID);
|
||||
}
|
||||
|
||||
Vocabulary getVocabulary(int wordID) {
|
||||
return words.get(wordID);
|
||||
}
|
||||
|
||||
List<Vocabulary> getWords() {
|
||||
return words.getWords();
|
||||
}
|
||||
|
||||
double getTheta(int topicID, double alpha, double sumAlpha) {
|
||||
if (topicID < 0 || alpha <= 0.0 || sumAlpha <= 0.0) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return (getTopicCount(topicID) + alpha) / (getDocLength() + sumAlpha);
|
||||
}
|
||||
}
|
||||
@@ -1,100 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference.internal;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import forge.deck.generate.lda.lda.LDA;
|
||||
import forge.deck.generate.lda.dataset.BagOfWords;
|
||||
import forge.deck.generate.lda.dataset.Vocabularies;
|
||||
import forge.deck.generate.lda.dataset.Vocabulary;
|
||||
|
||||
class Documents {
|
||||
private List<Document> documents;
|
||||
|
||||
Documents(LDA lda) {
|
||||
if (lda == null) throw new NullPointerException();
|
||||
|
||||
documents = new ArrayList<>();
|
||||
for (int d = 0; d < lda.getBow().getNumDocs(); ++d) {
|
||||
List<Vocabulary> vocabList = getVocabularyList(d, lda.getBow(), lda.getVocabularies());
|
||||
Document doc = new Document(d, lda.getNumTopics(), vocabList);
|
||||
documents.add(doc);
|
||||
}
|
||||
}
|
||||
|
||||
List<Vocabulary> getVocabularyList(int docID, BagOfWords bow, Vocabularies vocabs) {
|
||||
assert docID > 0 && bow != null && vocabs != null;
|
||||
System.out.println(docID);
|
||||
System.out.println(bow.getWords(docID).toString());
|
||||
return bow.getWords(docID).stream()
|
||||
.map(id -> vocabs.get(id))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
int getTopicID(int docID, int wordID) {
|
||||
return documents.get(docID).getTopicID(wordID);
|
||||
}
|
||||
|
||||
void setTopicID(int docID, int wordID, int topicID) {
|
||||
documents.get(docID).setTopicID(wordID, topicID);
|
||||
}
|
||||
|
||||
Vocabulary getVocab(int docID, int wordID) {
|
||||
return documents.get(docID).getVocabulary(wordID);
|
||||
}
|
||||
|
||||
List<Vocabulary> getWords(int docID) {
|
||||
return documents.get(docID).getWords();
|
||||
}
|
||||
|
||||
List<Document> getDocuments() {
|
||||
return Collections.unmodifiableList(documents);
|
||||
}
|
||||
|
||||
void incrementTopicCount(int docID, int topicID) {
|
||||
documents.get(docID).incrementTopicCount(topicID);
|
||||
}
|
||||
|
||||
void decrementTopicCount(int docID, int topicID) {
|
||||
documents.get(docID).decrementTopicCount(topicID);
|
||||
}
|
||||
|
||||
int getTopicCount(int docID, int topicID) {
|
||||
return documents.get(docID).getTopicCount(topicID);
|
||||
}
|
||||
|
||||
double getTheta(int docID, int topicID, double alpha, double sumAlpha) {
|
||||
if (docID < 0 || documents.size() < docID) throw new IllegalArgumentException();
|
||||
return documents.get(docID).getTheta(topicID, alpha, sumAlpha);
|
||||
}
|
||||
|
||||
void initializeTopicAssignment(Topics topics, long seed) {
|
||||
for (Document d : getDocuments()) {
|
||||
d.initializeTopicAssignment(seed);
|
||||
for (int w = 0; w < d.getDocLength(); ++w) {
|
||||
final int topicID = d.getTopicID(w);
|
||||
final Topic topic = topics.get(topicID);
|
||||
final Vocabulary vocab = d.getVocabulary(w);
|
||||
topic.incrementVocabCount(vocab.id());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference.internal;
|
||||
|
||||
class Topic {
|
||||
private final int id;
|
||||
private final int numVocabs;
|
||||
private VocabularyCounter counter;
|
||||
|
||||
Topic(int id, int numVocabs) {
|
||||
if (id < 0 || numVocabs <= 0) throw new IllegalArgumentException();
|
||||
this.id = id;
|
||||
this.numVocabs = numVocabs;
|
||||
this.counter = new VocabularyCounter(numVocabs);
|
||||
}
|
||||
|
||||
int id() {
|
||||
return id;
|
||||
}
|
||||
|
||||
int getVocabCount(int vocabID) {
|
||||
return counter.getVocabCount(vocabID);
|
||||
}
|
||||
|
||||
int getSumCount() {
|
||||
return counter.getSumCount();
|
||||
}
|
||||
|
||||
void incrementVocabCount(int vocabID) {
|
||||
counter.incrementVocabCount(vocabID);
|
||||
}
|
||||
|
||||
void decrementVocabCount(int vocabID) {
|
||||
counter.decrementVocabCount(vocabID);
|
||||
}
|
||||
|
||||
double getPhi(int vocabID, double beta) {
|
||||
if (vocabID < 0 || beta <= 0) throw new IllegalArgumentException();
|
||||
return (getVocabCount(vocabID) + beta) / (getSumCount() + beta * numVocabs);
|
||||
}
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference.internal;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
class TopicAssignment {
|
||||
private List<Integer> topicAssignment;
|
||||
private boolean ready;
|
||||
|
||||
TopicAssignment() {
|
||||
topicAssignment = new ArrayList<>();
|
||||
ready = false;
|
||||
}
|
||||
|
||||
void set(int wordID, int topicID) {
|
||||
if (!ready) throw new IllegalStateException();
|
||||
if (wordID < 0 || topicAssignment.size() <= wordID || topicID < 0) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
topicAssignment.set(wordID, topicID);
|
||||
}
|
||||
|
||||
int get(int wordID) {
|
||||
if (!ready) throw new IllegalStateException();
|
||||
if (wordID < 0 || topicAssignment.size() <= wordID) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return topicAssignment.get(wordID);
|
||||
}
|
||||
|
||||
void initialize(int docLength, int numTopics, long seed) {
|
||||
if (docLength <= 0 || numTopics <= 0) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
|
||||
Random random = new Random(seed);
|
||||
topicAssignment = random.ints(docLength, 0, numTopics)
|
||||
.boxed()
|
||||
.collect(Collectors.toList());
|
||||
ready = true;
|
||||
}
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference.internal;
|
||||
|
||||
class TopicCounter {
|
||||
private AssignmentCounter topicCount;
|
||||
|
||||
TopicCounter(int numTopics) {
|
||||
this.topicCount = new AssignmentCounter(numTopics);
|
||||
}
|
||||
|
||||
int getTopicCount(int topicID) {
|
||||
return topicCount.get(topicID);
|
||||
}
|
||||
|
||||
int getDocLength() {
|
||||
return topicCount.getSum();
|
||||
}
|
||||
|
||||
void incrementTopicCount(int topicID) {
|
||||
topicCount.increment(topicID);
|
||||
}
|
||||
|
||||
void decrementTopicCount(int topicID) {
|
||||
topicCount.decrement(topicID);
|
||||
}
|
||||
|
||||
int size() {
|
||||
return topicCount.size();
|
||||
}
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference.internal;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import forge.deck.generate.lda.lda.LDA;
|
||||
|
||||
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||
import org.apache.commons.lang3.tuple.Pair;
|
||||
|
||||
import forge.deck.generate.lda.dataset.Vocabularies;
|
||||
|
||||
class Topics {
|
||||
private List<Topic> topics;
|
||||
|
||||
Topics(LDA lda) {
|
||||
if (lda == null) throw new NullPointerException();
|
||||
|
||||
topics = new ArrayList<>();
|
||||
for (int t = 0; t < lda.getNumTopics(); ++t) {
|
||||
topics.add(new Topic(t, lda.getBow().getNumVocabs()));
|
||||
}
|
||||
}
|
||||
|
||||
int numTopics() {
|
||||
return topics.size();
|
||||
}
|
||||
|
||||
Topic get(int id) {
|
||||
return topics.get(id);
|
||||
}
|
||||
|
||||
int getVocabCount(int topicID, int vocabID) {
|
||||
return topics.get(topicID).getVocabCount(vocabID);
|
||||
}
|
||||
|
||||
int getSumCount(int topicID) {
|
||||
return topics.get(topicID).getSumCount();
|
||||
}
|
||||
|
||||
void incrementVocabCount(int topicID, int vocabID) {
|
||||
topics.get(topicID).incrementVocabCount(vocabID);
|
||||
}
|
||||
|
||||
void decrementVocabCount(int topicID, int vocabID) {
|
||||
topics.get(topicID).decrementVocabCount(vocabID);
|
||||
}
|
||||
|
||||
double getPhi(int topicID, int vocabID, double beta) {
|
||||
if (topicID < 0 || topics.size() <= topicID) throw new IllegalArgumentException();
|
||||
return topics.get(topicID).getPhi(vocabID, beta);
|
||||
}
|
||||
|
||||
List<Pair<String, Double>> getVocabsSortedByPhi(int topicID, Vocabularies vocabs, final double beta) {
|
||||
if (topicID < 0 || topics.size() <= topicID || vocabs == null || beta <= 0.0) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
|
||||
Topic topic = topics.get(topicID);
|
||||
List<Pair<String, Double>> vocabProbPairs
|
||||
= vocabs.getVocabularyList()
|
||||
.stream()
|
||||
.map(v -> new ImmutablePair<String, Double>(v.toString(), topic.getPhi(v.id(), beta)))
|
||||
.sorted((p1, p2) -> Double.compare(p2.getRight(), p1.getRight()))
|
||||
.collect(Collectors.toList());
|
||||
return Collections.unmodifiableList(vocabProbPairs);
|
||||
}
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference.internal;
|
||||
|
||||
class VocabularyCounter {
|
||||
private AssignmentCounter vocabCount;
|
||||
private int sumCount;
|
||||
|
||||
VocabularyCounter(int numVocabs) {
|
||||
this.vocabCount = new AssignmentCounter(numVocabs);
|
||||
this.sumCount = 0;
|
||||
}
|
||||
|
||||
int getVocabCount(int vocabID) {
|
||||
if (vocabCount.size() < vocabID) return 0;
|
||||
else return vocabCount.get(vocabID);
|
||||
}
|
||||
|
||||
int getSumCount() {
|
||||
return sumCount;
|
||||
}
|
||||
|
||||
void incrementVocabCount(int vocabID) {
|
||||
vocabCount.increment(vocabID);
|
||||
++sumCount;
|
||||
}
|
||||
|
||||
void decrementVocabCount(int vocabID) {
|
||||
vocabCount.decrement(vocabID);
|
||||
--sumCount;
|
||||
}
|
||||
}
|
||||
@@ -1,46 +0,0 @@
|
||||
/*
|
||||
* Copyright 2015 Kohei Yamamoto
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package forge.deck.generate.lda.lda.inference.internal;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import forge.deck.generate.lda.dataset.Vocabulary;
|
||||
|
||||
class Words {
|
||||
private List<Vocabulary> words;
|
||||
|
||||
Words(List<Vocabulary> words) {
|
||||
if (words == null) throw new NullPointerException();
|
||||
this.words = words;
|
||||
}
|
||||
|
||||
int getNumWords() {
|
||||
return words.size();
|
||||
}
|
||||
|
||||
Vocabulary get(int id) {
|
||||
if (id < 0 || words.size() <= id) {
|
||||
throw new IllegalArgumentException();
|
||||
}
|
||||
return words.get(id);
|
||||
}
|
||||
|
||||
List<Vocabulary> getWords() {
|
||||
return Collections.unmodifiableList(words);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
package forge.planarconquestgenerate;
|
||||
|
||||
import com.google.common.base.Function;
|
||||
import com.google.common.base.Predicates;
|
||||
import com.google.common.collect.Iterables;
|
||||
import com.google.common.collect.Lists;
|
||||
import forge.GuiBase;
|
||||
import forge.GuiDesktop;
|
||||
import forge.LobbyPlayer;
|
||||
import forge.StaticData;
|
||||
import forge.card.CardRulesPredicates;
|
||||
import forge.deck.*;
|
||||
import forge.deck.io.DeckStorage;
|
||||
import forge.game.GameFormat;
|
||||
import forge.game.GameRules;
|
||||
import forge.game.GameType;
|
||||
import forge.game.Match;
|
||||
import forge.game.player.RegisteredPlayer;
|
||||
import forge.item.PaperCard;
|
||||
import forge.limited.CardRanker;
|
||||
import forge.model.FModel;
|
||||
import forge.player.GamePlayerUtil;
|
||||
import forge.properties.ForgeConstants;
|
||||
import forge.properties.ForgePreferences;
|
||||
import forge.tournament.system.AbstractTournament;
|
||||
import forge.tournament.system.TournamentPairing;
|
||||
import forge.tournament.system.TournamentPlayer;
|
||||
import forge.tournament.system.TournamentSwiss;
|
||||
import forge.util.AbstractGeneticAlgorithm;
|
||||
import forge.util.MyRandom;
|
||||
import forge.util.TextUtil;
|
||||
import forge.view.SimulateMatch;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class PlanarConquestCommanderGeneraterGA extends PlanarConquestGeneraterGA {
|
||||
|
||||
|
||||
private int deckCount = 0;
|
||||
|
||||
public static void main(String[] args){
|
||||
test();
|
||||
}
|
||||
|
||||
public static void test(){
|
||||
|
||||
GuiBase.setInterface(new GuiDesktop());
|
||||
FModel.initialize(null, new Function<ForgePreferences, Void>() {
|
||||
@Override
|
||||
public Void apply(ForgePreferences preferences) {
|
||||
preferences.setPref(ForgePreferences.FPref.LOAD_CARD_SCRIPTS_LAZILY, false);
|
||||
return null;
|
||||
}
|
||||
});
|
||||
List<String> sets = new ArrayList<>();
|
||||
sets.add("XLN");
|
||||
sets.add("RIX");
|
||||
|
||||
PlanarConquestCommanderGeneraterGA ga = new PlanarConquestCommanderGeneraterGA(new GameRules(GameType.Constructed),
|
||||
new GameFormat("conquest",sets,null),
|
||||
DeckFormat.PlanarConquest,
|
||||
4,
|
||||
12,
|
||||
16);
|
||||
ga.run();
|
||||
List<Deck> winners = ga.listFinalPopulation();
|
||||
|
||||
DeckStorage storage = new DeckStorage(new File(ForgeConstants.DECK_CONSTRUCTED_DIR), ForgeConstants.DECK_BASE_DIR);
|
||||
int i=1;
|
||||
for(Deck deck:winners){
|
||||
storage.save(new Deck(deck,"GA"+i+"_"+deck.getName()));
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
public PlanarConquestCommanderGeneraterGA(GameRules rules, GameFormat gameFormat, DeckFormat deckFormat, int cardsToUse, int decksPerCard, int generations){
|
||||
super(rules,gameFormat,deckFormat,cardsToUse,decksPerCard,generations);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void initializeCards(){
|
||||
standardMap = CardRelationMatrixGenerator.cardPools.get(format.getName());
|
||||
List<String> cardNames = new ArrayList<>(standardMap.keySet());
|
||||
List<PaperCard> cards = new ArrayList<>();
|
||||
for(String cardName:cardNames){
|
||||
cards.add(StaticData.instance().getCommonCards().getUniqueByName(cardName));
|
||||
}
|
||||
|
||||
Iterable<PaperCard> filtered= Iterables.filter(cards, Predicates.and(
|
||||
Predicates.compose(CardRulesPredicates.IS_KEPT_IN_AI_DECKS, PaperCard.FN_GET_RULES),
|
||||
Predicates.compose(CardRulesPredicates.Presets.IS_PLANESWALKER, PaperCard.FN_GET_RULES),
|
||||
//Predicates.compose(CardRulesPredicates.Presets.IS_LEGENDARY, PaperCard.FN_GET_RULES),
|
||||
gameFormat.getFilterPrinted()));
|
||||
|
||||
List<PaperCard> filteredList = Lists.newArrayList(filtered);
|
||||
rankedList = CardRanker.rankCardsInDeck(filteredList);
|
||||
List<Deck> decks = new ArrayList<>();
|
||||
for(PaperCard card: rankedList.subList(0,cardsToUse)){
|
||||
System.out.println(card.getName());
|
||||
for( int i=0; i<decksPerCard;i++){
|
||||
decks.add(getDeckForCard(card));
|
||||
}
|
||||
}
|
||||
initializePopulation(decks);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Deck getDeckForCard(PaperCard card){
|
||||
Deck genDeck = DeckgenUtil.buildPlanarConquestCommanderDeck(card, gameFormat, deckFormat);
|
||||
Deck d = new Deck(genDeck,genDeck.getName()+"_"+deckCount+"_"+generationCount);
|
||||
deckCount++;
|
||||
return d;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Deck getDeckForCard(PaperCard card, PaperCard card2){
|
||||
return getDeckForCard(card);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Deck mutateObject(Deck parent1) {
|
||||
PaperCard allele = parent1.getCommanders().get(0);
|
||||
if(!standardMap.keySet().contains(allele.getName())){
|
||||
return null;
|
||||
}
|
||||
return getDeckForCard(allele);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Deck createChild(Deck parent1, Deck parent2) {
|
||||
PaperCard allele = parent1.getCommanders().get(0);
|
||||
return getDeckForCard(allele);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,277 @@
|
||||
package forge.planarconquestgenerate;
|
||||
|
||||
import com.google.common.base.Function;
|
||||
import com.google.common.base.Predicates;
|
||||
import com.google.common.collect.Iterables;
|
||||
import com.google.common.collect.Lists;
|
||||
import forge.GuiBase;
|
||||
import forge.GuiDesktop;
|
||||
import forge.LobbyPlayer;
|
||||
import forge.StaticData;
|
||||
import forge.card.CardRulesPredicates;
|
||||
import forge.deck.*;
|
||||
import forge.deck.io.DeckStorage;
|
||||
import forge.game.GameFormat;
|
||||
import forge.game.GameRules;
|
||||
import forge.game.GameType;
|
||||
import forge.game.Match;
|
||||
import forge.game.player.RegisteredPlayer;
|
||||
import forge.item.PaperCard;
|
||||
import forge.limited.CardRanker;
|
||||
import forge.model.FModel;
|
||||
import forge.player.GamePlayerUtil;
|
||||
import forge.properties.ForgeConstants;
|
||||
import forge.properties.ForgePreferences;
|
||||
import forge.tournament.system.AbstractTournament;
|
||||
import forge.tournament.system.TournamentPairing;
|
||||
import forge.tournament.system.TournamentPlayer;
|
||||
import forge.tournament.system.TournamentSwiss;
|
||||
import forge.util.AbstractGeneticAlgorithm;
|
||||
import forge.util.MyRandom;
|
||||
import forge.util.TextUtil;
|
||||
import forge.view.SimulateMatch;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public class PlanarConquestGeneraterGA extends AbstractGeneticAlgorithm<Deck> {
|
||||
|
||||
private DeckGroup deckGroup;
|
||||
private List<TournamentPlayer> players = new ArrayList<>();
|
||||
private TournamentSwiss tourney = null;
|
||||
protected Map<String,List<Map.Entry<PaperCard,Integer>>> standardMap;
|
||||
private GameRules rules;
|
||||
protected GameFormat format = FModel.getFormats().getStandard();
|
||||
protected int generations;
|
||||
protected GameFormat gameFormat;
|
||||
protected DeckFormat deckFormat;
|
||||
protected int cardsToUse;
|
||||
protected int decksPerCard;
|
||||
private int deckCount = 0;
|
||||
protected List<PaperCard> rankedList;
|
||||
|
||||
public static void main(String[] args){
|
||||
test();
|
||||
}
|
||||
|
||||
public static void test(){
|
||||
|
||||
GuiBase.setInterface(new GuiDesktop());
|
||||
FModel.initialize(null, new Function<ForgePreferences, Void>() {
|
||||
@Override
|
||||
public Void apply(ForgePreferences preferences) {
|
||||
preferences.setPref(ForgePreferences.FPref.LOAD_CARD_SCRIPTS_LAZILY, false);
|
||||
return null;
|
||||
}
|
||||
});
|
||||
List<String> sets = new ArrayList<>();
|
||||
sets.add("XLN");
|
||||
sets.add("RIX");
|
||||
|
||||
PlanarConquestGeneraterGA ga = new PlanarConquestGeneraterGA(new GameRules(GameType.Constructed),
|
||||
new GameFormat("conquest",sets,null),
|
||||
DeckFormat.PlanarConquest,
|
||||
40,
|
||||
4,
|
||||
10);
|
||||
ga.run();
|
||||
List<Deck> winners = ga.listFinalPopulation();
|
||||
|
||||
DeckStorage storage = new DeckStorage(new File(ForgeConstants.DECK_CONSTRUCTED_DIR), ForgeConstants.DECK_BASE_DIR);
|
||||
int i=1;
|
||||
for(Deck deck:winners){
|
||||
storage.save(new Deck(deck,"GA"+i+"_"+deck.getName()));
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
public PlanarConquestGeneraterGA(GameRules rules, GameFormat gameFormat, DeckFormat deckFormat, int cardsToUse, int decksPerCard, int generations){
|
||||
this.rules = rules;
|
||||
rules.setGamesPerMatch(7);
|
||||
this.gameFormat = gameFormat;
|
||||
this.deckFormat = deckFormat;
|
||||
this.cardsToUse = cardsToUse;
|
||||
this.decksPerCard = decksPerCard;
|
||||
this.generations = generations;
|
||||
initializeCards();
|
||||
}
|
||||
|
||||
|
||||
protected void initializeCards(){
|
||||
standardMap = CardRelationMatrixGenerator.cardPools.get(format.getName());
|
||||
List<String> cardNames = new ArrayList<>(standardMap.keySet());
|
||||
List<PaperCard> cards = new ArrayList<>();
|
||||
for(String cardName:cardNames){
|
||||
cards.add(StaticData.instance().getCommonCards().getUniqueByName(cardName));
|
||||
}
|
||||
|
||||
Iterable<PaperCard> filtered= Iterables.filter(cards, Predicates.and(
|
||||
Predicates.compose(CardRulesPredicates.IS_KEPT_IN_AI_DECKS, PaperCard.FN_GET_RULES),
|
||||
Predicates.compose(CardRulesPredicates.Presets.IS_NON_LAND, PaperCard.FN_GET_RULES),
|
||||
gameFormat.getFilterPrinted()));
|
||||
|
||||
List<PaperCard> filteredList = Lists.newArrayList(filtered);
|
||||
setRankedList(CardRanker.rankCardsInDeck(filteredList));
|
||||
List<Deck> decks = new ArrayList<>();
|
||||
for(PaperCard card: getRankedList().subList(0,cardsToUse)){
|
||||
System.out.println(card.getName());
|
||||
for( int i=0; i<decksPerCard;i++){
|
||||
decks.add(getDeckForCard(card));
|
||||
}
|
||||
}
|
||||
initializePopulation(decks);
|
||||
}
|
||||
|
||||
protected List<PaperCard> getRankedList(){
|
||||
return rankedList;
|
||||
}
|
||||
|
||||
protected void setRankedList(List<PaperCard> list){
|
||||
rankedList = list;
|
||||
}
|
||||
|
||||
protected Deck getDeckForCard(PaperCard card){
|
||||
Deck genDeck = DeckgenUtil.buildPlanarConquestDeck(card, gameFormat, deckFormat);
|
||||
Deck d = new Deck(genDeck,genDeck.getName()+"_"+deckCount+"_"+generationCount);
|
||||
deckCount++;
|
||||
return d;
|
||||
}
|
||||
|
||||
protected Deck getDeckForCard(PaperCard card, PaperCard card2){
|
||||
Deck genDeck = DeckgenUtil.buildPlanarConquestDeck(card, card2, gameFormat, deckFormat, false);
|
||||
Deck d = new Deck(genDeck,genDeck.getName()+"_"+deckCount+"_"+generationCount);
|
||||
deckCount++;
|
||||
return d;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void evaluateFitness() {
|
||||
deckGroup = new DeckGroup("SimulatedTournament");
|
||||
players = new ArrayList<>();
|
||||
int i=0;
|
||||
for(Deck d:population) {
|
||||
deckGroup.addAiDeck(d);
|
||||
players.add(new TournamentPlayer(GamePlayerUtil.createAiPlayer(d.getName(), 0), i));
|
||||
++i;
|
||||
}
|
||||
tourney = new TournamentSwiss(players, 2);
|
||||
tourney = runTournament(tourney, rules, players.size(), deckGroup);
|
||||
population = new ArrayList<>();
|
||||
for (int k = 0; k < tourney.getAllPlayers().size(); k++) {
|
||||
String deckName = tourney.getAllPlayers().get(k).getPlayer().getName();
|
||||
for (Deck sortedDeck : deckGroup.getAiDecks()) {
|
||||
if (sortedDeck.getName().equals(deckName)) {
|
||||
population.add(sortedDeck);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
deckCount=0;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Deck mutateObject(Deck parent1) {
|
||||
PaperCard allele = parent1.getMain().get(MyRandom.getRandom().nextInt(8));
|
||||
if(!standardMap.keySet().contains(allele.getName())){
|
||||
return null;
|
||||
}
|
||||
return getDeckForCard(allele);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Deck createChild(Deck parent1, Deck parent2) {
|
||||
PaperCard allele = parent1.getMain().get(MyRandom.getRandom().nextInt(8));
|
||||
PaperCard allele2 = parent2.getMain().get(MyRandom.getRandom().nextInt(8));
|
||||
if(!standardMap.keySet().contains(allele.getName())
|
||||
||!standardMap.keySet().contains(allele2.getName())
|
||||
||allele.getName().equals(allele2.getName())){
|
||||
return null;
|
||||
}
|
||||
return getDeckForCard(allele,allele2);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Deck expandPool(){
|
||||
PaperCard seed = getRankedList().get(MyRandom.getRandom().nextInt(getRankedList().size()));
|
||||
return getDeckForCard(seed);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean shouldContinue() {
|
||||
return generationCount<generations;
|
||||
}
|
||||
|
||||
|
||||
public TournamentSwiss runTournament(TournamentSwiss tourney, GameRules rules, int numPlayers, DeckGroup deckGroup){
|
||||
tourney.initializeTournament();
|
||||
|
||||
String lastWinner = "";
|
||||
int curRound = 0;
|
||||
System.out.println(TextUtil.concatNoSpace("Starting a tournament with ",
|
||||
String.valueOf(numPlayers), " players over ",
|
||||
String.valueOf(tourney.getTotalRounds()), " rounds"));
|
||||
while(!tourney.isTournamentOver()) {
|
||||
if (tourney.getActiveRound() != curRound) {
|
||||
if (curRound != 0) {
|
||||
System.out.println(TextUtil.concatNoSpace("End Round - ", String.valueOf(curRound)));
|
||||
}
|
||||
curRound = tourney.getActiveRound();
|
||||
System.out.println("");
|
||||
System.out.println(TextUtil.concatNoSpace("Round ", String.valueOf(curRound) ," Pairings:"));
|
||||
|
||||
for(TournamentPairing pairing : tourney.getActivePairings()) {
|
||||
System.out.println(pairing.outputHeader());
|
||||
}
|
||||
System.out.println("");
|
||||
}
|
||||
|
||||
TournamentPairing pairing = tourney.getNextPairing();
|
||||
List<RegisteredPlayer> regPlayers = AbstractTournament.registerTournamentPlayers(pairing, deckGroup);
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
sb.append("Round ").append(tourney.getActiveRound()).append(" - ");
|
||||
sb.append(pairing.outputHeader());
|
||||
//System.out.println(sb.toString());
|
||||
|
||||
if (!pairing.isBye()) {
|
||||
Match mc = new Match(rules, regPlayers, "TourneyMatch");
|
||||
|
||||
int exceptions = 0;
|
||||
int iGame = 0;
|
||||
while (!mc.isMatchOver()) {
|
||||
// play games until the match ends
|
||||
try{
|
||||
SimulateMatch.simulateSingleMatch(mc, iGame, false);
|
||||
iGame++;
|
||||
} catch(Exception e) {
|
||||
exceptions++;
|
||||
System.out.println(e.toString());
|
||||
if (exceptions > 5) {
|
||||
System.out.println("Exceeded number of exceptions thrown. Abandoning match...");
|
||||
break;
|
||||
} else {
|
||||
System.out.println("Game threw exception. Abandoning game and continuing...");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
LobbyPlayer winner = mc.getWinner().getPlayer();
|
||||
for (TournamentPlayer tp : pairing.getPairedPlayers()) {
|
||||
if (winner.equals(tp.getPlayer())) {
|
||||
pairing.setWinner(tp);
|
||||
lastWinner = winner.getName();
|
||||
//System.out.println(TextUtil.concatNoSpace("Match Winner - ", lastWinner, "!"));
|
||||
//System.out.println("");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tourney.reportMatchCompletion(pairing);
|
||||
}
|
||||
tourney.outputTournamentResults();
|
||||
return tourney;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user