Fully working LDA based deck generation for Standard and Modern

(cherry picked from commit 892ae23)
This commit is contained in:
austinio7116
2018-05-08 21:17:43 +01:00
committed by maustin
parent a25592ebfd
commit 3f8651f586
36 changed files with 1101 additions and 132 deletions

View File

@@ -1,204 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.dataset;
import com.google.common.base.Predicates;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import forge.card.CardRulesPredicates;
import forge.deck.Deck;
import forge.deck.io.DeckStorage;
import forge.game.GameFormat;
import forge.item.PaperCard;
import forge.properties.ForgeConstants;
import forge.util.storage.IStorage;
import forge.util.storage.StorageImmediatelySerialized;
import java.io.*;
import java.util.*;
import java.util.stream.Collectors;
/**
* This class is immutable.
*/
public final class BagOfWords {
private final int numDocs;
private final int numVocabs;
private final int numNNZ;
private final int numWords;
public Vocabularies getVocabs() {
return vocabs;
}
private final Vocabularies vocabs;
// docID -> the vocabs sequence in the doc
private Map<Integer, List<Integer>> words;
// docID -> the doc length
private Map<Integer, Integer> docLength;
/**
* Read the bag-of-words dataset.
* @throws FileNotFoundException
* @throws IOException
* @throws Exception
* @throws NullPointerException filePath is null
*/
public BagOfWords(GameFormat format) throws FileNotFoundException, IOException, Exception {
IStorage<Deck> decks = new StorageImmediatelySerialized<Deck>("Generator", new DeckStorage(new File(ForgeConstants.DECK_GEN_DIR+ForgeConstants.PATH_SEPARATOR+format.getName()),
ForgeConstants.DECK_GEN_DIR, false),
true);
List<PaperCard> cardList = format.getAllCards();
List<Deck> legalDecks = new ArrayList<>();
for(Deck deck:decks){
if(format.isDeckLegal(deck)){
legalDecks.add(deck);
}
}
this.words = new HashMap<>();
this.docLength = new HashMap<>();
ArrayList<Vocabulary> vocabList = new ArrayList<Vocabulary>();
int numDocs = legalDecks.size();
int numVocabs = cardList.size();
int numNNZ = 0;
int numWords = 0;
Map<String, Integer> cardIntegerMap = new HashMap<>();
Map<Integer, PaperCard> integerCardMap = new HashMap<>();
for (int i=0; i<cardList.size(); ++i){
cardIntegerMap.put(cardList.get(i).getName(), i);
vocabList.add(new Vocabulary(i,cardList.get(i).getName()));
integerCardMap.put(i, cardList.get(i));
}
this.vocabs = new Vocabularies(vocabList);
int deckID = 0;
for (Deck deck:legalDecks){
Iterator<Map.Entry<PaperCard,Integer>> cardIterator = deck.getMain().iterator();
while (cardIterator.hasNext()){
Map.Entry<PaperCard,Integer> entry = cardIterator.next();
numNNZ++;
}
List<Integer> cardNumbers = new ArrayList<>();
for(PaperCard card:deck.getMain().toFlatList()){
if(cardIntegerMap.get(card.getName()) == null){
System.out.println(card.getName() + " is missing!!");
}
cardNumbers.add(cardIntegerMap.get(card.getName()));
}
words.put(deckID,cardNumbers);
numWords+=cardNumbers.size();
docLength.put(deckID,cardNumbers.size());
deckID ++;
}
/*String s = null;
while ((s = reader.readLine()) != null) {
List<Integer> numbers
= Arrays.asList(s.split(" ")).stream().map(Integer::parseInt).collect(Collectors.toList());
if (numbers.size() == 1) {
if (headerCount == 2) numDocs = numbers.get(0);
else if (headerCount == 1) numVocabs = numbers.get(0);
else if (headerCount == 0) numNNZ = numbers.get(0);
--headerCount;
continue;
}
else if (numbers.size() == 3) {
final int docID = numbers.get(0);
final int vocabID = numbers.get(1);
final int count = numbers.get(2);
// Set up the words container
if (!words.containsKey(docID)) {
words.put(docID, new ArrayList<>());
}
for (int c = 0; c < count; ++c) {
words.get(docID).add(vocabID);
}
// Set up the doc length map
Optional<Integer> currentCount
= Optional.ofNullable(docLength.putIfAbsent(docID, count));
currentCount.ifPresent(c -> docLength.replace(docID, c + count));
numWords += count;
}
else {
throw new Exception("Invalid dataset form was detected.");
}
}
reader.close();*/
this.numDocs = numDocs;
this.numVocabs = numVocabs;
this.numNNZ = numNNZ;
this.numWords = numWords;
}
public int getNumDocs() {
return numDocs;
}
/**
* Get the length of the document.
* @param docID
* @return length of the document
* @throws IllegalArgumentException docID <= 0 || #documents < docID
*/
public int getDocLength(int docID) {
if (docID < 0 || getNumDocs() < docID) {
throw new IllegalArgumentException();
}
return docLength.get(docID);
}
/**
* Get the unmodifiable list of words in the document.
* @param docID
* @return the unmodifiable list of words
* @throws IllegalArgumentException docID <= 0 || #documents < docID
*/
public List<Integer> getWords(final int docID) {
if (docID < 0 || getNumDocs() < docID) {
throw new IllegalArgumentException();
}
return Collections.unmodifiableList(words.get(docID));
}
public int getNumVocabs() {
return numVocabs;
}
public int getNumNNZ() {
return numNNZ;
}
public int getNumWords() {
return numWords;
}
}

View File

@@ -1,81 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.dataset;
import forge.game.GameFormat;
import forge.model.FModel;
import java.util.List;
public class Dataset {
private BagOfWords bow;
private Vocabularies vocabs;
public Dataset(GameFormat format) throws Exception {
bow = new BagOfWords(format);
vocabs = bow.getVocabs();
}
public Dataset(BagOfWords bow) {
this.bow = bow;
this.vocabs = null;
}
public BagOfWords getBow() {
return bow;
}
public int getNumDocs() {
return bow.getNumDocs();
}
public int getDocLength(int docID) {
return bow.getDocLength(docID);
}
public List<Integer> getWords(int docID) {
return bow.getWords(docID);
}
public int getNumVocabs() {
return bow.getNumVocabs();
}
public int getNumNNZ() {
return bow.getNumNNZ();
}
public int getNumWords() {
return bow.getNumWords();
}
public Vocabulary get(int id) {
return vocabs.get(id);
}
public int size() {
return vocabs.size();
}
public Vocabularies getVocabularies() {
return vocabs;
}
public List<Vocabulary> getVocabularyList() {
return vocabs.getVocabularyList();
}
}

