From 185c4c09c9b013b033c24f8956cf656094215b4f Mon Sep 17 00:00:00 2001 From: Hans Mackowiak Date: Sun, 11 Apr 2021 05:51:11 +0000 Subject: [PATCH] Changed card-based deck generation to archetype based making better use of the new LDA models. Decks can now be selected by archetype with names generated from the source decklists. Archetypes are ordered by popularity. --- .../forge/deck/ArchetypeDeckGenerator.java | 6 +- .../src/main/java/forge/deck/DeckgenUtil.java | 2 - .../main/java/forge/deck/io/Archetype.java | 1 + forge-lda/pom.xml | 37 ++ .../java/forge/lda/LDAModelGenetrator.java | 389 ++++++++++++++++++ .../java/forge/lda/dataset/BagOfWords.java | 210 ++++++++++ .../main/java/forge/lda/dataset/Dataset.java | 80 ++++ .../java/forge/lda/dataset/Vocabularies.java | 40 ++ .../java/forge/lda/dataset/Vocabulary.java | 37 ++ .../src/main/java/forge/lda/lda/Alpha.java | 46 +++ .../src/main/java/forge/lda/lda/Beta.java | 54 +++ .../java/forge/lda/lda/Hyperparameters.java | 60 +++ .../src/main/java/forge/lda/lda/LDA.java | 183 ++++++++ .../forge/lda/lda/inference/Inference.java | 62 +++ .../lda/lda/inference/InferenceFactory.java | 37 ++ .../lda/lda/inference/InferenceMethod.java | 36 ++ .../lda/inference/InferenceProperties.java | 70 ++++ .../inference/internal/AssignmentCounter.java | 65 +++ .../internal/CollapsedGibbsSampler.java | 220 ++++++++++ .../lda/lda/inference/internal/Document.java | 86 ++++ .../lda/lda/inference/internal/Documents.java | 100 +++++ .../lda/lda/inference/internal/Topic.java | 55 +++ .../inference/internal/TopicAssignment.java | 60 +++ .../lda/inference/internal/TopicCounter.java | 45 ++ .../lda/lda/inference/internal/Topics.java | 85 ++++ .../inference/internal/VocabularyCounter.java | 46 +++ .../lda/lda/inference/internal/Words.java | 46 +++ pom.xml | 1 + 28 files changed, 2154 insertions(+), 5 deletions(-) create mode 100644 forge-lda/pom.xml create mode 100644 forge-lda/src/main/java/forge/lda/LDAModelGenetrator.java create mode 100644 forge-lda/src/main/java/forge/lda/dataset/BagOfWords.java create mode 100644 forge-lda/src/main/java/forge/lda/dataset/Dataset.java create mode 100644 forge-lda/src/main/java/forge/lda/dataset/Vocabularies.java create mode 100644 forge-lda/src/main/java/forge/lda/dataset/Vocabulary.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/Alpha.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/Beta.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/Hyperparameters.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/LDA.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/Inference.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/InferenceFactory.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/InferenceMethod.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/InferenceProperties.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/internal/AssignmentCounter.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/internal/CollapsedGibbsSampler.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/internal/Document.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/internal/Documents.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/internal/Topic.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicAssignment.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicCounter.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/internal/Topics.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/internal/VocabularyCounter.java create mode 100644 forge-lda/src/main/java/forge/lda/lda/inference/internal/Words.java diff --git a/forge-gui/src/main/java/forge/deck/ArchetypeDeckGenerator.java b/forge-gui/src/main/java/forge/deck/ArchetypeDeckGenerator.java index 37f69fa44e8..f53e5f2597b 100644 --- a/forge-gui/src/main/java/forge/deck/ArchetypeDeckGenerator.java +++ b/forge-gui/src/main/java/forge/deck/ArchetypeDeckGenerator.java @@ -17,9 +17,9 @@ import forge.model.FModel; public class ArchetypeDeckGenerator extends DeckProxy implements Comparable { public static List getMatrixDecks(GameFormat format, boolean isForAi){ final List decks = new ArrayList<>(); - for(Archetype archetype: CardArchetypeLDAGenerator.ldaArchetypes.get(format.getName())) { - decks.add(new ArchetypeDeckGenerator(archetype, format, isForAi)); - } + for(Archetype archetype: CardArchetypeLDAGenerator.ldaArchetypes.get(format.getName())) { + decks.add(new ArchetypeDeckGenerator(archetype, format, isForAi)); + } return decks; } diff --git a/forge-gui/src/main/java/forge/deck/DeckgenUtil.java b/forge-gui/src/main/java/forge/deck/DeckgenUtil.java index 6de099da5ae..76e07081e85 100644 --- a/forge-gui/src/main/java/forge/deck/DeckgenUtil.java +++ b/forge-gui/src/main/java/forge/deck/DeckgenUtil.java @@ -331,12 +331,10 @@ public class DeckgenUtil { return deck; } - /** * @param selection {@link java.lang.String} array * @return {@link forge.deck.Deck} */ - public static Deck buildColorDeck(List selection, Predicate formatFilter, boolean forAi) { try { final Deck deck; diff --git a/forge-gui/src/main/java/forge/deck/io/Archetype.java b/forge-gui/src/main/java/forge/deck/io/Archetype.java index 8bade9a7faa..4f3bea8e0d6 100644 --- a/forge-gui/src/main/java/forge/deck/io/Archetype.java +++ b/forge-gui/src/main/java/forge/deck/io/Archetype.java @@ -1,5 +1,6 @@ package forge.deck.io; + import java.io.Serializable; import java.util.List; diff --git a/forge-lda/pom.xml b/forge-lda/pom.xml new file mode 100644 index 00000000000..bdda207b546 --- /dev/null +++ b/forge-lda/pom.xml @@ -0,0 +1,37 @@ + + 4.0.0 + + + forge + forge + 1.6.40-SNAPSHOT + + + forge-lda + jar + Forge LDA + + + 0 + 0 + 0 + + + + + forge + forge-gui + ${project.version} + + + org.apache.commons + commons-lang3 + 3.8.1 + + + org.apache.commons + commons-math3 + 3.6.1 + + + \ No newline at end of file diff --git a/forge-lda/src/main/java/forge/lda/LDAModelGenetrator.java b/forge-lda/src/main/java/forge/lda/LDAModelGenetrator.java new file mode 100644 index 00000000000..1e0ce8b9202 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/LDAModelGenetrator.java @@ -0,0 +1,389 @@ +package forge.lda; + +import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.base.Predicates; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; + +//import forge.GuiDesktop; +import forge.StaticData; +import forge.card.CardRules; +import forge.card.CardRulesPredicates; +import forge.deck.Deck; +import forge.deck.DeckFormat; +import forge.deck.io.Archetype; +import forge.deck.io.CardThemedLDAIO; +import forge.deck.io.DeckStorage; +import forge.lda.dataset.Dataset; +import forge.lda.lda.LDA; +import forge.game.GameFormat; +import forge.item.PaperCard; +import forge.localinstance.properties.ForgeConstants; +import forge.localinstance.properties.ForgePreferences; +import forge.model.FModel; +import forge.util.storage.IStorage; +import forge.util.storage.StorageImmediatelySerialized; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.tuple.Pair; + +import java.io.File; +import java.util.*; + +import static forge.lda.lda.inference.InferenceMethod.CGS; + +/** + * Created by maustin on 09/05/2017. + */ +public final class LDAModelGenetrator { + + public static Map>>>> ldaPools = new HashMap<>(); + public static Map> ldaArchetypes = new HashMap<>(); + + + public static final void main(String[] args){ + //GuiBase.setInterface(new GuiDesktop()); + FModel.initialize(null, new Function() { + @Override + public Void apply(ForgePreferences preferences) { + preferences.setPref(ForgePreferences.FPref.LOAD_CARD_SCRIPTS_LAZILY, false); + return null; + } + }); + initialize(); + } + + public static boolean initialize(){ + List formatStrings = new ArrayList<>(); + formatStrings.add(FModel.getFormats().getStandard().getName()); + formatStrings.add(FModel.getFormats().getPioneer().getName()); + formatStrings.add(FModel.getFormats().getModern().getName()); + formatStrings.add("Legacy"); + formatStrings.add("Vintage"); + formatStrings.add(DeckFormat.Commander.toString()); + + for (String formatString : formatStrings){ + if(!initializeFormat(formatString)){ + return false; + } + } + + return true; + } + + /** Try to load matrix .dat files, otherwise check for deck folders and build .dat, otherwise return false **/ + public static boolean initializeFormat(String format){ + Map>>> formatMap = CardThemedLDAIO.loadLDA(format); + List lda = CardThemedLDAIO.loadRawLDA(format); + if(formatMap==null) { + try { + if(lda==null) { + if (CardThemedLDAIO.getMatrixFolder(format).exists()) { + if (format.equals(FModel.getFormats().getStandard().getName())) { + lda = initializeFormat(FModel.getFormats().getStandard()); + } else if (format.equals(FModel.getFormats().getModern().getName())) { + lda = initializeFormat(FModel.getFormats().getModern()); + } else if (format != DeckFormat.Commander.toString()) { + lda = initializeFormat(FModel.getFormats().get(format)); + } + CardThemedLDAIO.saveRawLDA(format, lda); + } else { + return false; + } + } + if (format.equals(FModel.getFormats().getStandard().getName())) { + formatMap = loadFormat(FModel.getFormats().getStandard(), lda); + } else if (format.equals(FModel.getFormats().getModern().getName())) { + formatMap = loadFormat(FModel.getFormats().getModern(), lda); + } else if (format != DeckFormat.Commander.toString()) { + formatMap = loadFormat(FModel.getFormats().get(format), lda);; + } + CardThemedLDAIO.saveLDA(format, formatMap); + }catch (Exception e){ + e.printStackTrace(); + return false; + } + } + ldaPools.put(format, formatMap); + ldaArchetypes.put(format, lda); + return true; + } + + public static Map>>> loadFormat(GameFormat format,List lda) throws Exception{ + + List>> topics = new ArrayList<>(); + Set cards = new HashSet(); + for (int t = 0; t < lda.size(); ++t) { + List> topic = new ArrayList<>(); + Set topicCards = new HashSet<>(); + List> highRankVocabs = lda.get(t).getCardProbabilities(); + if (highRankVocabs.get(0).getRight()<=0.01d){ + continue; + } + System.out.print("t" + t + ": "); + int i = 0; + while (topic.size()<=40&&i=0.005d) { + topicCards.add(cardName); + } + System.out.println("[" + highRankVocabs.get(i).getLeft() + "," + highRankVocabs.get(i).getRight() + "],"); + topic.add(highRankVocabs.get(i)); + } + i++; + } + System.out.println(); + if(topic.size()>18) { + cards.addAll(topicCards); + topics.add(topic); + } + } + Map>>> cardTopicMap = new HashMap<>(); + for (String card:cards){ + List>> cardTopics = new ArrayList<>(); + for( List> topic:topics){ + if(topicContains(card,topic)){ + cardTopics.add(topic); + } + } + cardTopicMap.put(card,cardTopics); + } + return cardTopicMap; + } + + public static List initializeFormat(GameFormat format) throws Exception{ + Dataset dataset = new Dataset(format); + + //estimate number of topics to attempt to find using power law + final int numTopics = Float.valueOf(347f*dataset.getNumDocs()/(2892f + dataset.getNumDocs())).intValue(); + System.out.println("Num Topics = " + numTopics); + LDA lda = new LDA(0.1, 0.1, numTopics, dataset, CGS); + lda.run(); + System.out.println(lda.computePerplexity(dataset)); + + //sort decks by topic + Map> topicDecks = new HashMap<>(); + + int deckNum=0; + for(Deck deck: dataset.getBow().getLegalDecks()){ + double maxTheta = 0; + int mainTopic = 0; + for (int t = 0; t < lda.getNumTopics(); ++t){ + double theta = lda.getTheta(deckNum,t); + if (theta > maxTheta){ + maxTheta = theta; + mainTopic = t; + } + } + if(topicDecks.containsKey(mainTopic)){ + topicDecks.get(mainTopic).add(deck); + }else{ + List decks = new ArrayList<>(); + decks.add(deck); + topicDecks.put(mainTopic,decks); + } + ++deckNum; + } + + + List unfilteredTopics = new ArrayList<>(); + for (int t = 0; t < lda.getNumTopics(); ++t) { + List> highRankVocabs = lda.getVocabsSortedByPhi(t); + Double min = 1d; + for(Pair p:highRankVocabs){ + if(p.getRight()> topRankVocabs = new ArrayList<>(); + for(Pair p:highRankVocabs){ + if(p.getRight()>min){ + topRankVocabs.add(p); + } + } + + //generate names for topics + List decks = topicDecks.get(t); + if(decks==null){ + continue; + } + LinkedHashMap wordCounts = new LinkedHashMap<>(); + for( Deck deck: decks){ + String name = deck.getName().replaceAll(".* Version - ","").replaceAll(" \\((Modern|Pioneer|Standard|Legacy|Vintage), #[0-9]+\\)",""); + name = name.replaceAll("\\(Modern|Pioneer|Standard|Legacy|Vintage|Fuck|Shit|Cunt\\)",""); + String[] tokens = name.split(" "); + for(String rawtoken: tokens){ + String token = rawtoken.toLowerCase(); + if (token.matches("[0-9]+")) { + //skip just numbers as not useful + continue; + } + if(wordCounts.containsKey(token)){ + wordCounts.put(token, wordCounts.get(token)+1); + }else{ + wordCounts.put(token, 1); + } + } + } + Map sortedWordCounts = sortByValue(wordCounts); + + List topWords = new ArrayList<>(); + Iterator wordIterator = sortedWordCounts.keySet().iterator(); + while(wordIterator.hasNext() && topWords.size() < 3){ + topWords.add(wordIterator.next()); + } + StringJoiner sb = new StringJoiner(" "); + for(String word: wordCounts.keySet()){ + if(topWords.contains(word)){ + sb.add(word); + } + } + String deckName = sb.toString(); + System.out.println("============ " + deckName); + System.out.println(decks.toString()); + + unfilteredTopics.add(new Archetype(topRankVocabs,deckName,decks.size())); + } + Comparator archetypeComparator = new Comparator() { + @Override + public int compare(Archetype o1, Archetype o2) { + return o2.getDeckCount().compareTo(o1.getDeckCount()); + } + }; + + Collections.sort(unfilteredTopics,archetypeComparator); + return unfilteredTopics; + } + + + + private static Map sortByValue(Map map) { + List> list = new LinkedList<>(map.entrySet()); + Collections.sort(list, new Comparator() { + @SuppressWarnings("unchecked") + public int compare(Object o1, Object o2) { + return ((Comparable) ((Map.Entry) (o2)).getValue()).compareTo(((Map.Entry) (o1)).getValue()); + } + }); + + Map result = new LinkedHashMap<>(); + for (Iterator> it = list.iterator(); it.hasNext();) { + Map.Entry entry = (Map.Entry) it.next(); + result.put(entry.getKey(), entry.getValue()); + } + + return result; + } + + public static boolean topicContains(String card, List> topic){ + for(Pair pair:topic){ + if(pair.getLeft().equals(card)){ + return true; + } + } + return false; + } + + public static HashMap>> initializeCommanderFormat(){ + + IStorage decks = new StorageImmediatelySerialized("Generator", + new DeckStorage(new File(ForgeConstants.DECK_GEN_DIR,DeckFormat.Commander.toString()), + ForgeConstants.DECK_GEN_DIR, false), + true); + + //get all cards + final Iterable cards = Iterables.filter(FModel.getMagicDb().getCommonCards().getUniqueCards() + , Predicates.compose(Predicates.not(CardRulesPredicates.Presets.IS_BASIC_LAND_NOT_WASTES), PaperCard.FN_GET_RULES)); + List cardList = Lists.newArrayList(cards); + cardList.add(FModel.getMagicDb().getCommonCards().getCard("Wastes")); + Map cardIntegerMap = new HashMap<>(); + Map integerCardMap = new HashMap<>(); + Map legendIntegerMap = new HashMap<>(); + Map integerLegendMap = new HashMap<>(); + //generate lookups for cards to link card names to matrix columns + for (int i=0; i legends = Lists.newArrayList(Iterables.filter(cardList,Predicates.compose( + new Predicate() { + @Override + public boolean apply(CardRules rules) { + return DeckFormat.Commander.isLegalCommander(rules); + } + }, PaperCard.FN_GET_RULES))); + + //generate lookups for legends to link commander names to matrix rows + for (int i=0; i>> cardPools = new HashMap<>(); + for (PaperCard card:legends){ + int col=legendIntegerMap.get(card.getName()); + int[] distances = matrix[col]; + int max = (Integer) Collections.max(Arrays.asList(ArrayUtils.toObject(distances))); + if (max>0) { + List> deckPool=new ArrayList<>(); + for(int k=0;k0){ + deckPool.add(new AbstractMap.SimpleEntry(integerCardMap.get(k),matrix[col][k])); + } + } + cardPools.put(card.getName(), deckPool); + } + } + return cardPools; + } + + //update the matrix by incrementing the connectivity count for each card in the deck + public static void updateLegendMatrix(Deck deck, PaperCard legend, Map cardIntegerMap, + Map legendIntegerMap, int[][] matrix){ + for (PaperCard pairCard:Iterables.filter(deck.getMain().toFlatList(), + Predicates.compose(Predicates.not(CardRulesPredicates.Presets.IS_BASIC_LAND_NOT_WASTES), PaperCard.FN_GET_RULES))){ + if (!pairCard.getName().equals(legend.getName())){ + try { + int old = matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(pairCard.getName())]; + matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(pairCard.getName())] = old + 1; + }catch (NullPointerException ne){ + //Todo: Not sure what was failing here + ne.printStackTrace(); + } + } + + } + //add partner commanders to matrix + if(deck.getCommanders().size()>1){ + for(PaperCard partner:deck.getCommanders()){ + if(!partner.equals(legend)){ + int old = matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(partner.getName())]; + matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(partner.getName())] = old + 1; + } + } + } + } + +} diff --git a/forge-lda/src/main/java/forge/lda/dataset/BagOfWords.java b/forge-lda/src/main/java/forge/lda/dataset/BagOfWords.java new file mode 100644 index 00000000000..d142496c8e5 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/dataset/BagOfWords.java @@ -0,0 +1,210 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.dataset; + +import forge.deck.Deck; +import forge.deck.io.DeckStorage; +import forge.game.GameFormat; +import forge.item.PaperCard; +import forge.localinstance.properties.ForgeConstants; +import forge.util.storage.IStorage; +import forge.util.storage.StorageImmediatelySerialized; + +import java.io.*; +import java.util.*; + +/** + * This class is immutable. + */ +public final class BagOfWords { + + private final int numDocs; + private final int numVocabs; + private final int numNNZ; + private final int numWords; + private List legalDecks; + + public Vocabularies getVocabs() { + return vocabs; + } + + private final Vocabularies vocabs; + + // docID -> the vocabs sequence in the doc + private Map> words; + + // docID -> the doc length + private Map docLength; + + + /** + * Read the bag-of-words dataset. + * @throws FileNotFoundException + * @throws IOException + * @throws Exception + * @throws NullPointerException filePath is null + */ + public BagOfWords(GameFormat format) throws FileNotFoundException, IOException, Exception { + IStorage decks = new StorageImmediatelySerialized("Generator", new DeckStorage(new File(ForgeConstants.DECK_GEN_DIR+ForgeConstants.PATH_SEPARATOR+format.getName()), + ForgeConstants.DECK_GEN_DIR, false), + true); + + Set cardSet = new HashSet<>(); + legalDecks = new ArrayList<>(); + for(Deck deck:decks){ + try { + if (format.isDeckLegal(deck) && deck.getMain().toFlatList().size() == 60) { + legalDecks.add(deck); + for (PaperCard card : deck.getMain().toFlatList()) { + cardSet.add(card); + } + } + }catch(Exception e){ + System.out.println("Skipping deck "+deck.getName()); + } + } + List cardList = new ArrayList<>(cardSet); + + this.words = new HashMap<>(); + this.docLength = new HashMap<>(); + ArrayList vocabList = new ArrayList(); + + int numDocs = legalDecks.size(); + int numVocabs = cardList.size(); + int numNNZ = 0; + int numWords = 0; + + Map cardIntegerMap = new HashMap<>(); + Map integerCardMap = new HashMap<>(); + for (int i=0; i cardNumbers = new ArrayList<>(); + for(PaperCard card:deck.getMain().toFlatList()){ + if(cardIntegerMap.get(card.getName()) == null){ + System.out.println(card.getName() + " is missing!!"); + } + cardNumbers.add(cardIntegerMap.get(card.getName())); + } + words.put(deckID,cardNumbers); + numWords+=cardNumbers.size(); + docLength.put(deckID,cardNumbers.size()); + deckID ++; + } + + /*String s = null; + while ((s = reader.readLine()) != null) { + List numbers + = Arrays.asList(s.split(" ")).stream().map(Integer::parseInt).collect(Collectors.toList()); + + if (numbers.size() == 1) { + if (headerCount == 2) numDocs = numbers.get(0); + else if (headerCount == 1) numVocabs = numbers.get(0); + else if (headerCount == 0) numNNZ = numbers.get(0); + --headerCount; + continue; + } + else if (numbers.size() == 3) { + final int docID = numbers.get(0); + final int vocabID = numbers.get(1); + final int count = numbers.get(2); + + // Set up the words container + if (!words.containsKey(docID)) { + words.put(docID, new ArrayList<>()); + } + for (int c = 0; c < count; ++c) { + words.get(docID).add(vocabID); + } + + // Set up the doc length map + Optional currentCount + = Optional.ofNullable(docLength.putIfAbsent(docID, count)); + currentCount.ifPresent(c -> docLength.replace(docID, c + count)); + + numWords += count; + } + else { + throw new Exception("Invalid dataset form was detected."); + } + } + reader.close();*/ + + this.numDocs = numDocs; + this.numVocabs = numVocabs; + this.numNNZ = numNNZ; + this.numWords = numWords; + + System.out.println("Num Decks" + this.numDocs); + System.out.println("Num Vocabs" + this.numVocabs); + System.out.println("Num NNZ" + this.numNNZ); + System.out.println("Num Cards" + this.numWords); + } + + public int getNumDocs() { + return numDocs; + } + + /** + * Get the length of the document. + * @param docID + * @return length of the document + * @throws IllegalArgumentException docID <= 0 || #documents < docID + */ + public int getDocLength(int docID) { + if (docID < 0 || getNumDocs() < docID) { + throw new IllegalArgumentException(); + } + + return docLength.get(docID); + } + + /** + * Get the unmodifiable list of words in the document. + * @param docID + * @return the unmodifiable list of words + * @throws IllegalArgumentException docID <= 0 || #documents < docID + */ + public List getWords(final int docID) { + if (docID < 0 || getNumDocs() < docID) { + throw new IllegalArgumentException(); + } + return Collections.unmodifiableList(words.get(docID)); + } + + public int getNumVocabs() { + return numVocabs; + } + + public int getNumNNZ() { + return numNNZ; + } + + public int getNumWords() { + return numWords; + } + + public List getLegalDecks() { return legalDecks; } +} + diff --git a/forge-lda/src/main/java/forge/lda/dataset/Dataset.java b/forge-lda/src/main/java/forge/lda/dataset/Dataset.java new file mode 100644 index 00000000000..0a35f73bab4 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/dataset/Dataset.java @@ -0,0 +1,80 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.dataset; + +import forge.game.GameFormat; + +import java.util.List; + +public class Dataset { + private BagOfWords bow; + private Vocabularies vocabs; + + public Dataset(GameFormat format) throws Exception { + bow = new BagOfWords(format); + vocabs = bow.getVocabs(); + } + + public Dataset(BagOfWords bow) { + this.bow = bow; + this.vocabs = null; + } + + public BagOfWords getBow() { + return bow; + } + + public int getNumDocs() { + return bow.getNumDocs(); + } + + public int getDocLength(int docID) { + return bow.getDocLength(docID); + } + + public List getWords(int docID) { + return bow.getWords(docID); + } + + public int getNumVocabs() { + return bow.getNumVocabs(); + } + + public int getNumNNZ() { + return bow.getNumNNZ(); + } + + public int getNumWords() { + return bow.getNumWords(); + } + + public Vocabulary get(int id) { + return vocabs.get(id); + } + + public int size() { + return vocabs.size(); + } + + public Vocabularies getVocabularies() { + return vocabs; + } + + public List getVocabularyList() { + return vocabs.getVocabularyList(); + } +} diff --git a/forge-lda/src/main/java/forge/lda/dataset/Vocabularies.java b/forge-lda/src/main/java/forge/lda/dataset/Vocabularies.java new file mode 100644 index 00000000000..4d32c8c19a2 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/dataset/Vocabularies.java @@ -0,0 +1,40 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.dataset; + +import java.util.Collections; +import java.util.List; + +public class Vocabularies { + private List vocabs; + + public Vocabularies(List vocabs) { + this.vocabs = vocabs; + } + + public Vocabulary get(int id) { + return vocabs.get(id); + } + + public int size() { + return vocabs.size(); + } + + public List getVocabularyList() { + return Collections.unmodifiableList(vocabs); + } +} diff --git a/forge-lda/src/main/java/forge/lda/dataset/Vocabulary.java b/forge-lda/src/main/java/forge/lda/dataset/Vocabulary.java new file mode 100644 index 00000000000..2420acf7f41 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/dataset/Vocabulary.java @@ -0,0 +1,37 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.dataset; + +public class Vocabulary { + private final int id; + private String vocabulary; + + public Vocabulary(int id, String vocabulary) { + if (vocabulary == null) throw new NullPointerException(); + this.id = id; + this.vocabulary = vocabulary; + } + + public int id() { + return id; + } + + @Override + public String toString() { + return vocabulary; + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/Alpha.java b/forge-lda/src/main/java/forge/lda/lda/Alpha.java new file mode 100644 index 00000000000..68880d18b2b --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/Alpha.java @@ -0,0 +1,46 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +class Alpha { + private List alphas; + + Alpha(double alpha, int numTopics) { + if (alpha <= 0.0 || numTopics <= 0) { + throw new IllegalArgumentException(); + } + this.alphas = Stream.generate(() -> alpha) + .limit(numTopics) + .collect(Collectors.toList()); + } + + double get(int i) { + return alphas.get(i); + } + + void set(int i, double value) { + alphas.set(i, value); + } + + double sum() { + return alphas.stream().reduce(Double::sum).get(); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/Beta.java b/forge-lda/src/main/java/forge/lda/lda/Beta.java new file mode 100644 index 00000000000..8471aadb8a3 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/Beta.java @@ -0,0 +1,54 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +class Beta { + private List betas; + + Beta(double beta, int numVocabs) { + if (beta <= 0.0 || numVocabs <= 0) { + throw new IllegalArgumentException(); + } + this.betas = Stream.generate(() -> beta) + .limit(numVocabs) + .collect(Collectors.toList()); + } + + Beta(double beta) { + if (beta <= 0.0) { + throw new IllegalArgumentException(); + } + this.betas = Arrays.asList(beta); + } + + double get() { + return get(0); + } + + double get(int i) { + return betas.get(i); + } + + void set(int i, double value) { + betas.set(i, value); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/Hyperparameters.java b/forge-lda/src/main/java/forge/lda/lda/Hyperparameters.java new file mode 100644 index 00000000000..a468535de7b --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/Hyperparameters.java @@ -0,0 +1,60 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda; + +class Hyperparameters { + private Alpha alpha; + private Beta beta; + + Hyperparameters(double alpha, double beta, int numTopics, int numVocabs) { + this.alpha = new Alpha(alpha, numTopics); + this.beta = new Beta(beta, numVocabs); + } + + Hyperparameters(double alpha, double beta, int numTopics) { + this.alpha = new Alpha(alpha, numTopics); + this.beta = new Beta(beta); + } + + double alpha(int i) { + return alpha.get(i); + } + + double sumAlpha() { + return alpha.sum(); + } + + double beta() { + return beta.get(); + } + + double beta(int i) { + return beta.get(i); + } + + void setAlpha(int i, double value) { + alpha.set(i, value); + } + + void setBeta(int i, double value) { + beta.set(i, value); + } + + void setBeta(double value) { + beta.set(0, value); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/LDA.java b/forge-lda/src/main/java/forge/lda/lda/LDA.java new file mode 100644 index 00000000000..f07a19749a0 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/LDA.java @@ -0,0 +1,183 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda; + +import java.util.List; + +import forge.lda.dataset.BagOfWords; +import forge.lda.dataset.Dataset; +import forge.lda.dataset.Vocabularies; + +import forge.lda.lda.inference.Inference; +import forge.lda.lda.inference.InferenceFactory; +import forge.lda.lda.inference.InferenceMethod; +import forge.lda.lda.inference.InferenceProperties; + +import org.apache.commons.lang3.tuple.Pair; + +public class LDA { + private Hyperparameters hyperparameters; + private final int numTopics; + private Dataset dataset; + private final Inference inference; + private InferenceProperties properties; + private boolean trained; + + /** + * @param alpha doc-topic hyperparameter + * @param beta topic-vocab hyperparameter + * @param numTopics the number of topics + * @param dataset dataset + * @param bow bag-of-words + * @param method inference method + */ + public LDA(final double alpha, final double beta, final int numTopics, + final Dataset dataset, InferenceMethod method) { + this(alpha, beta, numTopics, dataset, method, InferenceProperties.PROPERTIES_FILE_NAME); + } + + LDA(final double alpha, final double beta, final int numTopics, + final Dataset dataset, InferenceMethod method, String propertiesFilePath) { + this.hyperparameters = new Hyperparameters(alpha, beta, numTopics); + this.numTopics = numTopics; + this.dataset = dataset; + this.inference = InferenceFactory.getInstance(method); + this.properties = new InferenceProperties(); + this.trained = false; + + properties.setSeed(123L); + properties.setNumIteration(100); + } + + /** + * Get the vocabulary from its ID. + * @param vocabID + * @return the vocabulary + * @throws IllegalArgumentException vocabID <= 0 || the number of vocabularies < vocabID + */ + public String getVocab(int vocabID) { + if (vocabID < 0 || dataset.getNumVocabs() < vocabID) { + throw new IllegalArgumentException(); + } + return dataset.get(vocabID).toString(); + } + + /** + * Run model inference. + */ + public void run() { + if (properties == null) inference.setUp(this); + else inference.setUp(this, properties); + inference.run(); + trained = true; + } + + /** + * Get hyperparameter alpha corresponding to topic. + * @param topic + * @return alpha corresponding to topicID + * @throws ArrayIndexOutOfBoundsException topic < 0 || #topics <= topic + */ + public double getAlpha(final int topic) { + if (topic < 0 || numTopics < topic) { + throw new ArrayIndexOutOfBoundsException(topic); + } + return hyperparameters.alpha(topic); + } + + public double getSumAlpha() { + return hyperparameters.sumAlpha(); + } + + public double getBeta() { + return hyperparameters.beta(); + } + + public int getNumTopics() { + return numTopics; + } + + public BagOfWords getBow() { + return dataset.getBow(); + } + + /** + * Get the value of doc-topic probability \theta_{docID, topicID}. + * @param docID + * @param topicID + * @return the value of doc-topic probability + * @throws IllegalArgumentException docID <= 0 || #docs < docID || topicID < 0 || #topics <= topicID + * @throws IllegalStateException call this method when the inference has not been finished yet + */ + public double getTheta(final int docID, final int topicID) { + if (docID < 0 || dataset.getNumDocs() < docID + || topicID < 0 || numTopics < topicID) { + throw new IllegalArgumentException(); + } + if (!trained) { + throw new IllegalStateException(); + } + + return inference.getTheta(docID, topicID); + } + + /** + * Get the value of topic-vocab probability \phi_{topicID, vocabID}. + * @param topicID + * @param vocabID + * @return the value of topic-vocab probability + * @throws IllegalArgumentException topicID < 0 || #topics <= topicID || vocabID <= 0 + * @throws IllegalStateException call this method when the inference has not been finished yet + */ + public double getPhi(final int topicID, final int vocabID) { + if (topicID < 0 || numTopics < topicID || vocabID < 0) { + throw new IllegalArgumentException(); + } + if (!trained) { + throw new IllegalStateException(); + } + + return inference.getPhi(topicID, vocabID); + } + + public Vocabularies getVocabularies() { + return dataset.getVocabularies(); + } + + public List> getVocabsSortedByPhi(int topicID) { + return inference.getVocabsSortedByPhi(topicID); + } + + /** + * Compute the perplexity of trained LDA for the test bag-of-words dataset. + * @param testDataset + * @return the perplexity for the test bag-of-words dataset + */ + public double computePerplexity(Dataset testDataset) { + double loglikelihood = 0.0; + for (int d = 0; d < testDataset.getNumDocs(); ++d) { + for (Integer w : testDataset.getWords(d)) { + double sum = 0.0; + for (int t = 0; t < getNumTopics(); ++t) { + sum += getTheta(d, t) * getPhi(t, w.intValue()); + } + loglikelihood += Math.log(sum); + } + } + return Math.exp(-loglikelihood / testDataset.getNumWords()); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/Inference.java b/forge-lda/src/main/java/forge/lda/lda/inference/Inference.java new file mode 100644 index 00000000000..89b6ffdf2d1 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/Inference.java @@ -0,0 +1,62 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference; + +import java.util.List; + +import forge.lda.lda.LDA; + +import org.apache.commons.lang3.tuple.Pair; + +public interface Inference { + /** + * Set up for inference. + * @param lda + */ + public void setUp(LDA lda); + + /** + * Set up for inference. + * The configuration is read from properties class. + * @param lda + * @param properties + */ + public void setUp(LDA lda, InferenceProperties properties); + + /** + * Run model inference. + */ + public void run(); + + /** + * Get the value of doc-topic probability \theta_{docID, topicID}. + * @param docID + * @param topicID + * @return the value of doc-topic probability + */ + public double getTheta(final int docID, final int topicID); + + /** + * Get the value of topic-vocab probability \phi_{topicID, vocabID}. + * @param topicID + * @param vocabID + * @return the value of topic-vocab probability + */ + public double getPhi(final int topicID, final int vocabID); + + public List> getVocabsSortedByPhi(int topicID); +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/InferenceFactory.java b/forge-lda/src/main/java/forge/lda/lda/inference/InferenceFactory.java new file mode 100644 index 00000000000..bc88e8d056d --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/InferenceFactory.java @@ -0,0 +1,37 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference; + + +public class InferenceFactory { + private InferenceFactory() {} + + /** + * Get the LDAInference instance specified by the argument + * @param method + * @return the instance which implements LDAInference + */ + public static Inference getInstance(InferenceMethod method) { + Inference clazz = null; + try { + clazz = (Inference)Class.forName(method.toString()).newInstance(); + } catch (ReflectiveOperationException roe) { + roe.printStackTrace(); + } + return clazz; + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/InferenceMethod.java b/forge-lda/src/main/java/forge/lda/lda/inference/InferenceMethod.java new file mode 100644 index 00000000000..a1ca4e968a3 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/InferenceMethod.java @@ -0,0 +1,36 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference; + +import forge.lda.lda.inference.internal.CollapsedGibbsSampler; + +public enum InferenceMethod { + CGS(CollapsedGibbsSampler.class.getName()), + // more + ; + + private String className; + + private InferenceMethod(final String className) { + this.className = className; + } + + @Override + public String toString() { + return className; + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/InferenceProperties.java b/forge-lda/src/main/java/forge/lda/lda/inference/InferenceProperties.java new file mode 100644 index 00000000000..c550cc3277d --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/InferenceProperties.java @@ -0,0 +1,70 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.util.Properties; + +public class InferenceProperties { + public static String PROPERTIES_FILE_NAME = "lda.properties"; + + PropertiesLoader loader = new PropertiesLoader(); + private Properties properties; + + public InferenceProperties() { + this.properties = new Properties(); + } + + public void setSeed(Long seed){ + properties.setProperty("seed",seed.toString()); + } + + public void setNumIteration(Integer numIteration){ + properties.setProperty("numIteration",numIteration.toString()); + } + + /** + * Load properties. + * @param fileName + * @throws IOException + * @throws NullPointerException fileName is null + */ + public void load(String fileName) throws IOException { + if (fileName == null) throw new NullPointerException(); + InputStream stream = loader.getInputStream(fileName); + if (stream == null) throw new NullPointerException(); + properties.load(stream); + } + + public Long seed() { + return Long.parseLong(properties.getProperty("seed")); + } + + public Integer numIteration() { + return Integer.parseInt(properties.getProperty("numIteration")); + } +} + +class PropertiesLoader { + public InputStream getInputStream(String fileName) throws FileNotFoundException { + if (fileName == null) throw new NullPointerException(); + return new FileInputStream(fileName); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/AssignmentCounter.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/AssignmentCounter.java new file mode 100644 index 00000000000..9009ed78d34 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/AssignmentCounter.java @@ -0,0 +1,65 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +class AssignmentCounter { + private List counter; + + AssignmentCounter(int size) { + if (size <= 0) throw new IllegalArgumentException(); + this.counter = IntStream.generate(() -> 0) + .limit(size) + .boxed() + .collect(Collectors.toList()); + } + + int size() { + return counter.size(); + } + + int get(int id) { + if (id < 0 || counter.size() <= id) { + throw new IllegalArgumentException(); + } + return counter.get(id); + } + + int getSum() { + return counter.stream().reduce(Integer::sum).get(); + } + + void increment(int id) { + if (id < 0 || counter.size() <= id) { + throw new IllegalArgumentException(); + } + counter.set(id, counter.get(id) + 1); + } + + void decrement(int id) { + if (id < 0 || counter.size() <= id) { + throw new IllegalArgumentException(); + } + if (counter.get(id) == 0) { + throw new IllegalStateException(); + } + counter.set(id, counter.get(id) - 1); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/CollapsedGibbsSampler.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/CollapsedGibbsSampler.java new file mode 100644 index 00000000000..60f969fecc4 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/CollapsedGibbsSampler.java @@ -0,0 +1,220 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.IntStream; + +import forge.lda.lda.LDA; +import forge.lda.lda.inference.Inference; +import forge.lda.lda.inference.InferenceProperties; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution; +import org.apache.commons.math3.distribution.IntegerDistribution; + +import forge.lda.dataset.Vocabulary; + +public class CollapsedGibbsSampler implements Inference { + private LDA lda; + private Topics topics; + private Documents documents; + private int numIteration; + + private static final long DEFAULT_SEED = 0L; + private static final int DEFAULT_NUM_ITERATION = 100; + + // ready for Gibbs sampling + private boolean ready; + + public CollapsedGibbsSampler() { + ready = false; + } + + @Override + public void setUp(LDA lda, InferenceProperties properties) { + if (properties == null) { + setUp(lda); + return; + } + + this.lda = lda; + initialize(this.lda); + + final long seed = properties.seed() != null ? properties.seed() : DEFAULT_SEED; + initializeTopicAssignment(seed); + + this.numIteration + = properties.numIteration() != null ? properties.numIteration() : DEFAULT_NUM_ITERATION; + this.ready = true; + } + + @Override + public void setUp(LDA lda) { + if (lda == null) throw new NullPointerException(); + + this.lda = lda; + + initialize(this.lda); + initializeTopicAssignment(DEFAULT_SEED); + + this.numIteration = DEFAULT_NUM_ITERATION; + this.ready = true; + } + + private void initialize(LDA lda) { + assert lda != null; + this.topics = new Topics(lda); + this.documents = new Documents(lda); + } + + public boolean isReady() { + return ready; + } + + public int getNumIteration() { + return numIteration; + } + + public void setNumIteration(final int numIteration) { + this.numIteration = numIteration; + } + + @Override + public void run() { + if (!ready) { + throw new IllegalStateException("instance has not set up yet"); + } + + for (int i = 1; i <= numIteration; ++i) { + System.out.println("Iteration " + i + "."); + runSampling(); + } + } + + /** + * Run collapsed Gibbs sampling [Griffiths and Steyvers 2004]. + */ + void runSampling() { + for (Document d : documents.getDocuments()) { + for (int w = 0; w < d.getDocLength(); ++w) { + final Topic oldTopic = topics.get(d.getTopicID(w)); + d.decrementTopicCount(oldTopic.id()); + + final Vocabulary v = d.getVocabulary(w); + oldTopic.decrementVocabCount(v.id()); + + IntegerDistribution distribution + = getFullConditionalDistribution(lda.getNumTopics(), d.id(), v.id()); + + final int newTopicID = distribution.sample(); + d.setTopicID(w, newTopicID); + + d.incrementTopicCount(newTopicID); + final Topic newTopic = topics.get(newTopicID); + newTopic.incrementVocabCount(v.id()); + } + } + } + + /** + * Get the full conditional distribution over topics. + * docID and vocabID are passed to this distribution for parameters. + * @param numTopics + * @param docID + * @param vocabID + * @return the integer distribution over topics + */ + IntegerDistribution getFullConditionalDistribution(final int numTopics, final int docID, final int vocabID) { + int[] topics = IntStream.range(0, numTopics).toArray(); + double[] probabilities = Arrays.stream(topics) + .mapToDouble(t -> getTheta(docID, t) * getPhi(t, vocabID)) + .toArray(); + return new EnumeratedIntegerDistribution(topics, probabilities); + } + + /** + * Initialize the topic assignment. + * @param seed the seed of a pseudo random number generator + */ + void initializeTopicAssignment(final long seed) { + documents.initializeTopicAssignment(topics, seed); + } + + /** + * Get the count of topicID assigned to docID. + * @param docID + * @param topicID + * @return the count of topicID assigned to docID + */ + int getDTCount(final int docID, final int topicID) { + if (!ready) throw new IllegalStateException(); + if (docID <= 0 || lda.getBow().getNumDocs() < docID + || topicID < 0 || lda.getNumTopics() <= topicID) { + throw new IllegalArgumentException(); + } + return documents.getTopicCount(docID, topicID); + } + + /** + * Get the count of vocabID assigned to topicID. + * @param topicID + * @param vocabID + * @return the count of vocabID assigned to topicID + */ + int getTVCount(final int topicID, final int vocabID) { + if (!ready) throw new IllegalStateException(); + if (topicID < 0 || lda.getNumTopics() <= topicID || vocabID <= 0) { + throw new IllegalArgumentException(); + } + final Topic topic = topics.get(topicID); + return topic.getVocabCount(vocabID); + } + + /** + * Get the sum of counts of vocabs assigned to topicID. + * This is the sum of topic-vocab count over vocabs. + * @param topicID + * @return the sum of counts of vocabs assigned to topicID + * @throws IllegalArgumentException topicID < 0 || #topic <= topicID + */ + int getTSumCount(final int topicID) { + if (topicID < 0 || lda.getNumTopics() <= topicID) { + throw new IllegalArgumentException(); + } + final Topic topic = topics.get(topicID); + return topic.getSumCount(); + } + + @Override + public double getTheta(final int docID, final int topicID) { + if (!ready) throw new IllegalStateException(); + return documents.getTheta(docID, topicID, lda.getAlpha(topicID), lda.getSumAlpha()); + } + + @Override + public double getPhi(int topicID, int vocabID) { + if (!ready) throw new IllegalStateException(); + return topics.getPhi(topicID, vocabID, lda.getBeta()); + } + + @Override + public List> getVocabsSortedByPhi(int topicID) { + return topics.getVocabsSortedByPhi(topicID, lda.getVocabularies(), lda.getBeta()); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/Document.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Document.java new file mode 100644 index 00000000000..fee03880408 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Document.java @@ -0,0 +1,86 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.List; + +import forge.lda.dataset.Vocabulary; + +public class Document { + private final int id; + private TopicCounter topicCount; + private Words words; + private TopicAssignment assignment; + + Document(int id, int numTopics, List words) { + if (id < 0 || numTopics <= 0) throw new IllegalArgumentException(); + this.id = id; + this.topicCount = new TopicCounter(numTopics); + this.words = new Words(words); + this.assignment = new TopicAssignment(); + } + + int id() { + return id; + } + + int getTopicCount(int topicID) { + return topicCount.getTopicCount(topicID); + } + + int getDocLength() { + return words.getNumWords(); + } + + void incrementTopicCount(int topicID) { + topicCount.incrementTopicCount(topicID); + } + + void decrementTopicCount(int topicID) { + topicCount.decrementTopicCount(topicID); + } + + void initializeTopicAssignment(long seed) { + assignment.initialize(getDocLength(), topicCount.size(), seed); + for (int w = 0; w < getDocLength(); ++w) { + incrementTopicCount(assignment.get(w)); + } + } + + int getTopicID(int wordID) { + return assignment.get(wordID); + } + + void setTopicID(int wordID, int topicID) { + assignment.set(wordID, topicID); + } + + Vocabulary getVocabulary(int wordID) { + return words.get(wordID); + } + + List getWords() { + return words.getWords(); + } + + double getTheta(int topicID, double alpha, double sumAlpha) { + if (topicID < 0 || alpha <= 0.0 || sumAlpha <= 0.0) { + throw new IllegalArgumentException(); + } + return (getTopicCount(topicID) + alpha) / (getDocLength() + sumAlpha); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/Documents.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Documents.java new file mode 100644 index 00000000000..8dafbab0347 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Documents.java @@ -0,0 +1,100 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import forge.lda.lda.LDA; +import forge.lda.dataset.BagOfWords; +import forge.lda.dataset.Vocabularies; +import forge.lda.dataset.Vocabulary; + +class Documents { + private List documents; + + Documents(LDA lda) { + if (lda == null) throw new NullPointerException(); + + documents = new ArrayList<>(); + for (int d = 0; d < lda.getBow().getNumDocs(); ++d) { + List vocabList = getVocabularyList(d, lda.getBow(), lda.getVocabularies()); + Document doc = new Document(d, lda.getNumTopics(), vocabList); + documents.add(doc); + } + } + + List getVocabularyList(int docID, BagOfWords bow, Vocabularies vocabs) { + assert docID > 0 && bow != null && vocabs != null; + //System.out.println(docID); + //System.out.println(bow.getWords(docID).toString()); + return bow.getWords(docID).stream() + .map(id -> vocabs.get(id)) + .collect(Collectors.toList()); + } + + int getTopicID(int docID, int wordID) { + return documents.get(docID).getTopicID(wordID); + } + + void setTopicID(int docID, int wordID, int topicID) { + documents.get(docID).setTopicID(wordID, topicID); + } + + Vocabulary getVocab(int docID, int wordID) { + return documents.get(docID).getVocabulary(wordID); + } + + List getWords(int docID) { + return documents.get(docID).getWords(); + } + + List getDocuments() { + return Collections.unmodifiableList(documents); + } + + void incrementTopicCount(int docID, int topicID) { + documents.get(docID).incrementTopicCount(topicID); + } + + void decrementTopicCount(int docID, int topicID) { + documents.get(docID).decrementTopicCount(topicID); + } + + int getTopicCount(int docID, int topicID) { + return documents.get(docID).getTopicCount(topicID); + } + + double getTheta(int docID, int topicID, double alpha, double sumAlpha) { + if (docID < 0 || documents.size() < docID) throw new IllegalArgumentException(); + return documents.get(docID).getTheta(topicID, alpha, sumAlpha); + } + + void initializeTopicAssignment(Topics topics, long seed) { + for (Document d : getDocuments()) { + d.initializeTopicAssignment(seed); + for (int w = 0; w < d.getDocLength(); ++w) { + final int topicID = d.getTopicID(w); + final Topic topic = topics.get(topicID); + final Vocabulary vocab = d.getVocabulary(w); + topic.incrementVocabCount(vocab.id()); + } + } + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/Topic.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Topic.java new file mode 100644 index 00000000000..a54c1f249ff --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Topic.java @@ -0,0 +1,55 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +class Topic { + private final int id; + private final int numVocabs; + private VocabularyCounter counter; + + Topic(int id, int numVocabs) { + if (id < 0 || numVocabs <= 0) throw new IllegalArgumentException(); + this.id = id; + this.numVocabs = numVocabs; + this.counter = new VocabularyCounter(numVocabs); + } + + int id() { + return id; + } + + int getVocabCount(int vocabID) { + return counter.getVocabCount(vocabID); + } + + int getSumCount() { + return counter.getSumCount(); + } + + void incrementVocabCount(int vocabID) { + counter.incrementVocabCount(vocabID); + } + + void decrementVocabCount(int vocabID) { + counter.decrementVocabCount(vocabID); + } + + double getPhi(int vocabID, double beta) { + if (vocabID < 0 || beta <= 0) throw new IllegalArgumentException(); + return (getVocabCount(vocabID) + beta) / (getSumCount() + beta * numVocabs); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicAssignment.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicAssignment.java new file mode 100644 index 00000000000..8d97bd709e6 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicAssignment.java @@ -0,0 +1,60 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; + +class TopicAssignment { + private List topicAssignment; + private boolean ready; + + TopicAssignment() { + topicAssignment = new ArrayList<>(); + ready = false; + } + + void set(int wordID, int topicID) { + if (!ready) throw new IllegalStateException(); + if (wordID < 0 || topicAssignment.size() <= wordID || topicID < 0) { + throw new IllegalArgumentException(); + } + topicAssignment.set(wordID, topicID); + } + + int get(int wordID) { + if (!ready) throw new IllegalStateException(); + if (wordID < 0 || topicAssignment.size() <= wordID) { + throw new IllegalArgumentException(); + } + return topicAssignment.get(wordID); + } + + void initialize(int docLength, int numTopics, long seed) { + if (docLength <= 0 || numTopics <= 0) { + throw new IllegalArgumentException(); + } + + Random random = new Random(seed); + topicAssignment = random.ints(docLength, 0, numTopics) + .boxed() + .collect(Collectors.toList()); + ready = true; + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicCounter.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicCounter.java new file mode 100644 index 00000000000..7a04437cd29 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicCounter.java @@ -0,0 +1,45 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +class TopicCounter { + private AssignmentCounter topicCount; + + TopicCounter(int numTopics) { + this.topicCount = new AssignmentCounter(numTopics); + } + + int getTopicCount(int topicID) { + return topicCount.get(topicID); + } + + int getDocLength() { + return topicCount.getSum(); + } + + void incrementTopicCount(int topicID) { + topicCount.increment(topicID); + } + + void decrementTopicCount(int topicID) { + topicCount.decrement(topicID); + } + + int size() { + return topicCount.size(); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/Topics.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Topics.java new file mode 100644 index 00000000000..ec042d0a1d3 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Topics.java @@ -0,0 +1,85 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import forge.lda.lda.LDA; +import forge.lda.dataset.Vocabularies; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; + +class Topics { + private List topics; + + Topics(LDA lda) { + if (lda == null) throw new NullPointerException(); + + topics = new ArrayList<>(); + for (int t = 0; t < lda.getNumTopics(); ++t) { + topics.add(new Topic(t, lda.getBow().getNumVocabs())); + } + } + + int numTopics() { + return topics.size(); + } + + Topic get(int id) { + return topics.get(id); + } + + int getVocabCount(int topicID, int vocabID) { + return topics.get(topicID).getVocabCount(vocabID); + } + + int getSumCount(int topicID) { + return topics.get(topicID).getSumCount(); + } + + void incrementVocabCount(int topicID, int vocabID) { + topics.get(topicID).incrementVocabCount(vocabID); + } + + void decrementVocabCount(int topicID, int vocabID) { + topics.get(topicID).decrementVocabCount(vocabID); + } + + double getPhi(int topicID, int vocabID, double beta) { + if (topicID < 0 || topics.size() <= topicID) throw new IllegalArgumentException(); + return topics.get(topicID).getPhi(vocabID, beta); + } + + List> getVocabsSortedByPhi(int topicID, Vocabularies vocabs, final double beta) { + if (topicID < 0 || topics.size() <= topicID || vocabs == null || beta <= 0.0) { + throw new IllegalArgumentException(); + } + + Topic topic = topics.get(topicID); + List> vocabProbPairs + = vocabs.getVocabularyList() + .stream() + .map(v -> new ImmutablePair(v.toString(), topic.getPhi(v.id(), beta))) + .sorted((p1, p2) -> Double.compare(p2.getRight(), p1.getRight())) + .collect(Collectors.toList()); + return Collections.unmodifiableList(vocabProbPairs); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/VocabularyCounter.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/VocabularyCounter.java new file mode 100644 index 00000000000..00e067c4998 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/VocabularyCounter.java @@ -0,0 +1,46 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +class VocabularyCounter { + private AssignmentCounter vocabCount; + private int sumCount; + + VocabularyCounter(int numVocabs) { + this.vocabCount = new AssignmentCounter(numVocabs); + this.sumCount = 0; + } + + int getVocabCount(int vocabID) { + if (vocabCount.size() < vocabID) return 0; + else return vocabCount.get(vocabID); + } + + int getSumCount() { + return sumCount; + } + + void incrementVocabCount(int vocabID) { + vocabCount.increment(vocabID); + ++sumCount; + } + + void decrementVocabCount(int vocabID) { + vocabCount.decrement(vocabID); + --sumCount; + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/Words.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Words.java new file mode 100644 index 00000000000..11a1c691477 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Words.java @@ -0,0 +1,46 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.Collections; +import java.util.List; + +import forge.lda.dataset.Vocabulary; + +class Words { + private List words; + + Words(List words) { + if (words == null) throw new NullPointerException(); + this.words = words; + } + + int getNumWords() { + return words.size(); + } + + Vocabulary get(int id) { + if (id < 0 || words.size() <= id) { + throw new IllegalArgumentException(); + } + return words.get(id); + } + + List getWords() { + return Collections.unmodifiableList(words); + } +} diff --git a/pom.xml b/pom.xml index a746b57d9d8..10be1f6fc46 100644 --- a/pom.xml +++ b/pom.xml @@ -203,6 +203,7 @@ forge-gui-mobile-dev forge-gui-android forge-gui-ios + forge-lda