diff --git a/forge-gui/src/main/java/forge/deck/ArchetypeDeckGenerator.java b/forge-gui/src/main/java/forge/deck/ArchetypeDeckGenerator.java index 37f69fa44e8..f53e5f2597b 100644 --- a/forge-gui/src/main/java/forge/deck/ArchetypeDeckGenerator.java +++ b/forge-gui/src/main/java/forge/deck/ArchetypeDeckGenerator.java @@ -17,9 +17,9 @@ import forge.model.FModel; public class ArchetypeDeckGenerator extends DeckProxy implements Comparable { public static List getMatrixDecks(GameFormat format, boolean isForAi){ final List decks = new ArrayList<>(); - for(Archetype archetype: CardArchetypeLDAGenerator.ldaArchetypes.get(format.getName())) { - decks.add(new ArchetypeDeckGenerator(archetype, format, isForAi)); - } + for(Archetype archetype: CardArchetypeLDAGenerator.ldaArchetypes.get(format.getName())) { + decks.add(new ArchetypeDeckGenerator(archetype, format, isForAi)); + } return decks; } diff --git a/forge-gui/src/main/java/forge/deck/DeckgenUtil.java b/forge-gui/src/main/java/forge/deck/DeckgenUtil.java index 6de099da5ae..76e07081e85 100644 --- a/forge-gui/src/main/java/forge/deck/DeckgenUtil.java +++ b/forge-gui/src/main/java/forge/deck/DeckgenUtil.java @@ -331,12 +331,10 @@ public class DeckgenUtil { return deck; } - /** * @param selection {@link java.lang.String} array * @return {@link forge.deck.Deck} */ - public static Deck buildColorDeck(List selection, Predicate formatFilter, boolean forAi) { try { final Deck deck; diff --git a/forge-gui/src/main/java/forge/deck/io/Archetype.java b/forge-gui/src/main/java/forge/deck/io/Archetype.java index 8bade9a7faa..4f3bea8e0d6 100644 --- a/forge-gui/src/main/java/forge/deck/io/Archetype.java +++ b/forge-gui/src/main/java/forge/deck/io/Archetype.java @@ -1,5 +1,6 @@ package forge.deck.io; + import java.io.Serializable; import java.util.List; diff --git a/forge-lda/pom.xml b/forge-lda/pom.xml new file mode 100644 index 00000000000..bdda207b546 --- /dev/null +++ b/forge-lda/pom.xml @@ -0,0 +1,37 @@ + + 4.0.0 + + + forge + forge + 1.6.40-SNAPSHOT + + + forge-lda + jar + Forge LDA + + + 0 + 0 + 0 + + + + + forge + forge-gui + ${project.version} + + + org.apache.commons + commons-lang3 + 3.8.1 + + + org.apache.commons + commons-math3 + 3.6.1 + + + \ No newline at end of file diff --git a/forge-lda/src/main/java/forge/lda/LDAModelGenetrator.java b/forge-lda/src/main/java/forge/lda/LDAModelGenetrator.java new file mode 100644 index 00000000000..1e0ce8b9202 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/LDAModelGenetrator.java @@ -0,0 +1,389 @@ +package forge.lda; + +import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.base.Predicates; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; + +//import forge.GuiDesktop; +import forge.StaticData; +import forge.card.CardRules; +import forge.card.CardRulesPredicates; +import forge.deck.Deck; +import forge.deck.DeckFormat; +import forge.deck.io.Archetype; +import forge.deck.io.CardThemedLDAIO; +import forge.deck.io.DeckStorage; +import forge.lda.dataset.Dataset; +import forge.lda.lda.LDA; +import forge.game.GameFormat; +import forge.item.PaperCard; +import forge.localinstance.properties.ForgeConstants; +import forge.localinstance.properties.ForgePreferences; +import forge.model.FModel; +import forge.util.storage.IStorage; +import forge.util.storage.StorageImmediatelySerialized; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.lang3.tuple.Pair; + +import java.io.File; +import java.util.*; + +import static forge.lda.lda.inference.InferenceMethod.CGS; + +/** + * Created by maustin on 09/05/2017. + */ +public final class LDAModelGenetrator { + + public static Map>>>> ldaPools = new HashMap<>(); + public static Map> ldaArchetypes = new HashMap<>(); + + + public static final void main(String[] args){ + //GuiBase.setInterface(new GuiDesktop()); + FModel.initialize(null, new Function() { + @Override + public Void apply(ForgePreferences preferences) { + preferences.setPref(ForgePreferences.FPref.LOAD_CARD_SCRIPTS_LAZILY, false); + return null; + } + }); + initialize(); + } + + public static boolean initialize(){ + List formatStrings = new ArrayList<>(); + formatStrings.add(FModel.getFormats().getStandard().getName()); + formatStrings.add(FModel.getFormats().getPioneer().getName()); + formatStrings.add(FModel.getFormats().getModern().getName()); + formatStrings.add("Legacy"); + formatStrings.add("Vintage"); + formatStrings.add(DeckFormat.Commander.toString()); + + for (String formatString : formatStrings){ + if(!initializeFormat(formatString)){ + return false; + } + } + + return true; + } + + /** Try to load matrix .dat files, otherwise check for deck folders and build .dat, otherwise return false **/ + public static boolean initializeFormat(String format){ + Map>>> formatMap = CardThemedLDAIO.loadLDA(format); + List lda = CardThemedLDAIO.loadRawLDA(format); + if(formatMap==null) { + try { + if(lda==null) { + if (CardThemedLDAIO.getMatrixFolder(format).exists()) { + if (format.equals(FModel.getFormats().getStandard().getName())) { + lda = initializeFormat(FModel.getFormats().getStandard()); + } else if (format.equals(FModel.getFormats().getModern().getName())) { + lda = initializeFormat(FModel.getFormats().getModern()); + } else if (format != DeckFormat.Commander.toString()) { + lda = initializeFormat(FModel.getFormats().get(format)); + } + CardThemedLDAIO.saveRawLDA(format, lda); + } else { + return false; + } + } + if (format.equals(FModel.getFormats().getStandard().getName())) { + formatMap = loadFormat(FModel.getFormats().getStandard(), lda); + } else if (format.equals(FModel.getFormats().getModern().getName())) { + formatMap = loadFormat(FModel.getFormats().getModern(), lda); + } else if (format != DeckFormat.Commander.toString()) { + formatMap = loadFormat(FModel.getFormats().get(format), lda);; + } + CardThemedLDAIO.saveLDA(format, formatMap); + }catch (Exception e){ + e.printStackTrace(); + return false; + } + } + ldaPools.put(format, formatMap); + ldaArchetypes.put(format, lda); + return true; + } + + public static Map>>> loadFormat(GameFormat format,List lda) throws Exception{ + + List>> topics = new ArrayList<>(); + Set cards = new HashSet(); + for (int t = 0; t < lda.size(); ++t) { + List> topic = new ArrayList<>(); + Set topicCards = new HashSet<>(); + List> highRankVocabs = lda.get(t).getCardProbabilities(); + if (highRankVocabs.get(0).getRight()<=0.01d){ + continue; + } + System.out.print("t" + t + ": "); + int i = 0; + while (topic.size()<=40&&i=0.005d) { + topicCards.add(cardName); + } + System.out.println("[" + highRankVocabs.get(i).getLeft() + "," + highRankVocabs.get(i).getRight() + "],"); + topic.add(highRankVocabs.get(i)); + } + i++; + } + System.out.println(); + if(topic.size()>18) { + cards.addAll(topicCards); + topics.add(topic); + } + } + Map>>> cardTopicMap = new HashMap<>(); + for (String card:cards){ + List>> cardTopics = new ArrayList<>(); + for( List> topic:topics){ + if(topicContains(card,topic)){ + cardTopics.add(topic); + } + } + cardTopicMap.put(card,cardTopics); + } + return cardTopicMap; + } + + public static List initializeFormat(GameFormat format) throws Exception{ + Dataset dataset = new Dataset(format); + + //estimate number of topics to attempt to find using power law + final int numTopics = Float.valueOf(347f*dataset.getNumDocs()/(2892f + dataset.getNumDocs())).intValue(); + System.out.println("Num Topics = " + numTopics); + LDA lda = new LDA(0.1, 0.1, numTopics, dataset, CGS); + lda.run(); + System.out.println(lda.computePerplexity(dataset)); + + //sort decks by topic + Map> topicDecks = new HashMap<>(); + + int deckNum=0; + for(Deck deck: dataset.getBow().getLegalDecks()){ + double maxTheta = 0; + int mainTopic = 0; + for (int t = 0; t < lda.getNumTopics(); ++t){ + double theta = lda.getTheta(deckNum,t); + if (theta > maxTheta){ + maxTheta = theta; + mainTopic = t; + } + } + if(topicDecks.containsKey(mainTopic)){ + topicDecks.get(mainTopic).add(deck); + }else{ + List decks = new ArrayList<>(); + decks.add(deck); + topicDecks.put(mainTopic,decks); + } + ++deckNum; + } + + + List unfilteredTopics = new ArrayList<>(); + for (int t = 0; t < lda.getNumTopics(); ++t) { + List> highRankVocabs = lda.getVocabsSortedByPhi(t); + Double min = 1d; + for(Pair p:highRankVocabs){ + if(p.getRight()> topRankVocabs = new ArrayList<>(); + for(Pair p:highRankVocabs){ + if(p.getRight()>min){ + topRankVocabs.add(p); + } + } + + //generate names for topics + List decks = topicDecks.get(t); + if(decks==null){ + continue; + } + LinkedHashMap wordCounts = new LinkedHashMap<>(); + for( Deck deck: decks){ + String name = deck.getName().replaceAll(".* Version - ","").replaceAll(" \\((Modern|Pioneer|Standard|Legacy|Vintage), #[0-9]+\\)",""); + name = name.replaceAll("\\(Modern|Pioneer|Standard|Legacy|Vintage|Fuck|Shit|Cunt\\)",""); + String[] tokens = name.split(" "); + for(String rawtoken: tokens){ + String token = rawtoken.toLowerCase(); + if (token.matches("[0-9]+")) { + //skip just numbers as not useful + continue; + } + if(wordCounts.containsKey(token)){ + wordCounts.put(token, wordCounts.get(token)+1); + }else{ + wordCounts.put(token, 1); + } + } + } + Map sortedWordCounts = sortByValue(wordCounts); + + List topWords = new ArrayList<>(); + Iterator wordIterator = sortedWordCounts.keySet().iterator(); + while(wordIterator.hasNext() && topWords.size() < 3){ + topWords.add(wordIterator.next()); + } + StringJoiner sb = new StringJoiner(" "); + for(String word: wordCounts.keySet()){ + if(topWords.contains(word)){ + sb.add(word); + } + } + String deckName = sb.toString(); + System.out.println("============ " + deckName); + System.out.println(decks.toString()); + + unfilteredTopics.add(new Archetype(topRankVocabs,deckName,decks.size())); + } + Comparator archetypeComparator = new Comparator() { + @Override + public int compare(Archetype o1, Archetype o2) { + return o2.getDeckCount().compareTo(o1.getDeckCount()); + } + }; + + Collections.sort(unfilteredTopics,archetypeComparator); + return unfilteredTopics; + } + + + + private static Map sortByValue(Map map) { + List> list = new LinkedList<>(map.entrySet()); + Collections.sort(list, new Comparator() { + @SuppressWarnings("unchecked") + public int compare(Object o1, Object o2) { + return ((Comparable) ((Map.Entry) (o2)).getValue()).compareTo(((Map.Entry) (o1)).getValue()); + } + }); + + Map result = new LinkedHashMap<>(); + for (Iterator> it = list.iterator(); it.hasNext();) { + Map.Entry entry = (Map.Entry) it.next(); + result.put(entry.getKey(), entry.getValue()); + } + + return result; + } + + public static boolean topicContains(String card, List> topic){ + for(Pair pair:topic){ + if(pair.getLeft().equals(card)){ + return true; + } + } + return false; + } + + public static HashMap>> initializeCommanderFormat(){ + + IStorage decks = new StorageImmediatelySerialized("Generator", + new DeckStorage(new File(ForgeConstants.DECK_GEN_DIR,DeckFormat.Commander.toString()), + ForgeConstants.DECK_GEN_DIR, false), + true); + + //get all cards + final Iterable cards = Iterables.filter(FModel.getMagicDb().getCommonCards().getUniqueCards() + , Predicates.compose(Predicates.not(CardRulesPredicates.Presets.IS_BASIC_LAND_NOT_WASTES), PaperCard.FN_GET_RULES)); + List cardList = Lists.newArrayList(cards); + cardList.add(FModel.getMagicDb().getCommonCards().getCard("Wastes")); + Map cardIntegerMap = new HashMap<>(); + Map integerCardMap = new HashMap<>(); + Map legendIntegerMap = new HashMap<>(); + Map integerLegendMap = new HashMap<>(); + //generate lookups for cards to link card names to matrix columns + for (int i=0; i legends = Lists.newArrayList(Iterables.filter(cardList,Predicates.compose( + new Predicate() { + @Override + public boolean apply(CardRules rules) { + return DeckFormat.Commander.isLegalCommander(rules); + } + }, PaperCard.FN_GET_RULES))); + + //generate lookups for legends to link commander names to matrix rows + for (int i=0; i>> cardPools = new HashMap<>(); + for (PaperCard card:legends){ + int col=legendIntegerMap.get(card.getName()); + int[] distances = matrix[col]; + int max = (Integer) Collections.max(Arrays.asList(ArrayUtils.toObject(distances))); + if (max>0) { + List> deckPool=new ArrayList<>(); + for(int k=0;k0){ + deckPool.add(new AbstractMap.SimpleEntry(integerCardMap.get(k),matrix[col][k])); + } + } + cardPools.put(card.getName(), deckPool); + } + } + return cardPools; + } + + //update the matrix by incrementing the connectivity count for each card in the deck + public static void updateLegendMatrix(Deck deck, PaperCard legend, Map cardIntegerMap, + Map legendIntegerMap, int[][] matrix){ + for (PaperCard pairCard:Iterables.filter(deck.getMain().toFlatList(), + Predicates.compose(Predicates.not(CardRulesPredicates.Presets.IS_BASIC_LAND_NOT_WASTES), PaperCard.FN_GET_RULES))){ + if (!pairCard.getName().equals(legend.getName())){ + try { + int old = matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(pairCard.getName())]; + matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(pairCard.getName())] = old + 1; + }catch (NullPointerException ne){ + //Todo: Not sure what was failing here + ne.printStackTrace(); + } + } + + } + //add partner commanders to matrix + if(deck.getCommanders().size()>1){ + for(PaperCard partner:deck.getCommanders()){ + if(!partner.equals(legend)){ + int old = matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(partner.getName())]; + matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(partner.getName())] = old + 1; + } + } + } + } + +} diff --git a/forge-lda/src/main/java/forge/lda/dataset/BagOfWords.java b/forge-lda/src/main/java/forge/lda/dataset/BagOfWords.java new file mode 100644 index 00000000000..d142496c8e5 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/dataset/BagOfWords.java @@ -0,0 +1,210 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.dataset; + +import forge.deck.Deck; +import forge.deck.io.DeckStorage; +import forge.game.GameFormat; +import forge.item.PaperCard; +import forge.localinstance.properties.ForgeConstants; +import forge.util.storage.IStorage; +import forge.util.storage.StorageImmediatelySerialized; + +import java.io.*; +import java.util.*; + +/** + * This class is immutable. + */ +public final class BagOfWords { + + private final int numDocs; + private final int numVocabs; + private final int numNNZ; + private final int numWords; + private List legalDecks; + + public Vocabularies getVocabs() { + return vocabs; + } + + private final Vocabularies vocabs; + + // docID -> the vocabs sequence in the doc + private Map> words; + + // docID -> the doc length + private Map docLength; + + + /** + * Read the bag-of-words dataset. + * @throws FileNotFoundException + * @throws IOException + * @throws Exception + * @throws NullPointerException filePath is null + */ + public BagOfWords(GameFormat format) throws FileNotFoundException, IOException, Exception { + IStorage decks = new StorageImmediatelySerialized("Generator", new DeckStorage(new File(ForgeConstants.DECK_GEN_DIR+ForgeConstants.PATH_SEPARATOR+format.getName()), + ForgeConstants.DECK_GEN_DIR, false), + true); + + Set cardSet = new HashSet<>(); + legalDecks = new ArrayList<>(); + for(Deck deck:decks){ + try { + if (format.isDeckLegal(deck) && deck.getMain().toFlatList().size() == 60) { + legalDecks.add(deck); + for (PaperCard card : deck.getMain().toFlatList()) { + cardSet.add(card); + } + } + }catch(Exception e){ + System.out.println("Skipping deck "+deck.getName()); + } + } + List cardList = new ArrayList<>(cardSet); + + this.words = new HashMap<>(); + this.docLength = new HashMap<>(); + ArrayList vocabList = new ArrayList(); + + int numDocs = legalDecks.size(); + int numVocabs = cardList.size(); + int numNNZ = 0; + int numWords = 0; + + Map cardIntegerMap = new HashMap<>(); + Map integerCardMap = new HashMap<>(); + for (int i=0; i cardNumbers = new ArrayList<>(); + for(PaperCard card:deck.getMain().toFlatList()){ + if(cardIntegerMap.get(card.getName()) == null){ + System.out.println(card.getName() + " is missing!!"); + } + cardNumbers.add(cardIntegerMap.get(card.getName())); + } + words.put(deckID,cardNumbers); + numWords+=cardNumbers.size(); + docLength.put(deckID,cardNumbers.size()); + deckID ++; + } + + /*String s = null; + while ((s = reader.readLine()) != null) { + List numbers + = Arrays.asList(s.split(" ")).stream().map(Integer::parseInt).collect(Collectors.toList()); + + if (numbers.size() == 1) { + if (headerCount == 2) numDocs = numbers.get(0); + else if (headerCount == 1) numVocabs = numbers.get(0); + else if (headerCount == 0) numNNZ = numbers.get(0); + --headerCount; + continue; + } + else if (numbers.size() == 3) { + final int docID = numbers.get(0); + final int vocabID = numbers.get(1); + final int count = numbers.get(2); + + // Set up the words container + if (!words.containsKey(docID)) { + words.put(docID, new ArrayList<>()); + } + for (int c = 0; c < count; ++c) { + words.get(docID).add(vocabID); + } + + // Set up the doc length map + Optional currentCount + = Optional.ofNullable(docLength.putIfAbsent(docID, count)); + currentCount.ifPresent(c -> docLength.replace(docID, c + count)); + + numWords += count; + } + else { + throw new Exception("Invalid dataset form was detected."); + } + } + reader.close();*/ + + this.numDocs = numDocs; + this.numVocabs = numVocabs; + this.numNNZ = numNNZ; + this.numWords = numWords; + + System.out.println("Num Decks" + this.numDocs); + System.out.println("Num Vocabs" + this.numVocabs); + System.out.println("Num NNZ" + this.numNNZ); + System.out.println("Num Cards" + this.numWords); + } + + public int getNumDocs() { + return numDocs; + } + + /** + * Get the length of the document. + * @param docID + * @return length of the document + * @throws IllegalArgumentException docID <= 0 || #documents < docID + */ + public int getDocLength(int docID) { + if (docID < 0 || getNumDocs() < docID) { + throw new IllegalArgumentException(); + } + + return docLength.get(docID); + } + + /** + * Get the unmodifiable list of words in the document. + * @param docID + * @return the unmodifiable list of words + * @throws IllegalArgumentException docID <= 0 || #documents < docID + */ + public List getWords(final int docID) { + if (docID < 0 || getNumDocs() < docID) { + throw new IllegalArgumentException(); + } + return Collections.unmodifiableList(words.get(docID)); + } + + public int getNumVocabs() { + return numVocabs; + } + + public int getNumNNZ() { + return numNNZ; + } + + public int getNumWords() { + return numWords; + } + + public List getLegalDecks() { return legalDecks; } +} + diff --git a/forge-lda/src/main/java/forge/lda/dataset/Dataset.java b/forge-lda/src/main/java/forge/lda/dataset/Dataset.java new file mode 100644 index 00000000000..0a35f73bab4 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/dataset/Dataset.java @@ -0,0 +1,80 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.dataset; + +import forge.game.GameFormat; + +import java.util.List; + +public class Dataset { + private BagOfWords bow; + private Vocabularies vocabs; + + public Dataset(GameFormat format) throws Exception { + bow = new BagOfWords(format); + vocabs = bow.getVocabs(); + } + + public Dataset(BagOfWords bow) { + this.bow = bow; + this.vocabs = null; + } + + public BagOfWords getBow() { + return bow; + } + + public int getNumDocs() { + return bow.getNumDocs(); + } + + public int getDocLength(int docID) { + return bow.getDocLength(docID); + } + + public List getWords(int docID) { + return bow.getWords(docID); + } + + public int getNumVocabs() { + return bow.getNumVocabs(); + } + + public int getNumNNZ() { + return bow.getNumNNZ(); + } + + public int getNumWords() { + return bow.getNumWords(); + } + + public Vocabulary get(int id) { + return vocabs.get(id); + } + + public int size() { + return vocabs.size(); + } + + public Vocabularies getVocabularies() { + return vocabs; + } + + public List getVocabularyList() { + return vocabs.getVocabularyList(); + } +} diff --git a/forge-lda/src/main/java/forge/lda/dataset/Vocabularies.java b/forge-lda/src/main/java/forge/lda/dataset/Vocabularies.java new file mode 100644 index 00000000000..4d32c8c19a2 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/dataset/Vocabularies.java @@ -0,0 +1,40 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.dataset; + +import java.util.Collections; +import java.util.List; + +public class Vocabularies { + private List vocabs; + + public Vocabularies(List vocabs) { + this.vocabs = vocabs; + } + + public Vocabulary get(int id) { + return vocabs.get(id); + } + + public int size() { + return vocabs.size(); + } + + public List getVocabularyList() { + return Collections.unmodifiableList(vocabs); + } +} diff --git a/forge-lda/src/main/java/forge/lda/dataset/Vocabulary.java b/forge-lda/src/main/java/forge/lda/dataset/Vocabulary.java new file mode 100644 index 00000000000..2420acf7f41 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/dataset/Vocabulary.java @@ -0,0 +1,37 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.dataset; + +public class Vocabulary { + private final int id; + private String vocabulary; + + public Vocabulary(int id, String vocabulary) { + if (vocabulary == null) throw new NullPointerException(); + this.id = id; + this.vocabulary = vocabulary; + } + + public int id() { + return id; + } + + @Override + public String toString() { + return vocabulary; + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/Alpha.java b/forge-lda/src/main/java/forge/lda/lda/Alpha.java new file mode 100644 index 00000000000..68880d18b2b --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/Alpha.java @@ -0,0 +1,46 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +class Alpha { + private List alphas; + + Alpha(double alpha, int numTopics) { + if (alpha <= 0.0 || numTopics <= 0) { + throw new IllegalArgumentException(); + } + this.alphas = Stream.generate(() -> alpha) + .limit(numTopics) + .collect(Collectors.toList()); + } + + double get(int i) { + return alphas.get(i); + } + + void set(int i, double value) { + alphas.set(i, value); + } + + double sum() { + return alphas.stream().reduce(Double::sum).get(); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/Beta.java b/forge-lda/src/main/java/forge/lda/lda/Beta.java new file mode 100644 index 00000000000..8471aadb8a3 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/Beta.java @@ -0,0 +1,54 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +class Beta { + private List betas; + + Beta(double beta, int numVocabs) { + if (beta <= 0.0 || numVocabs <= 0) { + throw new IllegalArgumentException(); + } + this.betas = Stream.generate(() -> beta) + .limit(numVocabs) + .collect(Collectors.toList()); + } + + Beta(double beta) { + if (beta <= 0.0) { + throw new IllegalArgumentException(); + } + this.betas = Arrays.asList(beta); + } + + double get() { + return get(0); + } + + double get(int i) { + return betas.get(i); + } + + void set(int i, double value) { + betas.set(i, value); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/Hyperparameters.java b/forge-lda/src/main/java/forge/lda/lda/Hyperparameters.java new file mode 100644 index 00000000000..a468535de7b --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/Hyperparameters.java @@ -0,0 +1,60 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda; + +class Hyperparameters { + private Alpha alpha; + private Beta beta; + + Hyperparameters(double alpha, double beta, int numTopics, int numVocabs) { + this.alpha = new Alpha(alpha, numTopics); + this.beta = new Beta(beta, numVocabs); + } + + Hyperparameters(double alpha, double beta, int numTopics) { + this.alpha = new Alpha(alpha, numTopics); + this.beta = new Beta(beta); + } + + double alpha(int i) { + return alpha.get(i); + } + + double sumAlpha() { + return alpha.sum(); + } + + double beta() { + return beta.get(); + } + + double beta(int i) { + return beta.get(i); + } + + void setAlpha(int i, double value) { + alpha.set(i, value); + } + + void setBeta(int i, double value) { + beta.set(i, value); + } + + void setBeta(double value) { + beta.set(0, value); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/LDA.java b/forge-lda/src/main/java/forge/lda/lda/LDA.java new file mode 100644 index 00000000000..f07a19749a0 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/LDA.java @@ -0,0 +1,183 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda; + +import java.util.List; + +import forge.lda.dataset.BagOfWords; +import forge.lda.dataset.Dataset; +import forge.lda.dataset.Vocabularies; + +import forge.lda.lda.inference.Inference; +import forge.lda.lda.inference.InferenceFactory; +import forge.lda.lda.inference.InferenceMethod; +import forge.lda.lda.inference.InferenceProperties; + +import org.apache.commons.lang3.tuple.Pair; + +public class LDA { + private Hyperparameters hyperparameters; + private final int numTopics; + private Dataset dataset; + private final Inference inference; + private InferenceProperties properties; + private boolean trained; + + /** + * @param alpha doc-topic hyperparameter + * @param beta topic-vocab hyperparameter + * @param numTopics the number of topics + * @param dataset dataset + * @param bow bag-of-words + * @param method inference method + */ + public LDA(final double alpha, final double beta, final int numTopics, + final Dataset dataset, InferenceMethod method) { + this(alpha, beta, numTopics, dataset, method, InferenceProperties.PROPERTIES_FILE_NAME); + } + + LDA(final double alpha, final double beta, final int numTopics, + final Dataset dataset, InferenceMethod method, String propertiesFilePath) { + this.hyperparameters = new Hyperparameters(alpha, beta, numTopics); + this.numTopics = numTopics; + this.dataset = dataset; + this.inference = InferenceFactory.getInstance(method); + this.properties = new InferenceProperties(); + this.trained = false; + + properties.setSeed(123L); + properties.setNumIteration(100); + } + + /** + * Get the vocabulary from its ID. + * @param vocabID + * @return the vocabulary + * @throws IllegalArgumentException vocabID <= 0 || the number of vocabularies < vocabID + */ + public String getVocab(int vocabID) { + if (vocabID < 0 || dataset.getNumVocabs() < vocabID) { + throw new IllegalArgumentException(); + } + return dataset.get(vocabID).toString(); + } + + /** + * Run model inference. + */ + public void run() { + if (properties == null) inference.setUp(this); + else inference.setUp(this, properties); + inference.run(); + trained = true; + } + + /** + * Get hyperparameter alpha corresponding to topic. + * @param topic + * @return alpha corresponding to topicID + * @throws ArrayIndexOutOfBoundsException topic < 0 || #topics <= topic + */ + public double getAlpha(final int topic) { + if (topic < 0 || numTopics < topic) { + throw new ArrayIndexOutOfBoundsException(topic); + } + return hyperparameters.alpha(topic); + } + + public double getSumAlpha() { + return hyperparameters.sumAlpha(); + } + + public double getBeta() { + return hyperparameters.beta(); + } + + public int getNumTopics() { + return numTopics; + } + + public BagOfWords getBow() { + return dataset.getBow(); + } + + /** + * Get the value of doc-topic probability \theta_{docID, topicID}. + * @param docID + * @param topicID + * @return the value of doc-topic probability + * @throws IllegalArgumentException docID <= 0 || #docs < docID || topicID < 0 || #topics <= topicID + * @throws IllegalStateException call this method when the inference has not been finished yet + */ + public double getTheta(final int docID, final int topicID) { + if (docID < 0 || dataset.getNumDocs() < docID + || topicID < 0 || numTopics < topicID) { + throw new IllegalArgumentException(); + } + if (!trained) { + throw new IllegalStateException(); + } + + return inference.getTheta(docID, topicID); + } + + /** + * Get the value of topic-vocab probability \phi_{topicID, vocabID}. + * @param topicID + * @param vocabID + * @return the value of topic-vocab probability + * @throws IllegalArgumentException topicID < 0 || #topics <= topicID || vocabID <= 0 + * @throws IllegalStateException call this method when the inference has not been finished yet + */ + public double getPhi(final int topicID, final int vocabID) { + if (topicID < 0 || numTopics < topicID || vocabID < 0) { + throw new IllegalArgumentException(); + } + if (!trained) { + throw new IllegalStateException(); + } + + return inference.getPhi(topicID, vocabID); + } + + public Vocabularies getVocabularies() { + return dataset.getVocabularies(); + } + + public List> getVocabsSortedByPhi(int topicID) { + return inference.getVocabsSortedByPhi(topicID); + } + + /** + * Compute the perplexity of trained LDA for the test bag-of-words dataset. + * @param testDataset + * @return the perplexity for the test bag-of-words dataset + */ + public double computePerplexity(Dataset testDataset) { + double loglikelihood = 0.0; + for (int d = 0; d < testDataset.getNumDocs(); ++d) { + for (Integer w : testDataset.getWords(d)) { + double sum = 0.0; + for (int t = 0; t < getNumTopics(); ++t) { + sum += getTheta(d, t) * getPhi(t, w.intValue()); + } + loglikelihood += Math.log(sum); + } + } + return Math.exp(-loglikelihood / testDataset.getNumWords()); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/Inference.java b/forge-lda/src/main/java/forge/lda/lda/inference/Inference.java new file mode 100644 index 00000000000..89b6ffdf2d1 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/Inference.java @@ -0,0 +1,62 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference; + +import java.util.List; + +import forge.lda.lda.LDA; + +import org.apache.commons.lang3.tuple.Pair; + +public interface Inference { + /** + * Set up for inference. + * @param lda + */ + public void setUp(LDA lda); + + /** + * Set up for inference. + * The configuration is read from properties class. + * @param lda + * @param properties + */ + public void setUp(LDA lda, InferenceProperties properties); + + /** + * Run model inference. + */ + public void run(); + + /** + * Get the value of doc-topic probability \theta_{docID, topicID}. + * @param docID + * @param topicID + * @return the value of doc-topic probability + */ + public double getTheta(final int docID, final int topicID); + + /** + * Get the value of topic-vocab probability \phi_{topicID, vocabID}. + * @param topicID + * @param vocabID + * @return the value of topic-vocab probability + */ + public double getPhi(final int topicID, final int vocabID); + + public List> getVocabsSortedByPhi(int topicID); +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/InferenceFactory.java b/forge-lda/src/main/java/forge/lda/lda/inference/InferenceFactory.java new file mode 100644 index 00000000000..bc88e8d056d --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/InferenceFactory.java @@ -0,0 +1,37 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference; + + +public class InferenceFactory { + private InferenceFactory() {} + + /** + * Get the LDAInference instance specified by the argument + * @param method + * @return the instance which implements LDAInference + */ + public static Inference getInstance(InferenceMethod method) { + Inference clazz = null; + try { + clazz = (Inference)Class.forName(method.toString()).newInstance(); + } catch (ReflectiveOperationException roe) { + roe.printStackTrace(); + } + return clazz; + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/InferenceMethod.java b/forge-lda/src/main/java/forge/lda/lda/inference/InferenceMethod.java new file mode 100644 index 00000000000..a1ca4e968a3 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/InferenceMethod.java @@ -0,0 +1,36 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference; + +import forge.lda.lda.inference.internal.CollapsedGibbsSampler; + +public enum InferenceMethod { + CGS(CollapsedGibbsSampler.class.getName()), + // more + ; + + private String className; + + private InferenceMethod(final String className) { + this.className = className; + } + + @Override + public String toString() { + return className; + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/InferenceProperties.java b/forge-lda/src/main/java/forge/lda/lda/inference/InferenceProperties.java new file mode 100644 index 00000000000..c550cc3277d --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/InferenceProperties.java @@ -0,0 +1,70 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.util.Properties; + +public class InferenceProperties { + public static String PROPERTIES_FILE_NAME = "lda.properties"; + + PropertiesLoader loader = new PropertiesLoader(); + private Properties properties; + + public InferenceProperties() { + this.properties = new Properties(); + } + + public void setSeed(Long seed){ + properties.setProperty("seed",seed.toString()); + } + + public void setNumIteration(Integer numIteration){ + properties.setProperty("numIteration",numIteration.toString()); + } + + /** + * Load properties. + * @param fileName + * @throws IOException + * @throws NullPointerException fileName is null + */ + public void load(String fileName) throws IOException { + if (fileName == null) throw new NullPointerException(); + InputStream stream = loader.getInputStream(fileName); + if (stream == null) throw new NullPointerException(); + properties.load(stream); + } + + public Long seed() { + return Long.parseLong(properties.getProperty("seed")); + } + + public Integer numIteration() { + return Integer.parseInt(properties.getProperty("numIteration")); + } +} + +class PropertiesLoader { + public InputStream getInputStream(String fileName) throws FileNotFoundException { + if (fileName == null) throw new NullPointerException(); + return new FileInputStream(fileName); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/AssignmentCounter.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/AssignmentCounter.java new file mode 100644 index 00000000000..9009ed78d34 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/AssignmentCounter.java @@ -0,0 +1,65 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +class AssignmentCounter { + private List counter; + + AssignmentCounter(int size) { + if (size <= 0) throw new IllegalArgumentException(); + this.counter = IntStream.generate(() -> 0) + .limit(size) + .boxed() + .collect(Collectors.toList()); + } + + int size() { + return counter.size(); + } + + int get(int id) { + if (id < 0 || counter.size() <= id) { + throw new IllegalArgumentException(); + } + return counter.get(id); + } + + int getSum() { + return counter.stream().reduce(Integer::sum).get(); + } + + void increment(int id) { + if (id < 0 || counter.size() <= id) { + throw new IllegalArgumentException(); + } + counter.set(id, counter.get(id) + 1); + } + + void decrement(int id) { + if (id < 0 || counter.size() <= id) { + throw new IllegalArgumentException(); + } + if (counter.get(id) == 0) { + throw new IllegalStateException(); + } + counter.set(id, counter.get(id) - 1); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/CollapsedGibbsSampler.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/CollapsedGibbsSampler.java new file mode 100644 index 00000000000..60f969fecc4 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/CollapsedGibbsSampler.java @@ -0,0 +1,220 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.IntStream; + +import forge.lda.lda.LDA; +import forge.lda.lda.inference.Inference; +import forge.lda.lda.inference.InferenceProperties; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution; +import org.apache.commons.math3.distribution.IntegerDistribution; + +import forge.lda.dataset.Vocabulary; + +public class CollapsedGibbsSampler implements Inference { + private LDA lda; + private Topics topics; + private Documents documents; + private int numIteration; + + private static final long DEFAULT_SEED = 0L; + private static final int DEFAULT_NUM_ITERATION = 100; + + // ready for Gibbs sampling + private boolean ready; + + public CollapsedGibbsSampler() { + ready = false; + } + + @Override + public void setUp(LDA lda, InferenceProperties properties) { + if (properties == null) { + setUp(lda); + return; + } + + this.lda = lda; + initialize(this.lda); + + final long seed = properties.seed() != null ? properties.seed() : DEFAULT_SEED; + initializeTopicAssignment(seed); + + this.numIteration + = properties.numIteration() != null ? properties.numIteration() : DEFAULT_NUM_ITERATION; + this.ready = true; + } + + @Override + public void setUp(LDA lda) { + if (lda == null) throw new NullPointerException(); + + this.lda = lda; + + initialize(this.lda); + initializeTopicAssignment(DEFAULT_SEED); + + this.numIteration = DEFAULT_NUM_ITERATION; + this.ready = true; + } + + private void initialize(LDA lda) { + assert lda != null; + this.topics = new Topics(lda); + this.documents = new Documents(lda); + } + + public boolean isReady() { + return ready; + } + + public int getNumIteration() { + return numIteration; + } + + public void setNumIteration(final int numIteration) { + this.numIteration = numIteration; + } + + @Override + public void run() { + if (!ready) { + throw new IllegalStateException("instance has not set up yet"); + } + + for (int i = 1; i <= numIteration; ++i) { + System.out.println("Iteration " + i + "."); + runSampling(); + } + } + + /** + * Run collapsed Gibbs sampling [Griffiths and Steyvers 2004]. + */ + void runSampling() { + for (Document d : documents.getDocuments()) { + for (int w = 0; w < d.getDocLength(); ++w) { + final Topic oldTopic = topics.get(d.getTopicID(w)); + d.decrementTopicCount(oldTopic.id()); + + final Vocabulary v = d.getVocabulary(w); + oldTopic.decrementVocabCount(v.id()); + + IntegerDistribution distribution + = getFullConditionalDistribution(lda.getNumTopics(), d.id(), v.id()); + + final int newTopicID = distribution.sample(); + d.setTopicID(w, newTopicID); + + d.incrementTopicCount(newTopicID); + final Topic newTopic = topics.get(newTopicID); + newTopic.incrementVocabCount(v.id()); + } + } + } + + /** + * Get the full conditional distribution over topics. + * docID and vocabID are passed to this distribution for parameters. + * @param numTopics + * @param docID + * @param vocabID + * @return the integer distribution over topics + */ + IntegerDistribution getFullConditionalDistribution(final int numTopics, final int docID, final int vocabID) { + int[] topics = IntStream.range(0, numTopics).toArray(); + double[] probabilities = Arrays.stream(topics) + .mapToDouble(t -> getTheta(docID, t) * getPhi(t, vocabID)) + .toArray(); + return new EnumeratedIntegerDistribution(topics, probabilities); + } + + /** + * Initialize the topic assignment. + * @param seed the seed of a pseudo random number generator + */ + void initializeTopicAssignment(final long seed) { + documents.initializeTopicAssignment(topics, seed); + } + + /** + * Get the count of topicID assigned to docID. + * @param docID + * @param topicID + * @return the count of topicID assigned to docID + */ + int getDTCount(final int docID, final int topicID) { + if (!ready) throw new IllegalStateException(); + if (docID <= 0 || lda.getBow().getNumDocs() < docID + || topicID < 0 || lda.getNumTopics() <= topicID) { + throw new IllegalArgumentException(); + } + return documents.getTopicCount(docID, topicID); + } + + /** + * Get the count of vocabID assigned to topicID. + * @param topicID + * @param vocabID + * @return the count of vocabID assigned to topicID + */ + int getTVCount(final int topicID, final int vocabID) { + if (!ready) throw new IllegalStateException(); + if (topicID < 0 || lda.getNumTopics() <= topicID || vocabID <= 0) { + throw new IllegalArgumentException(); + } + final Topic topic = topics.get(topicID); + return topic.getVocabCount(vocabID); + } + + /** + * Get the sum of counts of vocabs assigned to topicID. + * This is the sum of topic-vocab count over vocabs. + * @param topicID + * @return the sum of counts of vocabs assigned to topicID + * @throws IllegalArgumentException topicID < 0 || #topic <= topicID + */ + int getTSumCount(final int topicID) { + if (topicID < 0 || lda.getNumTopics() <= topicID) { + throw new IllegalArgumentException(); + } + final Topic topic = topics.get(topicID); + return topic.getSumCount(); + } + + @Override + public double getTheta(final int docID, final int topicID) { + if (!ready) throw new IllegalStateException(); + return documents.getTheta(docID, topicID, lda.getAlpha(topicID), lda.getSumAlpha()); + } + + @Override + public double getPhi(int topicID, int vocabID) { + if (!ready) throw new IllegalStateException(); + return topics.getPhi(topicID, vocabID, lda.getBeta()); + } + + @Override + public List> getVocabsSortedByPhi(int topicID) { + return topics.getVocabsSortedByPhi(topicID, lda.getVocabularies(), lda.getBeta()); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/Document.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Document.java new file mode 100644 index 00000000000..fee03880408 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Document.java @@ -0,0 +1,86 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.List; + +import forge.lda.dataset.Vocabulary; + +public class Document { + private final int id; + private TopicCounter topicCount; + private Words words; + private TopicAssignment assignment; + + Document(int id, int numTopics, List words) { + if (id < 0 || numTopics <= 0) throw new IllegalArgumentException(); + this.id = id; + this.topicCount = new TopicCounter(numTopics); + this.words = new Words(words); + this.assignment = new TopicAssignment(); + } + + int id() { + return id; + } + + int getTopicCount(int topicID) { + return topicCount.getTopicCount(topicID); + } + + int getDocLength() { + return words.getNumWords(); + } + + void incrementTopicCount(int topicID) { + topicCount.incrementTopicCount(topicID); + } + + void decrementTopicCount(int topicID) { + topicCount.decrementTopicCount(topicID); + } + + void initializeTopicAssignment(long seed) { + assignment.initialize(getDocLength(), topicCount.size(), seed); + for (int w = 0; w < getDocLength(); ++w) { + incrementTopicCount(assignment.get(w)); + } + } + + int getTopicID(int wordID) { + return assignment.get(wordID); + } + + void setTopicID(int wordID, int topicID) { + assignment.set(wordID, topicID); + } + + Vocabulary getVocabulary(int wordID) { + return words.get(wordID); + } + + List getWords() { + return words.getWords(); + } + + double getTheta(int topicID, double alpha, double sumAlpha) { + if (topicID < 0 || alpha <= 0.0 || sumAlpha <= 0.0) { + throw new IllegalArgumentException(); + } + return (getTopicCount(topicID) + alpha) / (getDocLength() + sumAlpha); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/Documents.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Documents.java new file mode 100644 index 00000000000..8dafbab0347 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Documents.java @@ -0,0 +1,100 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import forge.lda.lda.LDA; +import forge.lda.dataset.BagOfWords; +import forge.lda.dataset.Vocabularies; +import forge.lda.dataset.Vocabulary; + +class Documents { + private List documents; + + Documents(LDA lda) { + if (lda == null) throw new NullPointerException(); + + documents = new ArrayList<>(); + for (int d = 0; d < lda.getBow().getNumDocs(); ++d) { + List vocabList = getVocabularyList(d, lda.getBow(), lda.getVocabularies()); + Document doc = new Document(d, lda.getNumTopics(), vocabList); + documents.add(doc); + } + } + + List getVocabularyList(int docID, BagOfWords bow, Vocabularies vocabs) { + assert docID > 0 && bow != null && vocabs != null; + //System.out.println(docID); + //System.out.println(bow.getWords(docID).toString()); + return bow.getWords(docID).stream() + .map(id -> vocabs.get(id)) + .collect(Collectors.toList()); + } + + int getTopicID(int docID, int wordID) { + return documents.get(docID).getTopicID(wordID); + } + + void setTopicID(int docID, int wordID, int topicID) { + documents.get(docID).setTopicID(wordID, topicID); + } + + Vocabulary getVocab(int docID, int wordID) { + return documents.get(docID).getVocabulary(wordID); + } + + List getWords(int docID) { + return documents.get(docID).getWords(); + } + + List getDocuments() { + return Collections.unmodifiableList(documents); + } + + void incrementTopicCount(int docID, int topicID) { + documents.get(docID).incrementTopicCount(topicID); + } + + void decrementTopicCount(int docID, int topicID) { + documents.get(docID).decrementTopicCount(topicID); + } + + int getTopicCount(int docID, int topicID) { + return documents.get(docID).getTopicCount(topicID); + } + + double getTheta(int docID, int topicID, double alpha, double sumAlpha) { + if (docID < 0 || documents.size() < docID) throw new IllegalArgumentException(); + return documents.get(docID).getTheta(topicID, alpha, sumAlpha); + } + + void initializeTopicAssignment(Topics topics, long seed) { + for (Document d : getDocuments()) { + d.initializeTopicAssignment(seed); + for (int w = 0; w < d.getDocLength(); ++w) { + final int topicID = d.getTopicID(w); + final Topic topic = topics.get(topicID); + final Vocabulary vocab = d.getVocabulary(w); + topic.incrementVocabCount(vocab.id()); + } + } + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/Topic.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Topic.java new file mode 100644 index 00000000000..a54c1f249ff --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Topic.java @@ -0,0 +1,55 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +class Topic { + private final int id; + private final int numVocabs; + private VocabularyCounter counter; + + Topic(int id, int numVocabs) { + if (id < 0 || numVocabs <= 0) throw new IllegalArgumentException(); + this.id = id; + this.numVocabs = numVocabs; + this.counter = new VocabularyCounter(numVocabs); + } + + int id() { + return id; + } + + int getVocabCount(int vocabID) { + return counter.getVocabCount(vocabID); + } + + int getSumCount() { + return counter.getSumCount(); + } + + void incrementVocabCount(int vocabID) { + counter.incrementVocabCount(vocabID); + } + + void decrementVocabCount(int vocabID) { + counter.decrementVocabCount(vocabID); + } + + double getPhi(int vocabID, double beta) { + if (vocabID < 0 || beta <= 0) throw new IllegalArgumentException(); + return (getVocabCount(vocabID) + beta) / (getSumCount() + beta * numVocabs); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicAssignment.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicAssignment.java new file mode 100644 index 00000000000..8d97bd709e6 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicAssignment.java @@ -0,0 +1,60 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.stream.Collectors; + +class TopicAssignment { + private List topicAssignment; + private boolean ready; + + TopicAssignment() { + topicAssignment = new ArrayList<>(); + ready = false; + } + + void set(int wordID, int topicID) { + if (!ready) throw new IllegalStateException(); + if (wordID < 0 || topicAssignment.size() <= wordID || topicID < 0) { + throw new IllegalArgumentException(); + } + topicAssignment.set(wordID, topicID); + } + + int get(int wordID) { + if (!ready) throw new IllegalStateException(); + if (wordID < 0 || topicAssignment.size() <= wordID) { + throw new IllegalArgumentException(); + } + return topicAssignment.get(wordID); + } + + void initialize(int docLength, int numTopics, long seed) { + if (docLength <= 0 || numTopics <= 0) { + throw new IllegalArgumentException(); + } + + Random random = new Random(seed); + topicAssignment = random.ints(docLength, 0, numTopics) + .boxed() + .collect(Collectors.toList()); + ready = true; + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicCounter.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicCounter.java new file mode 100644 index 00000000000..7a04437cd29 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/TopicCounter.java @@ -0,0 +1,45 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +class TopicCounter { + private AssignmentCounter topicCount; + + TopicCounter(int numTopics) { + this.topicCount = new AssignmentCounter(numTopics); + } + + int getTopicCount(int topicID) { + return topicCount.get(topicID); + } + + int getDocLength() { + return topicCount.getSum(); + } + + void incrementTopicCount(int topicID) { + topicCount.increment(topicID); + } + + void decrementTopicCount(int topicID) { + topicCount.decrement(topicID); + } + + int size() { + return topicCount.size(); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/Topics.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Topics.java new file mode 100644 index 00000000000..ec042d0a1d3 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Topics.java @@ -0,0 +1,85 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import forge.lda.lda.LDA; +import forge.lda.dataset.Vocabularies; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; + +class Topics { + private List topics; + + Topics(LDA lda) { + if (lda == null) throw new NullPointerException(); + + topics = new ArrayList<>(); + for (int t = 0; t < lda.getNumTopics(); ++t) { + topics.add(new Topic(t, lda.getBow().getNumVocabs())); + } + } + + int numTopics() { + return topics.size(); + } + + Topic get(int id) { + return topics.get(id); + } + + int getVocabCount(int topicID, int vocabID) { + return topics.get(topicID).getVocabCount(vocabID); + } + + int getSumCount(int topicID) { + return topics.get(topicID).getSumCount(); + } + + void incrementVocabCount(int topicID, int vocabID) { + topics.get(topicID).incrementVocabCount(vocabID); + } + + void decrementVocabCount(int topicID, int vocabID) { + topics.get(topicID).decrementVocabCount(vocabID); + } + + double getPhi(int topicID, int vocabID, double beta) { + if (topicID < 0 || topics.size() <= topicID) throw new IllegalArgumentException(); + return topics.get(topicID).getPhi(vocabID, beta); + } + + List> getVocabsSortedByPhi(int topicID, Vocabularies vocabs, final double beta) { + if (topicID < 0 || topics.size() <= topicID || vocabs == null || beta <= 0.0) { + throw new IllegalArgumentException(); + } + + Topic topic = topics.get(topicID); + List> vocabProbPairs + = vocabs.getVocabularyList() + .stream() + .map(v -> new ImmutablePair(v.toString(), topic.getPhi(v.id(), beta))) + .sorted((p1, p2) -> Double.compare(p2.getRight(), p1.getRight())) + .collect(Collectors.toList()); + return Collections.unmodifiableList(vocabProbPairs); + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/VocabularyCounter.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/VocabularyCounter.java new file mode 100644 index 00000000000..00e067c4998 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/VocabularyCounter.java @@ -0,0 +1,46 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +class VocabularyCounter { + private AssignmentCounter vocabCount; + private int sumCount; + + VocabularyCounter(int numVocabs) { + this.vocabCount = new AssignmentCounter(numVocabs); + this.sumCount = 0; + } + + int getVocabCount(int vocabID) { + if (vocabCount.size() < vocabID) return 0; + else return vocabCount.get(vocabID); + } + + int getSumCount() { + return sumCount; + } + + void incrementVocabCount(int vocabID) { + vocabCount.increment(vocabID); + ++sumCount; + } + + void decrementVocabCount(int vocabID) { + vocabCount.decrement(vocabID); + --sumCount; + } +} diff --git a/forge-lda/src/main/java/forge/lda/lda/inference/internal/Words.java b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Words.java new file mode 100644 index 00000000000..11a1c691477 --- /dev/null +++ b/forge-lda/src/main/java/forge/lda/lda/inference/internal/Words.java @@ -0,0 +1,46 @@ +/* +* Copyright 2015 Kohei Yamamoto +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package forge.lda.lda.inference.internal; + +import java.util.Collections; +import java.util.List; + +import forge.lda.dataset.Vocabulary; + +class Words { + private List words; + + Words(List words) { + if (words == null) throw new NullPointerException(); + this.words = words; + } + + int getNumWords() { + return words.size(); + } + + Vocabulary get(int id) { + if (id < 0 || words.size() <= id) { + throw new IllegalArgumentException(); + } + return words.get(id); + } + + List getWords() { + return Collections.unmodifiableList(words); + } +} diff --git a/pom.xml b/pom.xml index e378b7b49b5..9645f2a4962 100644 --- a/pom.xml +++ b/pom.xml @@ -65,6 +65,7 @@ forge-gui-mobile-dev forge-gui-android forge-gui-ios + forge-lda