View File

@@ -1,45 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.dataset;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
public class Vocabularies {
private List<Vocabulary> vocabs;
public Vocabularies(List<Vocabulary> vocabs) {
this.vocabs = vocabs;
}
public Vocabulary get(int id) {
return vocabs.get(id);
}
public int size() {
return vocabs.size();
}
public List<Vocabulary> getVocabularyList() {
return Collections.unmodifiableList(vocabs);
}
}

View File

@@ -1,37 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.dataset;
public class Vocabulary {
private final int id;
private String vocabulary;
public Vocabulary(int id, String vocabulary) {
if (vocabulary == null) throw new NullPointerException();
this.id = id;
this.vocabulary = vocabulary;
}
public int id() {
return id;
}
@Override
public String toString() {
return vocabulary;
}
}

View File

@@ -1,59 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.examples;
import java.util.List;
import com.google.common.base.Function;
import forge.GuiBase;
import forge.GuiDesktop;
import forge.deck.generate.lda.lda.LDA;
import static forge.deck.generate.lda.lda.inference.InferenceMethod.*;
import forge.model.FModel;
import forge.properties.ForgePreferences;
import org.apache.commons.lang3.tuple.Pair;
import forge.deck.generate.lda.dataset.Dataset;
public class Example {
public static void main(String[] args) throws Exception {
GuiBase.setInterface(new GuiDesktop());
FModel.initialize(null, new Function<ForgePreferences, Void>() {
@Override
public Void apply(ForgePreferences preferences) {
preferences.setPref(ForgePreferences.FPref.LOAD_CARD_SCRIPTS_LAZILY, false);
return null;
}
});
Dataset dataset = new Dataset(FModel.getFormats().getStandard());
final int numTopics = 50;
LDA lda = new LDA(0.1, 0.1, numTopics, dataset, CGS);
lda.run();
System.out.println(lda.computePerplexity(dataset));
for (int t = 0; t < numTopics; ++t) {
List<Pair<String, Double>> highRankVocabs = lda.getVocabsSortedByPhi(t);
System.out.print("t" + t + ": ");
for (int i = 0; i < 20; ++i) {
System.out.println("[" + highRankVocabs.get(i).getLeft() + "," + highRankVocabs.get(i).getRight() + "],");
}
System.out.println();
}
}
}

View File

@@ -1,46 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
class Alpha {
private List<Double> alphas;
Alpha(double alpha, int numTopics) {
if (alpha <= 0.0 || numTopics <= 0) {
throw new IllegalArgumentException();
}
this.alphas = Stream.generate(() -> alpha)
.limit(numTopics)
.collect(Collectors.toList());
}
double get(int i) {
return alphas.get(i);
}
void set(int i, double value) {
alphas.set(i, value);
}
double sum() {
return alphas.stream().reduce(Double::sum).get();
}
}

View File

@@ -1,54 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
class Beta {
private List<Double> betas;
Beta(double beta, int numVocabs) {
if (beta <= 0.0 || numVocabs <= 0) {
throw new IllegalArgumentException();
}
this.betas = Stream.generate(() -> beta)
.limit(numVocabs)
.collect(Collectors.toList());
}
Beta(double beta) {
if (beta <= 0.0) {
throw new IllegalArgumentException();
}
this.betas = Arrays.asList(beta);
}
double get() {
return get(0);
}
double get(int i) {
return betas.get(i);
}
void set(int i, double value) {
betas.set(i, value);
}
}

View File

@@ -1,60 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda;
class Hyperparameters {
private Alpha alpha;
private Beta beta;
Hyperparameters(double alpha, double beta, int numTopics, int numVocabs) {
this.alpha = new Alpha(alpha, numTopics);
this.beta = new Beta(beta, numVocabs);
}
Hyperparameters(double alpha, double beta, int numTopics) {
this.alpha = new Alpha(alpha, numTopics);
this.beta = new Beta(beta);
}
double alpha(int i) {
return alpha.get(i);
}
double sumAlpha() {
return alpha.sum();
}
double beta() {
return beta.get();
}
double beta(int i) {
return beta.get(i);
}
void setAlpha(int i, double value) {
alpha.set(i, value);
}
void setBeta(int i, double value) {
beta.set(i, value);
}
void setBeta(double value) {
beta.set(0, value);
}
}

View File

@@ -1,184 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda;
import java.io.IOException;
import java.util.List;
import forge.deck.generate.lda.lda.inference.Inference;
import forge.deck.generate.lda.lda.inference.InferenceFactory;
import forge.deck.generate.lda.lda.inference.InferenceMethod;
import forge.deck.generate.lda.lda.inference.InferenceProperties;
import org.apache.commons.lang3.tuple.Pair;
import forge.deck.generate.lda.dataset.BagOfWords;
import forge.deck.generate.lda.dataset.Dataset;
import forge.deck.generate.lda.dataset.Vocabularies;
public class LDA {
private Hyperparameters hyperparameters;
private final int numTopics;
private Dataset dataset;
private final Inference inference;
private InferenceProperties properties;
private boolean trained;
/**
* @param alpha doc-topic hyperparameter
* @param beta topic-vocab hyperparameter
* @param numTopics the number of topics
* @param dataset dataset
* @param bow bag-of-words
* @param method inference method
*/
public LDA(final double alpha, final double beta, final int numTopics,
final Dataset dataset, InferenceMethod method) {
this(alpha, beta, numTopics, dataset, method, InferenceProperties.PROPERTIES_FILE_NAME);
}
LDA(final double alpha, final double beta, final int numTopics,
final Dataset dataset, InferenceMethod method, String propertiesFilePath) {
this.hyperparameters = new Hyperparameters(alpha, beta, numTopics);
this.numTopics = numTopics;
this.dataset = dataset;
this.inference = InferenceFactory.getInstance(method);
this.properties = new InferenceProperties();
this.trained = false;
properties.setSeed(123L);
properties.setNumIteration(100);
}
/**
* Get the vocabulary from its ID.
* @param vocabID
* @return the vocabulary
* @throws IllegalArgumentException vocabID <= 0 || the number of vocabularies < vocabID
*/
public String getVocab(int vocabID) {
if (vocabID < 0 || dataset.getNumVocabs() < vocabID) {
throw new IllegalArgumentException();
}
return dataset.get(vocabID).toString();
}
/**
* Run model inference.
*/
public void run() {
if (properties == null) inference.setUp(this);
else inference.setUp(this, properties);
inference.run();
trained = true;
}
/**
* Get hyperparameter alpha corresponding to topic.
* @param topic
* @return alpha corresponding to topicID
* @throws ArrayIndexOutOfBoundsException topic < 0 || #topics <= topic
*/
public double getAlpha(final int topic) {
if (topic < 0 || numTopics < topic) {
throw new ArrayIndexOutOfBoundsException(topic);
}
return hyperparameters.alpha(topic);
}
public double getSumAlpha() {
return hyperparameters.sumAlpha();
}
public double getBeta() {
return hyperparameters.beta();
}
public int getNumTopics() {
return numTopics;
}
public BagOfWords getBow() {
return dataset.getBow();
}
/**
* Get the value of doc-topic probability \theta_{docID, topicID}.
* @param docID
* @param topicID
* @return the value of doc-topic probability
* @throws IllegalArgumentException docID <= 0 || #docs < docID || topicID < 0 || #topics <= topicID
* @throws IllegalStateException call this method when the inference has not been finished yet
*/
public double getTheta(final int docID, final int topicID) {
if (docID < 0 || dataset.getNumDocs() < docID
|| topicID < 0 || numTopics < topicID) {
throw new IllegalArgumentException();
}
if (!trained) {
throw new IllegalStateException();
}
return inference.getTheta(docID, topicID);
}
/**
* Get the value of topic-vocab probability \phi_{topicID, vocabID}.
* @param topicID
* @param vocabID
* @return the value of topic-vocab probability
* @throws IllegalArgumentException topicID < 0 || #topics <= topicID || vocabID <= 0
* @throws IllegalStateException call this method when the inference has not been finished yet
*/
public double getPhi(final int topicID, final int vocabID) {
if (topicID < 0 || numTopics < topicID || vocabID < 0) {
throw new IllegalArgumentException();
}
if (!trained) {
throw new IllegalStateException();
}
return inference.getPhi(topicID, vocabID);
}
public Vocabularies getVocabularies() {
return dataset.getVocabularies();
}
public List<Pair<String, Double>> getVocabsSortedByPhi(int topicID) {
return inference.getVocabsSortedByPhi(topicID);
}
/**
* Compute the perplexity of trained LDA for the test bag-of-words dataset.
* @param testDataset
* @return the perplexity for the test bag-of-words dataset
*/
public double computePerplexity(Dataset testDataset) {
double loglikelihood = 0.0;
for (int d = 0; d < testDataset.getNumDocs(); ++d) {
for (Integer w : testDataset.getWords(d)) {
double sum = 0.0;
for (int t = 0; t < getNumTopics(); ++t) {
sum += getTheta(d, t) * getPhi(t, w.intValue());
}
loglikelihood += Math.log(sum);
}
}
return Math.exp(-loglikelihood / testDataset.getNumWords());
}
}

View File

@@ -1,62 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference;
import java.util.List;
import forge.deck.generate.lda.lda.LDA;
import org.apache.commons.lang3.tuple.Pair;
public interface Inference {
/**
* Set up for inference.
* @param lda
*/
public void setUp(LDA lda);
/**
* Set up for inference.
* The configuration is read from properties class.
* @param lda
* @param properties
*/
public void setUp(LDA lda, InferenceProperties properties);
/**
* Run model inference.
*/
public void run();
/**
* Get the value of doc-topic probability \theta_{docID, topicID}.
* @param docID
* @param topicID
* @return the value of doc-topic probability
*/
public double getTheta(final int docID, final int topicID);
/**
* Get the value of topic-vocab probability \phi_{topicID, vocabID}.
* @param topicID
* @param vocabID
* @return the value of topic-vocab probability
*/
public double getPhi(final int topicID, final int vocabID);
public List<Pair<String, Double>> getVocabsSortedByPhi(int topicID);
}

View File

@@ -1,37 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference;
public class InferenceFactory {
private InferenceFactory() {}
/**
* Get the LDAInference instance specified by the argument
* @param method
* @return the instance which implements LDAInference
*/
public static Inference getInstance(InferenceMethod method) {
Inference clazz = null;
try {
clazz = (Inference)Class.forName(method.toString()).newInstance();
} catch (ReflectiveOperationException roe) {
roe.printStackTrace();
}
return clazz;
}
}

View File

@@ -1,34 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference;
public enum InferenceMethod {
CGS("forge.deck.generate.lda.lda.inference.internal.CollapsedGibbsSampler"),
// more
;
private String className;
private InferenceMethod(final String className) {
this.className = className;
}
@Override
public String toString() {
return className;
}
}

View File

@@ -1,70 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.Properties;
public class InferenceProperties {
public static String PROPERTIES_FILE_NAME = "lda.properties";
PropertiesLoader loader = new PropertiesLoader();
private Properties properties;
public InferenceProperties() {
this.properties = new Properties();
}
public void setSeed(Long seed){
properties.setProperty("seed",seed.toString());
}
public void setNumIteration(Integer numIteration){
properties.setProperty("numIteration",numIteration.toString());
}
/**
* Load properties.
* @param fileName
* @throws IOException
* @throws NullPointerException fileName is null
*/
public void load(String fileName) throws IOException {
if (fileName == null) throw new NullPointerException();
InputStream stream = loader.getInputStream(fileName);
if (stream == null) throw new NullPointerException();
properties.load(stream);
}
public Long seed() {
return Long.parseLong(properties.getProperty("seed"));
}
public Integer numIteration() {
return Integer.parseInt(properties.getProperty("numIteration"));
}
}
class PropertiesLoader {
public InputStream getInputStream(String fileName) throws FileNotFoundException {
if (fileName == null) throw new NullPointerException();
return new FileInputStream(fileName);
}
}

View File

@@ -1,65 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference.internal;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
class AssignmentCounter {
private List<Integer> counter;
AssignmentCounter(int size) {
if (size <= 0) throw new IllegalArgumentException();
this.counter = IntStream.generate(() -> 0)
.limit(size)
.boxed()
.collect(Collectors.toList());
}
int size() {
return counter.size();
}
int get(int id) {
if (id < 0 || counter.size() <= id) {
throw new IllegalArgumentException();
}
return counter.get(id);
}
int getSum() {
return counter.stream().reduce(Integer::sum).get();
}
void increment(int id) {
if (id < 0 || counter.size() <= id) {
throw new IllegalArgumentException();
}
counter.set(id, counter.get(id) + 1);
}
void decrement(int id) {
if (id < 0 || counter.size() <= id) {
throw new IllegalArgumentException();
}
if (counter.get(id) == 0) {
throw new IllegalStateException();
}
counter.set(id, counter.get(id) - 1);
}
}

View File

@@ -1,220 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference.internal;
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
import forge.deck.generate.lda.lda.LDA;
import forge.deck.generate.lda.lda.inference.Inference;
import forge.deck.generate.lda.lda.inference.InferenceProperties;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.apache.commons.math3.distribution.IntegerDistribution;
import forge.deck.generate.lda.dataset.Vocabulary;
public class CollapsedGibbsSampler implements Inference {
private LDA lda;
private Topics topics;
private Documents documents;
private int numIteration;
private static final long DEFAULT_SEED = 0L;
private static final int DEFAULT_NUM_ITERATION = 100;
// ready for Gibbs sampling
private boolean ready;
public CollapsedGibbsSampler() {
ready = false;
}
@Override
public void setUp(LDA lda, InferenceProperties properties) {
if (properties == null) {
setUp(lda);
return;
}
this.lda = lda;
initialize(this.lda);
final long seed = properties.seed() != null ? properties.seed() : DEFAULT_SEED;
initializeTopicAssignment(seed);
this.numIteration
= properties.numIteration() != null ? properties.numIteration() : DEFAULT_NUM_ITERATION;
this.ready = true;
}
@Override
public void setUp(LDA lda) {
if (lda == null) throw new NullPointerException();
this.lda = lda;
initialize(this.lda);
initializeTopicAssignment(DEFAULT_SEED);
this.numIteration = DEFAULT_NUM_ITERATION;
this.ready = true;
}
private void initialize(LDA lda) {
assert lda != null;
this.topics = new Topics(lda);
this.documents = new Documents(lda);
}
public boolean isReady() {
return ready;
}
public int getNumIteration() {
return numIteration;
}
public void setNumIteration(final int numIteration) {
this.numIteration = numIteration;
}
@Override
public void run() {
if (!ready) {
throw new IllegalStateException("instance has not set up yet");
}
for (int i = 1; i <= numIteration; ++i) {
System.out.println("Iteraion " + i + ".");
runSampling();
}
}
/**
* Run collapsed Gibbs sampling [Griffiths and Steyvers 2004].
*/
void runSampling() {
for (Document d : documents.getDocuments()) {
for (int w = 0; w < d.getDocLength(); ++w) {
final Topic oldTopic = topics.get(d.getTopicID(w));
d.decrementTopicCount(oldTopic.id());
final Vocabulary v = d.getVocabulary(w);
oldTopic.decrementVocabCount(v.id());
IntegerDistribution distribution
= getFullConditionalDistribution(lda.getNumTopics(), d.id(), v.id());
final int newTopicID = distribution.sample();
d.setTopicID(w, newTopicID);
d.incrementTopicCount(newTopicID);
final Topic newTopic = topics.get(newTopicID);
newTopic.incrementVocabCount(v.id());
}
}
}
/**
* Get the full conditional distribution over topics.
* docID and vocabID are passed to this distribution for parameters.
* @param numTopics
* @param docID
* @param vocabID
* @return the integer distribution over topics
*/
IntegerDistribution getFullConditionalDistribution(final int numTopics, final int docID, final int vocabID) {
int[] topics = IntStream.range(0, numTopics).toArray();
double[] probabilities = Arrays.stream(topics)
.mapToDouble(t -> getTheta(docID, t) * getPhi(t, vocabID))
.toArray();
return new EnumeratedIntegerDistribution(topics, probabilities);
}
/**
* Initialize the topic assignment.
* @param seed the seed of a pseudo random number generator
*/
void initializeTopicAssignment(final long seed) {
documents.initializeTopicAssignment(topics, seed);
}
/**
* Get the count of topicID assigned to docID.
* @param docID
* @param topicID
* @return the count of topicID assigned to docID
*/
int getDTCount(final int docID, final int topicID) {
if (!ready) throw new IllegalStateException();
if (docID <= 0 || lda.getBow().getNumDocs() < docID
|| topicID < 0 || lda.getNumTopics() <= topicID) {
throw new IllegalArgumentException();
}
return documents.getTopicCount(docID, topicID);
}
/**
* Get the count of vocabID assigned to topicID.
* @param topicID
* @param vocabID
* @return the count of vocabID assigned to topicID
*/
int getTVCount(final int topicID, final int vocabID) {
if (!ready) throw new IllegalStateException();
if (topicID < 0 || lda.getNumTopics() <= topicID || vocabID <= 0) {
throw new IllegalArgumentException();
}
final Topic topic = topics.get(topicID);
return topic.getVocabCount(vocabID);
}
/**
* Get the sum of counts of vocabs assigned to topicID.
* This is the sum of topic-vocab count over vocabs.
* @param topicID
* @return the sum of counts of vocabs assigned to topicID
* @throws IllegalArgumentException topicID < 0 || #topic <= topicID
*/
int getTSumCount(final int topicID) {
if (topicID < 0 || lda.getNumTopics() <= topicID) {
throw new IllegalArgumentException();
}
final Topic topic = topics.get(topicID);
return topic.getSumCount();
}
@Override
public double getTheta(final int docID, final int topicID) {
if (!ready) throw new IllegalStateException();
return documents.getTheta(docID, topicID, lda.getAlpha(topicID), lda.getSumAlpha());
}
@Override
public double getPhi(int topicID, int vocabID) {
if (!ready) throw new IllegalStateException();
return topics.getPhi(topicID, vocabID, lda.getBeta());
}
@Override
public List<Pair<String, Double>> getVocabsSortedByPhi(int topicID) {
return topics.getVocabsSortedByPhi(topicID, lda.getVocabularies(), lda.getBeta());
}
}

View File

@@ -1,86 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference.internal;
import java.util.List;
import forge.deck.generate.lda.dataset.Vocabulary;
public class Document {
private final int id;
private TopicCounter topicCount;
private Words words;
private TopicAssignment assignment;
Document(int id, int numTopics, List<Vocabulary> words) {
if (id < 0 || numTopics <= 0) throw new IllegalArgumentException();
this.id = id;
this.topicCount = new TopicCounter(numTopics);
this.words = new Words(words);
this.assignment = new TopicAssignment();
}
int id() {
return id;
}
int getTopicCount(int topicID) {
return topicCount.getTopicCount(topicID);
}
int getDocLength() {
return words.getNumWords();
}
void incrementTopicCount(int topicID) {
topicCount.incrementTopicCount(topicID);
}
void decrementTopicCount(int topicID) {
topicCount.decrementTopicCount(topicID);
}
void initializeTopicAssignment(long seed) {
assignment.initialize(getDocLength(), topicCount.size(), seed);
for (int w = 0; w < getDocLength(); ++w) {
incrementTopicCount(assignment.get(w));
}
}
int getTopicID(int wordID) {
return assignment.get(wordID);
}
void setTopicID(int wordID, int topicID) {
assignment.set(wordID, topicID);
}
Vocabulary getVocabulary(int wordID) {
return words.get(wordID);
}
List<Vocabulary> getWords() {
return words.getWords();
}
double getTheta(int topicID, double alpha, double sumAlpha) {
if (topicID < 0 || alpha <= 0.0 || sumAlpha <= 0.0) {
throw new IllegalArgumentException();
}
return (getTopicCount(topicID) + alpha) / (getDocLength() + sumAlpha);
}
}

View File

@@ -1,100 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference.internal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import forge.deck.generate.lda.lda.LDA;
import forge.deck.generate.lda.dataset.BagOfWords;
import forge.deck.generate.lda.dataset.Vocabularies;
import forge.deck.generate.lda.dataset.Vocabulary;
class Documents {
private List<Document> documents;
Documents(LDA lda) {
if (lda == null) throw new NullPointerException();
documents = new ArrayList<>();
for (int d = 0; d < lda.getBow().getNumDocs(); ++d) {
List<Vocabulary> vocabList = getVocabularyList(d, lda.getBow(), lda.getVocabularies());
Document doc = new Document(d, lda.getNumTopics(), vocabList);
documents.add(doc);
}
}
List<Vocabulary> getVocabularyList(int docID, BagOfWords bow, Vocabularies vocabs) {
assert docID > 0 && bow != null && vocabs != null;
System.out.println(docID);
System.out.println(bow.getWords(docID).toString());
return bow.getWords(docID).stream()
.map(id -> vocabs.get(id))
.collect(Collectors.toList());
}
int getTopicID(int docID, int wordID) {
return documents.get(docID).getTopicID(wordID);
}
void setTopicID(int docID, int wordID, int topicID) {
documents.get(docID).setTopicID(wordID, topicID);
}
Vocabulary getVocab(int docID, int wordID) {
return documents.get(docID).getVocabulary(wordID);
}
List<Vocabulary> getWords(int docID) {
return documents.get(docID).getWords();
}
List<Document> getDocuments() {
return Collections.unmodifiableList(documents);
}
void incrementTopicCount(int docID, int topicID) {
documents.get(docID).incrementTopicCount(topicID);
}
void decrementTopicCount(int docID, int topicID) {
documents.get(docID).decrementTopicCount(topicID);
}
int getTopicCount(int docID, int topicID) {
return documents.get(docID).getTopicCount(topicID);
}
double getTheta(int docID, int topicID, double alpha, double sumAlpha) {
if (docID < 0 || documents.size() < docID) throw new IllegalArgumentException();
return documents.get(docID).getTheta(topicID, alpha, sumAlpha);
}
void initializeTopicAssignment(Topics topics, long seed) {
for (Document d : getDocuments()) {
d.initializeTopicAssignment(seed);
for (int w = 0; w < d.getDocLength(); ++w) {
final int topicID = d.getTopicID(w);
final Topic topic = topics.get(topicID);
final Vocabulary vocab = d.getVocabulary(w);
topic.incrementVocabCount(vocab.id());
}
}
}
}

View File

@@ -1,55 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference.internal;
class Topic {
private final int id;
private final int numVocabs;
private VocabularyCounter counter;
Topic(int id, int numVocabs) {
if (id < 0 || numVocabs <= 0) throw new IllegalArgumentException();
this.id = id;
this.numVocabs = numVocabs;
this.counter = new VocabularyCounter(numVocabs);
}
int id() {
return id;
}
int getVocabCount(int vocabID) {
return counter.getVocabCount(vocabID);
}
int getSumCount() {
return counter.getSumCount();
}
void incrementVocabCount(int vocabID) {
counter.incrementVocabCount(vocabID);
}
void decrementVocabCount(int vocabID) {
counter.decrementVocabCount(vocabID);
}
double getPhi(int vocabID, double beta) {
if (vocabID < 0 || beta <= 0) throw new IllegalArgumentException();
return (getVocabCount(vocabID) + beta) / (getSumCount() + beta * numVocabs);
}
}

View File

@@ -1,60 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference.internal;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
class TopicAssignment {
private List<Integer> topicAssignment;
private boolean ready;
TopicAssignment() {
topicAssignment = new ArrayList<>();
ready = false;
}
void set(int wordID, int topicID) {
if (!ready) throw new IllegalStateException();
if (wordID < 0 || topicAssignment.size() <= wordID || topicID < 0) {
throw new IllegalArgumentException();
}
topicAssignment.set(wordID, topicID);
}
int get(int wordID) {
if (!ready) throw new IllegalStateException();
if (wordID < 0 || topicAssignment.size() <= wordID) {
throw new IllegalArgumentException();
}
return topicAssignment.get(wordID);
}
void initialize(int docLength, int numTopics, long seed) {
if (docLength <= 0 || numTopics <= 0) {
throw new IllegalArgumentException();
}
Random random = new Random(seed);
topicAssignment = random.ints(docLength, 0, numTopics)
.boxed()
.collect(Collectors.toList());
ready = true;
}
}

View File

@@ -1,45 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference.internal;
class TopicCounter {
private AssignmentCounter topicCount;
TopicCounter(int numTopics) {
this.topicCount = new AssignmentCounter(numTopics);
}
int getTopicCount(int topicID) {
return topicCount.get(topicID);
}
int getDocLength() {
return topicCount.getSum();
}
void incrementTopicCount(int topicID) {
topicCount.increment(topicID);
}
void decrementTopicCount(int topicID) {
topicCount.decrement(topicID);
}
int size() {
return topicCount.size();
}
}

View File

@@ -1,86 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference.internal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import forge.deck.generate.lda.lda.LDA;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import forge.deck.generate.lda.dataset.Vocabularies;
class Topics {
private List<Topic> topics;
Topics(LDA lda) {
if (lda == null) throw new NullPointerException();
topics = new ArrayList<>();
for (int t = 0; t < lda.getNumTopics(); ++t) {
topics.add(new Topic(t, lda.getBow().getNumVocabs()));
}
}
int numTopics() {
return topics.size();
}
Topic get(int id) {
return topics.get(id);
}
int getVocabCount(int topicID, int vocabID) {
return topics.get(topicID).getVocabCount(vocabID);
}
int getSumCount(int topicID) {
return topics.get(topicID).getSumCount();
}
void incrementVocabCount(int topicID, int vocabID) {
topics.get(topicID).incrementVocabCount(vocabID);
}
void decrementVocabCount(int topicID, int vocabID) {
topics.get(topicID).decrementVocabCount(vocabID);
}
double getPhi(int topicID, int vocabID, double beta) {
if (topicID < 0 || topics.size() <= topicID) throw new IllegalArgumentException();
return topics.get(topicID).getPhi(vocabID, beta);
}
List<Pair<String, Double>> getVocabsSortedByPhi(int topicID, Vocabularies vocabs, final double beta) {
if (topicID < 0 || topics.size() <= topicID || vocabs == null || beta <= 0.0) {
throw new IllegalArgumentException();
}
Topic topic = topics.get(topicID);
List<Pair<String, Double>> vocabProbPairs
= vocabs.getVocabularyList()
.stream()
.map(v -> new ImmutablePair<String, Double>(v.toString(), topic.getPhi(v.id(), beta)))
.sorted((p1, p2) -> Double.compare(p2.getRight(), p1.getRight()))
.collect(Collectors.toList());
return Collections.unmodifiableList(vocabProbPairs);
}
}

View File

@@ -1,46 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference.internal;
class VocabularyCounter {
private AssignmentCounter vocabCount;
private int sumCount;
VocabularyCounter(int numVocabs) {
this.vocabCount = new AssignmentCounter(numVocabs);
this.sumCount = 0;
}
int getVocabCount(int vocabID) {
if (vocabCount.size() < vocabID) return 0;
else return vocabCount.get(vocabID);
}
int getSumCount() {
return sumCount;
}
void incrementVocabCount(int vocabID) {
vocabCount.increment(vocabID);
++sumCount;
}
void decrementVocabCount(int vocabID) {
vocabCount.decrement(vocabID);
--sumCount;
}
}

View File

@@ -1,46 +0,0 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.deck.generate.lda.lda.inference.internal;
import java.util.Collections;
import java.util.List;
import forge.deck.generate.lda.dataset.Vocabulary;
class Words {
private List<Vocabulary> words;
Words(List<Vocabulary> words) {
if (words == null) throw new NullPointerException();
this.words = words;
}
int getNumWords() {
return words.size();
}
Vocabulary get(int id) {
if (id < 0 || words.size() <= id) {
throw new IllegalArgumentException();
}
return words.get(id);
}
List<Vocabulary> getWords() {
return Collections.unmodifiableList(words);
}
}

View File

@@ -0,0 +1,139 @@
package forge.planarconquestgenerate;
import com.google.common.base.Function;
import com.google.common.base.Predicates;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import forge.GuiBase;
import forge.GuiDesktop;
import forge.LobbyPlayer;
import forge.StaticData;
import forge.card.CardRulesPredicates;
import forge.deck.*;
import forge.deck.io.DeckStorage;
import forge.game.GameFormat;
import forge.game.GameRules;
import forge.game.GameType;
import forge.game.Match;
import forge.game.player.RegisteredPlayer;
import forge.item.PaperCard;
import forge.limited.CardRanker;
import forge.model.FModel;
import forge.player.GamePlayerUtil;
import forge.properties.ForgeConstants;
import forge.properties.ForgePreferences;
import forge.tournament.system.AbstractTournament;
import forge.tournament.system.TournamentPairing;
import forge.tournament.system.TournamentPlayer;
import forge.tournament.system.TournamentSwiss;
import forge.util.AbstractGeneticAlgorithm;
import forge.util.MyRandom;
import forge.util.TextUtil;
import forge.view.SimulateMatch;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class PlanarConquestCommanderGeneraterGA extends PlanarConquestGeneraterGA {
private int deckCount = 0;
public static void main(String[] args){
test();
}
public static void test(){
GuiBase.setInterface(new GuiDesktop());
FModel.initialize(null, new Function<ForgePreferences, Void>() {
@Override
public Void apply(ForgePreferences preferences) {
preferences.setPref(ForgePreferences.FPref.LOAD_CARD_SCRIPTS_LAZILY, false);
return null;
}
});
List<String> sets = new ArrayList<>();
sets.add("XLN");
sets.add("RIX");
PlanarConquestCommanderGeneraterGA ga = new PlanarConquestCommanderGeneraterGA(new GameRules(GameType.Constructed),
new GameFormat("conquest",sets,null),
DeckFormat.PlanarConquest,
4,
12,
16);
ga.run();
List<Deck> winners = ga.listFinalPopulation();
DeckStorage storage = new DeckStorage(new File(ForgeConstants.DECK_CONSTRUCTED_DIR), ForgeConstants.DECK_BASE_DIR);
int i=1;
for(Deck deck:winners){
storage.save(new Deck(deck,"GA"+i+"_"+deck.getName()));
i++;
}
}
public PlanarConquestCommanderGeneraterGA(GameRules rules, GameFormat gameFormat, DeckFormat deckFormat, int cardsToUse, int decksPerCard, int generations){
super(rules,gameFormat,deckFormat,cardsToUse,decksPerCard,generations);
}
@Override
protected void initializeCards(){
standardMap = CardRelationMatrixGenerator.cardPools.get(format.getName());
List<String> cardNames = new ArrayList<>(standardMap.keySet());
List<PaperCard> cards = new ArrayList<>();
for(String cardName:cardNames){
cards.add(StaticData.instance().getCommonCards().getUniqueByName(cardName));
}
Iterable<PaperCard> filtered= Iterables.filter(cards, Predicates.and(
Predicates.compose(CardRulesPredicates.IS_KEPT_IN_AI_DECKS, PaperCard.FN_GET_RULES),
Predicates.compose(CardRulesPredicates.Presets.IS_PLANESWALKER, PaperCard.FN_GET_RULES),
//Predicates.compose(CardRulesPredicates.Presets.IS_LEGENDARY, PaperCard.FN_GET_RULES),
gameFormat.getFilterPrinted()));
List<PaperCard> filteredList = Lists.newArrayList(filtered);
rankedList = CardRanker.rankCardsInDeck(filteredList);
List<Deck> decks = new ArrayList<>();
for(PaperCard card: rankedList.subList(0,cardsToUse)){
System.out.println(card.getName());
for( int i=0; i<decksPerCard;i++){
decks.add(getDeckForCard(card));
}
}
initializePopulation(decks);
}
@Override
protected Deck getDeckForCard(PaperCard card){
Deck genDeck = DeckgenUtil.buildPlanarConquestCommanderDeck(card, gameFormat, deckFormat);
Deck d = new Deck(genDeck,genDeck.getName()+"_"+deckCount+"_"+generationCount);
deckCount++;
return d;
}
@Override
protected Deck getDeckForCard(PaperCard card, PaperCard card2){
return getDeckForCard(card);
}
@Override
protected Deck mutateObject(Deck parent1) {
PaperCard allele = parent1.getCommanders().get(0);
if(!standardMap.keySet().contains(allele.getName())){
return null;
}
return getDeckForCard(allele);
}
@Override
protected Deck createChild(Deck parent1, Deck parent2) {
PaperCard allele = parent1.getCommanders().get(0);
return getDeckForCard(allele);
}
}

View File

@@ -0,0 +1,277 @@
package forge.planarconquestgenerate;
import com.google.common.base.Function;
import com.google.common.base.Predicates;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import forge.GuiBase;
import forge.GuiDesktop;
import forge.LobbyPlayer;
import forge.StaticData;
import forge.card.CardRulesPredicates;
import forge.deck.*;
import forge.deck.io.DeckStorage;
import forge.game.GameFormat;
import forge.game.GameRules;
import forge.game.GameType;
import forge.game.Match;
import forge.game.player.RegisteredPlayer;
import forge.item.PaperCard;
import forge.limited.CardRanker;
import forge.model.FModel;
import forge.player.GamePlayerUtil;
import forge.properties.ForgeConstants;
import forge.properties.ForgePreferences;
import forge.tournament.system.AbstractTournament;
import forge.tournament.system.TournamentPairing;
import forge.tournament.system.TournamentPlayer;
import forge.tournament.system.TournamentSwiss;
import forge.util.AbstractGeneticAlgorithm;
import forge.util.MyRandom;
import forge.util.TextUtil;
import forge.view.SimulateMatch;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
public class PlanarConquestGeneraterGA extends AbstractGeneticAlgorithm<Deck> {
private DeckGroup deckGroup;
private List<TournamentPlayer> players = new ArrayList<>();
private TournamentSwiss tourney = null;
protected Map<String,List<Map.Entry<PaperCard,Integer>>> standardMap;
private GameRules rules;
protected GameFormat format = FModel.getFormats().getStandard();
protected int generations;
protected GameFormat gameFormat;
protected DeckFormat deckFormat;
protected int cardsToUse;
protected int decksPerCard;
private int deckCount = 0;
protected List<PaperCard> rankedList;
public static void main(String[] args){
test();
}
public static void test(){
GuiBase.setInterface(new GuiDesktop());
FModel.initialize(null, new Function<ForgePreferences, Void>() {
@Override
public Void apply(ForgePreferences preferences) {
preferences.setPref(ForgePreferences.FPref.LOAD_CARD_SCRIPTS_LAZILY, false);
return null;
}
});
List<String> sets = new ArrayList<>();
sets.add("XLN");
sets.add("RIX");
PlanarConquestGeneraterGA ga = new PlanarConquestGeneraterGA(new GameRules(GameType.Constructed),
new GameFormat("conquest",sets,null),
DeckFormat.PlanarConquest,
40,
4,
10);
ga.run();
List<Deck> winners = ga.listFinalPopulation();
DeckStorage storage = new DeckStorage(new File(ForgeConstants.DECK_CONSTRUCTED_DIR), ForgeConstants.DECK_BASE_DIR);
int i=1;
for(Deck deck:winners){
storage.save(new Deck(deck,"GA"+i+"_"+deck.getName()));
i++;
}
}
public PlanarConquestGeneraterGA(GameRules rules, GameFormat gameFormat, DeckFormat deckFormat, int cardsToUse, int decksPerCard, int generations){
this.rules = rules;
rules.setGamesPerMatch(7);
this.gameFormat = gameFormat;
this.deckFormat = deckFormat;
this.cardsToUse = cardsToUse;
this.decksPerCard = decksPerCard;
this.generations = generations;
initializeCards();
}
protected void initializeCards(){
standardMap = CardRelationMatrixGenerator.cardPools.get(format.getName());
List<String> cardNames = new ArrayList<>(standardMap.keySet());
List<PaperCard> cards = new ArrayList<>();
for(String cardName:cardNames){
cards.add(StaticData.instance().getCommonCards().getUniqueByName(cardName));
}
Iterable<PaperCard> filtered= Iterables.filter(cards, Predicates.and(
Predicates.compose(CardRulesPredicates.IS_KEPT_IN_AI_DECKS, PaperCard.FN_GET_RULES),
Predicates.compose(CardRulesPredicates.Presets.IS_NON_LAND, PaperCard.FN_GET_RULES),
gameFormat.getFilterPrinted()));
List<PaperCard> filteredList = Lists.newArrayList(filtered);
setRankedList(CardRanker.rankCardsInDeck(filteredList));
List<Deck> decks = new ArrayList<>();
for(PaperCard card: getRankedList().subList(0,cardsToUse)){
System.out.println(card.getName());
for( int i=0; i<decksPerCard;i++){
decks.add(getDeckForCard(card));
}
}
initializePopulation(decks);
}
protected List<PaperCard> getRankedList(){
return rankedList;
}
protected void setRankedList(List<PaperCard> list){
rankedList = list;
}
protected Deck getDeckForCard(PaperCard card){
Deck genDeck = DeckgenUtil.buildPlanarConquestDeck(card, gameFormat, deckFormat);
Deck d = new Deck(genDeck,genDeck.getName()+"_"+deckCount+"_"+generationCount);
deckCount++;
return d;
}
protected Deck getDeckForCard(PaperCard card, PaperCard card2){
Deck genDeck = DeckgenUtil.buildPlanarConquestDeck(card, card2, gameFormat, deckFormat, false);
Deck d = new Deck(genDeck,genDeck.getName()+"_"+deckCount+"_"+generationCount);
deckCount++;
return d;
}
@Override
protected void evaluateFitness() {
deckGroup = new DeckGroup("SimulatedTournament");
players = new ArrayList<>();
int i=0;
for(Deck d:population) {
deckGroup.addAiDeck(d);
players.add(new TournamentPlayer(GamePlayerUtil.createAiPlayer(d.getName(), 0), i));
++i;
}
tourney = new TournamentSwiss(players, 2);
tourney = runTournament(tourney, rules, players.size(), deckGroup);
population = new ArrayList<>();
for (int k = 0; k < tourney.getAllPlayers().size(); k++) {
String deckName = tourney.getAllPlayers().get(k).getPlayer().getName();
for (Deck sortedDeck : deckGroup.getAiDecks()) {
if (sortedDeck.getName().equals(deckName)) {
population.add(sortedDeck);
break;
}
}
}
deckCount=0;
}
@Override
protected Deck mutateObject(Deck parent1) {
PaperCard allele = parent1.getMain().get(MyRandom.getRandom().nextInt(8));
if(!standardMap.keySet().contains(allele.getName())){
return null;
}
return getDeckForCard(allele);
}
@Override
protected Deck createChild(Deck parent1, Deck parent2) {
PaperCard allele = parent1.getMain().get(MyRandom.getRandom().nextInt(8));
PaperCard allele2 = parent2.getMain().get(MyRandom.getRandom().nextInt(8));
if(!standardMap.keySet().contains(allele.getName())
||!standardMap.keySet().contains(allele2.getName())
||allele.getName().equals(allele2.getName())){
return null;
}
return getDeckForCard(allele,allele2);
}
@Override
protected Deck expandPool(){
PaperCard seed = getRankedList().get(MyRandom.getRandom().nextInt(getRankedList().size()));
return getDeckForCard(seed);
}
@Override
protected boolean shouldContinue() {
return generationCount<generations;
}
public TournamentSwiss runTournament(TournamentSwiss tourney, GameRules rules, int numPlayers, DeckGroup deckGroup){
tourney.initializeTournament();
String lastWinner = "";
int curRound = 0;
System.out.println(TextUtil.concatNoSpace("Starting a tournament with ",
String.valueOf(numPlayers), " players over ",
String.valueOf(tourney.getTotalRounds()), " rounds"));
while(!tourney.isTournamentOver()) {
if (tourney.getActiveRound() != curRound) {
if (curRound != 0) {
System.out.println(TextUtil.concatNoSpace("End Round - ", String.valueOf(curRound)));
}
curRound = tourney.getActiveRound();
System.out.println("");
System.out.println(TextUtil.concatNoSpace("Round ", String.valueOf(curRound) ," Pairings:"));
for(TournamentPairing pairing : tourney.getActivePairings()) {
System.out.println(pairing.outputHeader());
}
System.out.println("");
}
TournamentPairing pairing = tourney.getNextPairing();
List<RegisteredPlayer> regPlayers = AbstractTournament.registerTournamentPlayers(pairing, deckGroup);
StringBuilder sb = new StringBuilder();
sb.append("Round ").append(tourney.getActiveRound()).append(" - ");
sb.append(pairing.outputHeader());
//System.out.println(sb.toString());
if (!pairing.isBye()) {
Match mc = new Match(rules, regPlayers, "TourneyMatch");
int exceptions = 0;
int iGame = 0;
while (!mc.isMatchOver()) {
// play games until the match ends
try{
SimulateMatch.simulateSingleMatch(mc, iGame, false);
iGame++;
} catch(Exception e) {
exceptions++;
System.out.println(e.toString());
if (exceptions > 5) {
System.out.println("Exceeded number of exceptions thrown. Abandoning match...");
break;
} else {
System.out.println("Game threw exception. Abandoning game and continuing...");
}
}
}
LobbyPlayer winner = mc.getWinner().getPlayer();
for (TournamentPlayer tp : pairing.getPairedPlayers()) {
if (winner.equals(tp.getPlayer())) {
pairing.setWinner(tp);
lastWinner = winner.getName();
//System.out.println(TextUtil.concatNoSpace("Match Winner - ", lastWinner, "!"));
//System.out.println("");
break;
}
}
}
tourney.reportMatchCompletion(pairing);
}
tourney.outputTournamentResults();
return tourney;
}
}