Merge branch 'ldastream' into 'master'

LDA module

See merge request core-developers/forge!4251
This commit is contained in:
austinio7116
2021-04-11 05:51:11 +00:00
28 changed files with 2154 additions and 5 deletions

View File

@@ -17,9 +17,9 @@ import forge.model.FModel;
public class ArchetypeDeckGenerator extends DeckProxy implements Comparable<ArchetypeDeckGenerator> { public class ArchetypeDeckGenerator extends DeckProxy implements Comparable<ArchetypeDeckGenerator> {
public static List<DeckProxy> getMatrixDecks(GameFormat format, boolean isForAi){ public static List<DeckProxy> getMatrixDecks(GameFormat format, boolean isForAi){
final List<DeckProxy> decks = new ArrayList<>(); final List<DeckProxy> decks = new ArrayList<>();
for(Archetype archetype: CardArchetypeLDAGenerator.ldaArchetypes.get(format.getName())) { for(Archetype archetype: CardArchetypeLDAGenerator.ldaArchetypes.get(format.getName())) {
decks.add(new ArchetypeDeckGenerator(archetype, format, isForAi)); decks.add(new ArchetypeDeckGenerator(archetype, format, isForAi));
} }
return decks; return decks;
} }

View File

@@ -331,12 +331,10 @@ public class DeckgenUtil {
return deck; return deck;
} }
/** /**
* @param selection {@link java.lang.String} array * @param selection {@link java.lang.String} array
* @return {@link forge.deck.Deck} * @return {@link forge.deck.Deck}
*/ */
public static Deck buildColorDeck(List<String> selection, Predicate<PaperCard> formatFilter, boolean forAi) { public static Deck buildColorDeck(List<String> selection, Predicate<PaperCard> formatFilter, boolean forAi) {
try { try {
final Deck deck; final Deck deck;

View File

@@ -1,5 +1,6 @@
package forge.deck.io; package forge.deck.io;
import java.io.Serializable; import java.io.Serializable;
import java.util.List; import java.util.List;

37
forge-lda/pom.xml Normal file
View File

@@ -0,0 +1,37 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<artifactId>forge</artifactId>
<groupId>forge</groupId>
<version>1.6.40-SNAPSHOT</version>
</parent>
<artifactId>forge-lda</artifactId>
<packaging>jar</packaging>
<name>Forge LDA</name>
<properties>
<parsedVersion.majorVersion>0</parsedVersion.majorVersion>
<parsedVersion.minorVersion>0</parsedVersion.minorVersion>
<parsedVersion.incrementalVersion>0</parsedVersion.incrementalVersion>
</properties>
<dependencies>
<dependency>
<groupId>forge</groupId>
<artifactId>forge-gui</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.8.1</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>
</dependencies>
</project>

View File

@@ -0,0 +1,389 @@
package forge.lda;
import com.google.common.base.Function;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
//import forge.GuiDesktop;
import forge.StaticData;
import forge.card.CardRules;
import forge.card.CardRulesPredicates;
import forge.deck.Deck;
import forge.deck.DeckFormat;
import forge.deck.io.Archetype;
import forge.deck.io.CardThemedLDAIO;
import forge.deck.io.DeckStorage;
import forge.lda.dataset.Dataset;
import forge.lda.lda.LDA;
import forge.game.GameFormat;
import forge.item.PaperCard;
import forge.localinstance.properties.ForgeConstants;
import forge.localinstance.properties.ForgePreferences;
import forge.model.FModel;
import forge.util.storage.IStorage;
import forge.util.storage.StorageImmediatelySerialized;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.Pair;
import java.io.File;
import java.util.*;
import static forge.lda.lda.inference.InferenceMethod.CGS;
/**
* Created by maustin on 09/05/2017.
*/
public final class LDAModelGenetrator {
public static Map<String, Map<String,List<List<Pair<String, Double>>>>> ldaPools = new HashMap<>();
public static Map<String, List<Archetype>> ldaArchetypes = new HashMap<>();
public static final void main(String[] args){
//GuiBase.setInterface(new GuiDesktop());
FModel.initialize(null, new Function<ForgePreferences, Void>() {
@Override
public Void apply(ForgePreferences preferences) {
preferences.setPref(ForgePreferences.FPref.LOAD_CARD_SCRIPTS_LAZILY, false);
return null;
}
});
initialize();
}
public static boolean initialize(){
List<String> formatStrings = new ArrayList<>();
formatStrings.add(FModel.getFormats().getStandard().getName());
formatStrings.add(FModel.getFormats().getPioneer().getName());
formatStrings.add(FModel.getFormats().getModern().getName());
formatStrings.add("Legacy");
formatStrings.add("Vintage");
formatStrings.add(DeckFormat.Commander.toString());
for (String formatString : formatStrings){
if(!initializeFormat(formatString)){
return false;
}
}
return true;
}
/** Try to load matrix .dat files, otherwise check for deck folders and build .dat, otherwise return false **/
public static boolean initializeFormat(String format){
Map<String,List<List<Pair<String, Double>>>> formatMap = CardThemedLDAIO.loadLDA(format);
List<Archetype> lda = CardThemedLDAIO.loadRawLDA(format);
if(formatMap==null) {
try {
if(lda==null) {
if (CardThemedLDAIO.getMatrixFolder(format).exists()) {
if (format.equals(FModel.getFormats().getStandard().getName())) {
lda = initializeFormat(FModel.getFormats().getStandard());
} else if (format.equals(FModel.getFormats().getModern().getName())) {
lda = initializeFormat(FModel.getFormats().getModern());
} else if (format != DeckFormat.Commander.toString()) {
lda = initializeFormat(FModel.getFormats().get(format));
}
CardThemedLDAIO.saveRawLDA(format, lda);
} else {
return false;
}
}
if (format.equals(FModel.getFormats().getStandard().getName())) {
formatMap = loadFormat(FModel.getFormats().getStandard(), lda);
} else if (format.equals(FModel.getFormats().getModern().getName())) {
formatMap = loadFormat(FModel.getFormats().getModern(), lda);
} else if (format != DeckFormat.Commander.toString()) {
formatMap = loadFormat(FModel.getFormats().get(format), lda);;
}
CardThemedLDAIO.saveLDA(format, formatMap);
}catch (Exception e){
e.printStackTrace();
return false;
}
}
ldaPools.put(format, formatMap);
ldaArchetypes.put(format, lda);
return true;
}
public static Map<String,List<List<Pair<String, Double>>>> loadFormat(GameFormat format,List<Archetype> lda) throws Exception{
List<List<Pair<String, Double>>> topics = new ArrayList<>();
Set<String> cards = new HashSet<String>();
for (int t = 0; t < lda.size(); ++t) {
List<Pair<String, Double>> topic = new ArrayList<>();
Set<String> topicCards = new HashSet<>();
List<Pair<String, Double>> highRankVocabs = lda.get(t).getCardProbabilities();
if (highRankVocabs.get(0).getRight()<=0.01d){
continue;
}
System.out.print("t" + t + ": ");
int i = 0;
while (topic.size()<=40&&i<highRankVocabs.size()) {
String cardName = highRankVocabs.get(i).getLeft();;
PaperCard card = StaticData.instance().getCommonCards().getUniqueByName(cardName);
if(card == null){
System.out.println("Card " + cardName + " is MISSING!");
i++;
continue;
}
if(!card.getRules().getType().isBasicLand()){
if(highRankVocabs.get(i).getRight()>=0.005d) {
topicCards.add(cardName);
}
System.out.println("[" + highRankVocabs.get(i).getLeft() + "," + highRankVocabs.get(i).getRight() + "],");
topic.add(highRankVocabs.get(i));
}
i++;
}
System.out.println();
if(topic.size()>18) {
cards.addAll(topicCards);
topics.add(topic);
}
}
Map<String,List<List<Pair<String, Double>>>> cardTopicMap = new HashMap<>();
for (String card:cards){
List<List<Pair<String, Double>>> cardTopics = new ArrayList<>();
for( List<Pair<String, Double>> topic:topics){
if(topicContains(card,topic)){
cardTopics.add(topic);
}
}
cardTopicMap.put(card,cardTopics);
}
return cardTopicMap;
}
public static List<Archetype> initializeFormat(GameFormat format) throws Exception{
Dataset dataset = new Dataset(format);
//estimate number of topics to attempt to find using power law
final int numTopics = Float.valueOf(347f*dataset.getNumDocs()/(2892f + dataset.getNumDocs())).intValue();
System.out.println("Num Topics = " + numTopics);
LDA lda = new LDA(0.1, 0.1, numTopics, dataset, CGS);
lda.run();
System.out.println(lda.computePerplexity(dataset));
//sort decks by topic
Map<Integer,List<Deck>> topicDecks = new HashMap<>();
int deckNum=0;
for(Deck deck: dataset.getBow().getLegalDecks()){
double maxTheta = 0;
int mainTopic = 0;
for (int t = 0; t < lda.getNumTopics(); ++t){
double theta = lda.getTheta(deckNum,t);
if (theta > maxTheta){
maxTheta = theta;
mainTopic = t;
}
}
if(topicDecks.containsKey(mainTopic)){
topicDecks.get(mainTopic).add(deck);
}else{
List<Deck> decks = new ArrayList<>();
decks.add(deck);
topicDecks.put(mainTopic,decks);
}
++deckNum;
}
List<Archetype> unfilteredTopics = new ArrayList<>();
for (int t = 0; t < lda.getNumTopics(); ++t) {
List<Pair<String, Double>> highRankVocabs = lda.getVocabsSortedByPhi(t);
Double min = 1d;
for(Pair<String, Double> p:highRankVocabs){
if(p.getRight()<min){
min=p.getRight();
}
}
List<Pair<String, Double>> topRankVocabs = new ArrayList<>();
for(Pair<String, Double> p:highRankVocabs){
if(p.getRight()>min){
topRankVocabs.add(p);
}
}
//generate names for topics
List<Deck> decks = topicDecks.get(t);
if(decks==null){
continue;
}
LinkedHashMap<String, Integer> wordCounts = new LinkedHashMap<>();
for( Deck deck: decks){
String name = deck.getName().replaceAll(".* Version - ","").replaceAll(" \\((Modern|Pioneer|Standard|Legacy|Vintage), #[0-9]+\\)","");
name = name.replaceAll("\\(Modern|Pioneer|Standard|Legacy|Vintage|Fuck|Shit|Cunt\\)","");
String[] tokens = name.split(" ");
for(String rawtoken: tokens){
String token = rawtoken.toLowerCase();
if (token.matches("[0-9]+")) {
//skip just numbers as not useful
continue;
}
if(wordCounts.containsKey(token)){
wordCounts.put(token, wordCounts.get(token)+1);
}else{
wordCounts.put(token, 1);
}
}
}
Map<String, Integer> sortedWordCounts = sortByValue(wordCounts);
List<String> topWords = new ArrayList<>();
Iterator<String> wordIterator = sortedWordCounts.keySet().iterator();
while(wordIterator.hasNext() && topWords.size() < 3){
topWords.add(wordIterator.next());
}
StringJoiner sb = new StringJoiner(" ");
for(String word: wordCounts.keySet()){
if(topWords.contains(word)){
sb.add(word);
}
}
String deckName = sb.toString();
System.out.println("============ " + deckName);
System.out.println(decks.toString());
unfilteredTopics.add(new Archetype(topRankVocabs,deckName,decks.size()));
}
Comparator<Archetype> archetypeComparator = new Comparator<Archetype>() {
@Override
public int compare(Archetype o1, Archetype o2) {
return o2.getDeckCount().compareTo(o1.getDeckCount());
}
};
Collections.sort(unfilteredTopics,archetypeComparator);
return unfilteredTopics;
}
private static <K, V> Map<K, V> sortByValue(Map<K, V> map) {
List<Map.Entry<K, V>> list = new LinkedList<>(map.entrySet());
Collections.sort(list, new Comparator<Object>() {
@SuppressWarnings("unchecked")
public int compare(Object o1, Object o2) {
return ((Comparable<V>) ((Map.Entry<K, V>) (o2)).getValue()).compareTo(((Map.Entry<K, V>) (o1)).getValue());
}
});
Map<K, V> result = new LinkedHashMap<>();
for (Iterator<Map.Entry<K, V>> it = list.iterator(); it.hasNext();) {
Map.Entry<K, V> entry = (Map.Entry<K, V>) it.next();
result.put(entry.getKey(), entry.getValue());
}
return result;
}
public static boolean topicContains(String card, List<Pair<String, Double>> topic){
for(Pair<String,Double> pair:topic){
if(pair.getLeft().equals(card)){
return true;
}
}
return false;
}
public static HashMap<String,List<Map.Entry<PaperCard,Integer>>> initializeCommanderFormat(){
IStorage<Deck> decks = new StorageImmediatelySerialized<Deck>("Generator",
new DeckStorage(new File(ForgeConstants.DECK_GEN_DIR,DeckFormat.Commander.toString()),
ForgeConstants.DECK_GEN_DIR, false),
true);
//get all cards
final Iterable<PaperCard> cards = Iterables.filter(FModel.getMagicDb().getCommonCards().getUniqueCards()
, Predicates.compose(Predicates.not(CardRulesPredicates.Presets.IS_BASIC_LAND_NOT_WASTES), PaperCard.FN_GET_RULES));
List<PaperCard> cardList = Lists.newArrayList(cards);
cardList.add(FModel.getMagicDb().getCommonCards().getCard("Wastes"));
Map<String, Integer> cardIntegerMap = new HashMap<>();
Map<Integer, PaperCard> integerCardMap = new HashMap<>();
Map<String, Integer> legendIntegerMap = new HashMap<>();
Map<Integer, PaperCard> integerLegendMap = new HashMap<>();
//generate lookups for cards to link card names to matrix columns
for (int i=0; i<cardList.size(); ++i){
cardIntegerMap.put(cardList.get(i).getName(), i);
integerCardMap.put(i, cardList.get(i));
}
//filter to just legal commanders
List<PaperCard> legends = Lists.newArrayList(Iterables.filter(cardList,Predicates.compose(
new Predicate<CardRules>() {
@Override
public boolean apply(CardRules rules) {
return DeckFormat.Commander.isLegalCommander(rules);
}
}, PaperCard.FN_GET_RULES)));
//generate lookups for legends to link commander names to matrix rows
for (int i=0; i<legends.size(); ++i){
legendIntegerMap.put(legends.get(i).getName(), i);
integerLegendMap.put(i, legends.get(i));
}
int[][] matrix = new int[legends.size()][cardList.size()];
//loop through commanders and decks
for (PaperCard legend:legends){
for (Deck deck:decks){
//if the deck has the commander
if (deck.getCommanders().contains(legend)){
//update the matrix by incrementing the connectivity count for each card in the deck
updateLegendMatrix(deck, legend, cardIntegerMap, legendIntegerMap, matrix);
}
}
}
//convert the matrix into a map of pools for each commander
HashMap<String,List<Map.Entry<PaperCard,Integer>>> cardPools = new HashMap<>();
for (PaperCard card:legends){
int col=legendIntegerMap.get(card.getName());
int[] distances = matrix[col];
int max = (Integer) Collections.max(Arrays.asList(ArrayUtils.toObject(distances)));
if (max>0) {
List<Map.Entry<PaperCard,Integer>> deckPool=new ArrayList<>();
for(int k=0;k<cardList.size(); k++){
if(matrix[col][k]>0){
deckPool.add(new AbstractMap.SimpleEntry<PaperCard, Integer>(integerCardMap.get(k),matrix[col][k]));
}
}
cardPools.put(card.getName(), deckPool);
}
}
return cardPools;
}
//update the matrix by incrementing the connectivity count for each card in the deck
public static void updateLegendMatrix(Deck deck, PaperCard legend, Map<String, Integer> cardIntegerMap,
Map<String, Integer> legendIntegerMap, int[][] matrix){
for (PaperCard pairCard:Iterables.filter(deck.getMain().toFlatList(),
Predicates.compose(Predicates.not(CardRulesPredicates.Presets.IS_BASIC_LAND_NOT_WASTES), PaperCard.FN_GET_RULES))){
if (!pairCard.getName().equals(legend.getName())){
try {
int old = matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(pairCard.getName())];
matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(pairCard.getName())] = old + 1;
}catch (NullPointerException ne){
//Todo: Not sure what was failing here
ne.printStackTrace();
}
}
}
//add partner commanders to matrix
if(deck.getCommanders().size()>1){
for(PaperCard partner:deck.getCommanders()){
if(!partner.equals(legend)){
int old = matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(partner.getName())];
matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(partner.getName())] = old + 1;
}
}
}
}
}

View File

@@ -0,0 +1,210 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.dataset;
import forge.deck.Deck;
import forge.deck.io.DeckStorage;
import forge.game.GameFormat;
import forge.item.PaperCard;
import forge.localinstance.properties.ForgeConstants;
import forge.util.storage.IStorage;
import forge.util.storage.StorageImmediatelySerialized;
import java.io.*;
import java.util.*;
/**
* This class is immutable.
*/
public final class BagOfWords {
private final int numDocs;
private final int numVocabs;
private final int numNNZ;
private final int numWords;
private List<Deck> legalDecks;
public Vocabularies getVocabs() {
return vocabs;
}
private final Vocabularies vocabs;
// docID -> the vocabs sequence in the doc
private Map<Integer, List<Integer>> words;
// docID -> the doc length
private Map<Integer, Integer> docLength;
/**
* Read the bag-of-words dataset.
* @throws FileNotFoundException
* @throws IOException
* @throws Exception
* @throws NullPointerException filePath is null
*/
public BagOfWords(GameFormat format) throws FileNotFoundException, IOException, Exception {
IStorage<Deck> decks = new StorageImmediatelySerialized<Deck>("Generator", new DeckStorage(new File(ForgeConstants.DECK_GEN_DIR+ForgeConstants.PATH_SEPARATOR+format.getName()),
ForgeConstants.DECK_GEN_DIR, false),
true);
Set<PaperCard> cardSet = new HashSet<>();
legalDecks = new ArrayList<>();
for(Deck deck:decks){
try {
if (format.isDeckLegal(deck) && deck.getMain().toFlatList().size() == 60) {
legalDecks.add(deck);
for (PaperCard card : deck.getMain().toFlatList()) {
cardSet.add(card);
}
}
}catch(Exception e){
System.out.println("Skipping deck "+deck.getName());
}
}
List<PaperCard> cardList = new ArrayList<>(cardSet);
this.words = new HashMap<>();
this.docLength = new HashMap<>();
ArrayList<Vocabulary> vocabList = new ArrayList<Vocabulary>();
int numDocs = legalDecks.size();
int numVocabs = cardList.size();
int numNNZ = 0;
int numWords = 0;
Map<String, Integer> cardIntegerMap = new HashMap<>();
Map<Integer, PaperCard> integerCardMap = new HashMap<>();
for (int i=0; i<cardList.size(); ++i){
cardIntegerMap.put(cardList.get(i).getName(), i);
vocabList.add(new Vocabulary(i,cardList.get(i).getName()));
integerCardMap.put(i, cardList.get(i));
}
this.vocabs = new Vocabularies(vocabList);
int deckID = 0;
for (Deck deck:legalDecks){
numNNZ += deck.getMain().countDistinct();
List<Integer> cardNumbers = new ArrayList<>();
for(PaperCard card:deck.getMain().toFlatList()){
if(cardIntegerMap.get(card.getName()) == null){
System.out.println(card.getName() + " is missing!!");
}
cardNumbers.add(cardIntegerMap.get(card.getName()));
}
words.put(deckID,cardNumbers);
numWords+=cardNumbers.size();
docLength.put(deckID,cardNumbers.size());
deckID ++;
}
/*String s = null;
while ((s = reader.readLine()) != null) {
List<Integer> numbers
= Arrays.asList(s.split(" ")).stream().map(Integer::parseInt).collect(Collectors.toList());
if (numbers.size() == 1) {
if (headerCount == 2) numDocs = numbers.get(0);
else if (headerCount == 1) numVocabs = numbers.get(0);
else if (headerCount == 0) numNNZ = numbers.get(0);
--headerCount;
continue;
}
else if (numbers.size() == 3) {
final int docID = numbers.get(0);
final int vocabID = numbers.get(1);
final int count = numbers.get(2);
// Set up the words container
if (!words.containsKey(docID)) {
words.put(docID, new ArrayList<>());
}
for (int c = 0; c < count; ++c) {
words.get(docID).add(vocabID);
}
// Set up the doc length map
Optional<Integer> currentCount
= Optional.ofNullable(docLength.putIfAbsent(docID, count));
currentCount.ifPresent(c -> docLength.replace(docID, c + count));
numWords += count;
}
else {
throw new Exception("Invalid dataset form was detected.");
}
}
reader.close();*/
this.numDocs = numDocs;
this.numVocabs = numVocabs;
this.numNNZ = numNNZ;
this.numWords = numWords;
System.out.println("Num Decks" + this.numDocs);
System.out.println("Num Vocabs" + this.numVocabs);
System.out.println("Num NNZ" + this.numNNZ);
System.out.println("Num Cards" + this.numWords);
}
public int getNumDocs() {
return numDocs;
}
/**
* Get the length of the document.
* @param docID
* @return length of the document
* @throws IllegalArgumentException docID <= 0 || #documents < docID
*/
public int getDocLength(int docID) {
if (docID < 0 || getNumDocs() < docID) {
throw new IllegalArgumentException();
}
return docLength.get(docID);
}
/**
* Get the unmodifiable list of words in the document.
* @param docID
* @return the unmodifiable list of words
* @throws IllegalArgumentException docID <= 0 || #documents < docID
*/
public List<Integer> getWords(final int docID) {
if (docID < 0 || getNumDocs() < docID) {
throw new IllegalArgumentException();
}
return Collections.unmodifiableList(words.get(docID));
}
public int getNumVocabs() {
return numVocabs;
}
public int getNumNNZ() {
return numNNZ;
}
public int getNumWords() {
return numWords;
}
public List<Deck> getLegalDecks() { return legalDecks; }
}

View File

@@ -0,0 +1,80 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.dataset;
import forge.game.GameFormat;
import java.util.List;
public class Dataset {
private BagOfWords bow;
private Vocabularies vocabs;
public Dataset(GameFormat format) throws Exception {
bow = new BagOfWords(format);
vocabs = bow.getVocabs();
}
public Dataset(BagOfWords bow) {
this.bow = bow;
this.vocabs = null;
}
public BagOfWords getBow() {
return bow;
}
public int getNumDocs() {
return bow.getNumDocs();
}
public int getDocLength(int docID) {
return bow.getDocLength(docID);
}
public List<Integer> getWords(int docID) {
return bow.getWords(docID);
}
public int getNumVocabs() {
return bow.getNumVocabs();
}
public int getNumNNZ() {
return bow.getNumNNZ();
}
public int getNumWords() {
return bow.getNumWords();
}
public Vocabulary get(int id) {
return vocabs.get(id);
}
public int size() {
return vocabs.size();
}
public Vocabularies getVocabularies() {
return vocabs;
}
public List<Vocabulary> getVocabularyList() {
return vocabs.getVocabularyList();
}
}

View File

@@ -0,0 +1,40 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.dataset;
import java.util.Collections;
import java.util.List;
public class Vocabularies {
private List<Vocabulary> vocabs;
public Vocabularies(List<Vocabulary> vocabs) {
this.vocabs = vocabs;
}
public Vocabulary get(int id) {
return vocabs.get(id);
}
public int size() {
return vocabs.size();
}
public List<Vocabulary> getVocabularyList() {
return Collections.unmodifiableList(vocabs);
}
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.dataset;
public class Vocabulary {
private final int id;
private String vocabulary;
public Vocabulary(int id, String vocabulary) {
if (vocabulary == null) throw new NullPointerException();
this.id = id;
this.vocabulary = vocabulary;
}
public int id() {
return id;
}
@Override
public String toString() {
return vocabulary;
}
}

View File

@@ -0,0 +1,46 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
class Alpha {
private List<Double> alphas;
Alpha(double alpha, int numTopics) {
if (alpha <= 0.0 || numTopics <= 0) {
throw new IllegalArgumentException();
}
this.alphas = Stream.generate(() -> alpha)
.limit(numTopics)
.collect(Collectors.toList());
}
double get(int i) {
return alphas.get(i);
}
void set(int i, double value) {
alphas.set(i, value);
}
double sum() {
return alphas.stream().reduce(Double::sum).get();
}
}

View File

@@ -0,0 +1,54 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
class Beta {
private List<Double> betas;
Beta(double beta, int numVocabs) {
if (beta <= 0.0 || numVocabs <= 0) {
throw new IllegalArgumentException();
}
this.betas = Stream.generate(() -> beta)
.limit(numVocabs)
.collect(Collectors.toList());
}
Beta(double beta) {
if (beta <= 0.0) {
throw new IllegalArgumentException();
}
this.betas = Arrays.asList(beta);
}
double get() {
return get(0);
}
double get(int i) {
return betas.get(i);
}
void set(int i, double value) {
betas.set(i, value);
}
}

View File

@@ -0,0 +1,60 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda;
class Hyperparameters {
private Alpha alpha;
private Beta beta;
Hyperparameters(double alpha, double beta, int numTopics, int numVocabs) {
this.alpha = new Alpha(alpha, numTopics);
this.beta = new Beta(beta, numVocabs);
}
Hyperparameters(double alpha, double beta, int numTopics) {
this.alpha = new Alpha(alpha, numTopics);
this.beta = new Beta(beta);
}
double alpha(int i) {
return alpha.get(i);
}
double sumAlpha() {
return alpha.sum();
}
double beta() {
return beta.get();
}
double beta(int i) {
return beta.get(i);
}
void setAlpha(int i, double value) {
alpha.set(i, value);
}
void setBeta(int i, double value) {
beta.set(i, value);
}
void setBeta(double value) {
beta.set(0, value);
}
}

View File

@@ -0,0 +1,183 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda;
import java.util.List;
import forge.lda.dataset.BagOfWords;
import forge.lda.dataset.Dataset;
import forge.lda.dataset.Vocabularies;
import forge.lda.lda.inference.Inference;
import forge.lda.lda.inference.InferenceFactory;
import forge.lda.lda.inference.InferenceMethod;
import forge.lda.lda.inference.InferenceProperties;
import org.apache.commons.lang3.tuple.Pair;
public class LDA {
private Hyperparameters hyperparameters;
private final int numTopics;
private Dataset dataset;
private final Inference inference;
private InferenceProperties properties;
private boolean trained;
/**
* @param alpha doc-topic hyperparameter
* @param beta topic-vocab hyperparameter
* @param numTopics the number of topics
* @param dataset dataset
* @param bow bag-of-words
* @param method inference method
*/
public LDA(final double alpha, final double beta, final int numTopics,
final Dataset dataset, InferenceMethod method) {
this(alpha, beta, numTopics, dataset, method, InferenceProperties.PROPERTIES_FILE_NAME);
}
LDA(final double alpha, final double beta, final int numTopics,
final Dataset dataset, InferenceMethod method, String propertiesFilePath) {
this.hyperparameters = new Hyperparameters(alpha, beta, numTopics);
this.numTopics = numTopics;
this.dataset = dataset;
this.inference = InferenceFactory.getInstance(method);
this.properties = new InferenceProperties();
this.trained = false;
properties.setSeed(123L);
properties.setNumIteration(100);
}
/**
* Get the vocabulary from its ID.
* @param vocabID
* @return the vocabulary
* @throws IllegalArgumentException vocabID <= 0 || the number of vocabularies < vocabID
*/
public String getVocab(int vocabID) {
if (vocabID < 0 || dataset.getNumVocabs() < vocabID) {
throw new IllegalArgumentException();
}
return dataset.get(vocabID).toString();
}
/**
* Run model inference.
*/
public void run() {
if (properties == null) inference.setUp(this);
else inference.setUp(this, properties);
inference.run();
trained = true;
}
/**
* Get hyperparameter alpha corresponding to topic.
* @param topic
* @return alpha corresponding to topicID
* @throws ArrayIndexOutOfBoundsException topic < 0 || #topics <= topic
*/
public double getAlpha(final int topic) {
if (topic < 0 || numTopics < topic) {
throw new ArrayIndexOutOfBoundsException(topic);
}
return hyperparameters.alpha(topic);
}
public double getSumAlpha() {
return hyperparameters.sumAlpha();
}
public double getBeta() {
return hyperparameters.beta();
}
public int getNumTopics() {
return numTopics;
}
public BagOfWords getBow() {
return dataset.getBow();
}
/**
* Get the value of doc-topic probability \theta_{docID, topicID}.
* @param docID
* @param topicID
* @return the value of doc-topic probability
* @throws IllegalArgumentException docID <= 0 || #docs < docID || topicID < 0 || #topics <= topicID
* @throws IllegalStateException call this method when the inference has not been finished yet
*/
public double getTheta(final int docID, final int topicID) {
if (docID < 0 || dataset.getNumDocs() < docID
|| topicID < 0 || numTopics < topicID) {
throw new IllegalArgumentException();
}
if (!trained) {
throw new IllegalStateException();
}
return inference.getTheta(docID, topicID);
}
/**
* Get the value of topic-vocab probability \phi_{topicID, vocabID}.
* @param topicID
* @param vocabID
* @return the value of topic-vocab probability
* @throws IllegalArgumentException topicID < 0 || #topics <= topicID || vocabID <= 0
* @throws IllegalStateException call this method when the inference has not been finished yet
*/
public double getPhi(final int topicID, final int vocabID) {
if (topicID < 0 || numTopics < topicID || vocabID < 0) {
throw new IllegalArgumentException();
}
if (!trained) {
throw new IllegalStateException();
}
return inference.getPhi(topicID, vocabID);
}
public Vocabularies getVocabularies() {
return dataset.getVocabularies();
}
public List<Pair<String, Double>> getVocabsSortedByPhi(int topicID) {
return inference.getVocabsSortedByPhi(topicID);
}
/**
* Compute the perplexity of trained LDA for the test bag-of-words dataset.
* @param testDataset
* @return the perplexity for the test bag-of-words dataset
*/
public double computePerplexity(Dataset testDataset) {
double loglikelihood = 0.0;
for (int d = 0; d < testDataset.getNumDocs(); ++d) {
for (Integer w : testDataset.getWords(d)) {
double sum = 0.0;
for (int t = 0; t < getNumTopics(); ++t) {
sum += getTheta(d, t) * getPhi(t, w.intValue());
}
loglikelihood += Math.log(sum);
}
}
return Math.exp(-loglikelihood / testDataset.getNumWords());
}
}

View File

@@ -0,0 +1,62 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference;
import java.util.List;
import forge.lda.lda.LDA;
import org.apache.commons.lang3.tuple.Pair;
public interface Inference {
/**
* Set up for inference.
* @param lda
*/
public void setUp(LDA lda);
/**
* Set up for inference.
* The configuration is read from properties class.
* @param lda
* @param properties
*/
public void setUp(LDA lda, InferenceProperties properties);
/**
* Run model inference.
*/
public void run();
/**
* Get the value of doc-topic probability \theta_{docID, topicID}.
* @param docID
* @param topicID
* @return the value of doc-topic probability
*/
public double getTheta(final int docID, final int topicID);
/**
* Get the value of topic-vocab probability \phi_{topicID, vocabID}.
* @param topicID
* @param vocabID
* @return the value of topic-vocab probability
*/
public double getPhi(final int topicID, final int vocabID);
public List<Pair<String, Double>> getVocabsSortedByPhi(int topicID);
}

View File

@@ -0,0 +1,37 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference;
public class InferenceFactory {
private InferenceFactory() {}
/**
* Get the LDAInference instance specified by the argument
* @param method
* @return the instance which implements LDAInference
*/
public static Inference getInstance(InferenceMethod method) {
Inference clazz = null;
try {
clazz = (Inference)Class.forName(method.toString()).newInstance();
} catch (ReflectiveOperationException roe) {
roe.printStackTrace();
}
return clazz;
}
}

View File

@@ -0,0 +1,36 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference;
import forge.lda.lda.inference.internal.CollapsedGibbsSampler;
public enum InferenceMethod {
CGS(CollapsedGibbsSampler.class.getName()),
// more
;
private String className;
private InferenceMethod(final String className) {
this.className = className;
}
@Override
public String toString() {
return className;
}
}

View File

@@ -0,0 +1,70 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.util.Properties;
public class InferenceProperties {
public static String PROPERTIES_FILE_NAME = "lda.properties";
PropertiesLoader loader = new PropertiesLoader();
private Properties properties;
public InferenceProperties() {
this.properties = new Properties();
}
public void setSeed(Long seed){
properties.setProperty("seed",seed.toString());
}
public void setNumIteration(Integer numIteration){
properties.setProperty("numIteration",numIteration.toString());
}
/**
* Load properties.
* @param fileName
* @throws IOException
* @throws NullPointerException fileName is null
*/
public void load(String fileName) throws IOException {
if (fileName == null) throw new NullPointerException();
InputStream stream = loader.getInputStream(fileName);
if (stream == null) throw new NullPointerException();
properties.load(stream);
}
public Long seed() {
return Long.parseLong(properties.getProperty("seed"));
}
public Integer numIteration() {
return Integer.parseInt(properties.getProperty("numIteration"));
}
}
class PropertiesLoader {
public InputStream getInputStream(String fileName) throws FileNotFoundException {
if (fileName == null) throw new NullPointerException();
return new FileInputStream(fileName);
}
}

View File

@@ -0,0 +1,65 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference.internal;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
class AssignmentCounter {
private List<Integer> counter;
AssignmentCounter(int size) {
if (size <= 0) throw new IllegalArgumentException();
this.counter = IntStream.generate(() -> 0)
.limit(size)
.boxed()
.collect(Collectors.toList());
}
int size() {
return counter.size();
}
int get(int id) {
if (id < 0 || counter.size() <= id) {
throw new IllegalArgumentException();
}
return counter.get(id);
}
int getSum() {
return counter.stream().reduce(Integer::sum).get();
}
void increment(int id) {
if (id < 0 || counter.size() <= id) {
throw new IllegalArgumentException();
}
counter.set(id, counter.get(id) + 1);
}
void decrement(int id) {
if (id < 0 || counter.size() <= id) {
throw new IllegalArgumentException();
}
if (counter.get(id) == 0) {
throw new IllegalStateException();
}
counter.set(id, counter.get(id) - 1);
}
}

View File

@@ -0,0 +1,220 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference.internal;
import java.util.Arrays;
import java.util.List;
import java.util.stream.IntStream;
import forge.lda.lda.LDA;
import forge.lda.lda.inference.Inference;
import forge.lda.lda.inference.InferenceProperties;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
import org.apache.commons.math3.distribution.IntegerDistribution;
import forge.lda.dataset.Vocabulary;
public class CollapsedGibbsSampler implements Inference {
private LDA lda;
private Topics topics;
private Documents documents;
private int numIteration;
private static final long DEFAULT_SEED = 0L;
private static final int DEFAULT_NUM_ITERATION = 100;
// ready for Gibbs sampling
private boolean ready;
public CollapsedGibbsSampler() {
ready = false;
}
@Override
public void setUp(LDA lda, InferenceProperties properties) {
if (properties == null) {
setUp(lda);
return;
}
this.lda = lda;
initialize(this.lda);
final long seed = properties.seed() != null ? properties.seed() : DEFAULT_SEED;
initializeTopicAssignment(seed);
this.numIteration
= properties.numIteration() != null ? properties.numIteration() : DEFAULT_NUM_ITERATION;
this.ready = true;
}
@Override
public void setUp(LDA lda) {
if (lda == null) throw new NullPointerException();
this.lda = lda;
initialize(this.lda);
initializeTopicAssignment(DEFAULT_SEED);
this.numIteration = DEFAULT_NUM_ITERATION;
this.ready = true;
}
private void initialize(LDA lda) {
assert lda != null;
this.topics = new Topics(lda);
this.documents = new Documents(lda);
}
public boolean isReady() {
return ready;
}
public int getNumIteration() {
return numIteration;
}
public void setNumIteration(final int numIteration) {
this.numIteration = numIteration;
}
@Override
public void run() {
if (!ready) {
throw new IllegalStateException("instance has not set up yet");
}
for (int i = 1; i <= numIteration; ++i) {
System.out.println("Iteration " + i + ".");
runSampling();
}
}
/**
* Run collapsed Gibbs sampling [Griffiths and Steyvers 2004].
*/
void runSampling() {
for (Document d : documents.getDocuments()) {
for (int w = 0; w < d.getDocLength(); ++w) {
final Topic oldTopic = topics.get(d.getTopicID(w));
d.decrementTopicCount(oldTopic.id());
final Vocabulary v = d.getVocabulary(w);
oldTopic.decrementVocabCount(v.id());
IntegerDistribution distribution
= getFullConditionalDistribution(lda.getNumTopics(), d.id(), v.id());
final int newTopicID = distribution.sample();
d.setTopicID(w, newTopicID);
d.incrementTopicCount(newTopicID);
final Topic newTopic = topics.get(newTopicID);
newTopic.incrementVocabCount(v.id());
}
}
}
/**
* Get the full conditional distribution over topics.
* docID and vocabID are passed to this distribution for parameters.
* @param numTopics
* @param docID
* @param vocabID
* @return the integer distribution over topics
*/
IntegerDistribution getFullConditionalDistribution(final int numTopics, final int docID, final int vocabID) {
int[] topics = IntStream.range(0, numTopics).toArray();
double[] probabilities = Arrays.stream(topics)
.mapToDouble(t -> getTheta(docID, t) * getPhi(t, vocabID))
.toArray();
return new EnumeratedIntegerDistribution(topics, probabilities);
}
/**
* Initialize the topic assignment.
* @param seed the seed of a pseudo random number generator
*/
void initializeTopicAssignment(final long seed) {
documents.initializeTopicAssignment(topics, seed);
}
/**
* Get the count of topicID assigned to docID.
* @param docID
* @param topicID
* @return the count of topicID assigned to docID
*/
int getDTCount(final int docID, final int topicID) {
if (!ready) throw new IllegalStateException();
if (docID <= 0 || lda.getBow().getNumDocs() < docID
|| topicID < 0 || lda.getNumTopics() <= topicID) {
throw new IllegalArgumentException();
}
return documents.getTopicCount(docID, topicID);
}
/**
* Get the count of vocabID assigned to topicID.
* @param topicID
* @param vocabID
* @return the count of vocabID assigned to topicID
*/
int getTVCount(final int topicID, final int vocabID) {
if (!ready) throw new IllegalStateException();
if (topicID < 0 || lda.getNumTopics() <= topicID || vocabID <= 0) {
throw new IllegalArgumentException();
}
final Topic topic = topics.get(topicID);
return topic.getVocabCount(vocabID);
}
/**
* Get the sum of counts of vocabs assigned to topicID.
* This is the sum of topic-vocab count over vocabs.
* @param topicID
* @return the sum of counts of vocabs assigned to topicID
* @throws IllegalArgumentException topicID < 0 || #topic <= topicID
*/
int getTSumCount(final int topicID) {
if (topicID < 0 || lda.getNumTopics() <= topicID) {
throw new IllegalArgumentException();
}
final Topic topic = topics.get(topicID);
return topic.getSumCount();
}
@Override
public double getTheta(final int docID, final int topicID) {
if (!ready) throw new IllegalStateException();
return documents.getTheta(docID, topicID, lda.getAlpha(topicID), lda.getSumAlpha());
}
@Override
public double getPhi(int topicID, int vocabID) {
if (!ready) throw new IllegalStateException();
return topics.getPhi(topicID, vocabID, lda.getBeta());
}
@Override
public List<Pair<String, Double>> getVocabsSortedByPhi(int topicID) {
return topics.getVocabsSortedByPhi(topicID, lda.getVocabularies(), lda.getBeta());
}
}

View File

@@ -0,0 +1,86 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference.internal;
import java.util.List;
import forge.lda.dataset.Vocabulary;
public class Document {
private final int id;
private TopicCounter topicCount;
private Words words;
private TopicAssignment assignment;
Document(int id, int numTopics, List<Vocabulary> words) {
if (id < 0 || numTopics <= 0) throw new IllegalArgumentException();
this.id = id;
this.topicCount = new TopicCounter(numTopics);
this.words = new Words(words);
this.assignment = new TopicAssignment();
}
int id() {
return id;
}
int getTopicCount(int topicID) {
return topicCount.getTopicCount(topicID);
}
int getDocLength() {
return words.getNumWords();
}
void incrementTopicCount(int topicID) {
topicCount.incrementTopicCount(topicID);
}
void decrementTopicCount(int topicID) {
topicCount.decrementTopicCount(topicID);
}
void initializeTopicAssignment(long seed) {
assignment.initialize(getDocLength(), topicCount.size(), seed);
for (int w = 0; w < getDocLength(); ++w) {
incrementTopicCount(assignment.get(w));
}
}
int getTopicID(int wordID) {
return assignment.get(wordID);
}
void setTopicID(int wordID, int topicID) {
assignment.set(wordID, topicID);
}
Vocabulary getVocabulary(int wordID) {
return words.get(wordID);
}
List<Vocabulary> getWords() {
return words.getWords();
}
double getTheta(int topicID, double alpha, double sumAlpha) {
if (topicID < 0 || alpha <= 0.0 || sumAlpha <= 0.0) {
throw new IllegalArgumentException();
}
return (getTopicCount(topicID) + alpha) / (getDocLength() + sumAlpha);
}
}

View File

@@ -0,0 +1,100 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference.internal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import forge.lda.lda.LDA;
import forge.lda.dataset.BagOfWords;
import forge.lda.dataset.Vocabularies;
import forge.lda.dataset.Vocabulary;
class Documents {
private List<Document> documents;
Documents(LDA lda) {
if (lda == null) throw new NullPointerException();
documents = new ArrayList<>();
for (int d = 0; d < lda.getBow().getNumDocs(); ++d) {
List<Vocabulary> vocabList = getVocabularyList(d, lda.getBow(), lda.getVocabularies());
Document doc = new Document(d, lda.getNumTopics(), vocabList);
documents.add(doc);
}
}
List<Vocabulary> getVocabularyList(int docID, BagOfWords bow, Vocabularies vocabs) {
assert docID > 0 && bow != null && vocabs != null;
//System.out.println(docID);
//System.out.println(bow.getWords(docID).toString());
return bow.getWords(docID).stream()
.map(id -> vocabs.get(id))
.collect(Collectors.toList());
}
int getTopicID(int docID, int wordID) {
return documents.get(docID).getTopicID(wordID);
}
void setTopicID(int docID, int wordID, int topicID) {
documents.get(docID).setTopicID(wordID, topicID);
}
Vocabulary getVocab(int docID, int wordID) {
return documents.get(docID).getVocabulary(wordID);
}
List<Vocabulary> getWords(int docID) {
return documents.get(docID).getWords();
}
List<Document> getDocuments() {
return Collections.unmodifiableList(documents);
}
void incrementTopicCount(int docID, int topicID) {
documents.get(docID).incrementTopicCount(topicID);
}
void decrementTopicCount(int docID, int topicID) {
documents.get(docID).decrementTopicCount(topicID);
}
int getTopicCount(int docID, int topicID) {
return documents.get(docID).getTopicCount(topicID);
}
double getTheta(int docID, int topicID, double alpha, double sumAlpha) {
if (docID < 0 || documents.size() < docID) throw new IllegalArgumentException();
return documents.get(docID).getTheta(topicID, alpha, sumAlpha);
}
void initializeTopicAssignment(Topics topics, long seed) {
for (Document d : getDocuments()) {
d.initializeTopicAssignment(seed);
for (int w = 0; w < d.getDocLength(); ++w) {
final int topicID = d.getTopicID(w);
final Topic topic = topics.get(topicID);
final Vocabulary vocab = d.getVocabulary(w);
topic.incrementVocabCount(vocab.id());
}
}
}
}

View File

@@ -0,0 +1,55 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference.internal;
class Topic {
private final int id;
private final int numVocabs;
private VocabularyCounter counter;
Topic(int id, int numVocabs) {
if (id < 0 || numVocabs <= 0) throw new IllegalArgumentException();
this.id = id;
this.numVocabs = numVocabs;
this.counter = new VocabularyCounter(numVocabs);
}
int id() {
return id;
}
int getVocabCount(int vocabID) {
return counter.getVocabCount(vocabID);
}
int getSumCount() {
return counter.getSumCount();
}
void incrementVocabCount(int vocabID) {
counter.incrementVocabCount(vocabID);
}
void decrementVocabCount(int vocabID) {
counter.decrementVocabCount(vocabID);
}
double getPhi(int vocabID, double beta) {
if (vocabID < 0 || beta <= 0) throw new IllegalArgumentException();
return (getVocabCount(vocabID) + beta) / (getSumCount() + beta * numVocabs);
}
}

View File

@@ -0,0 +1,60 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference.internal;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
class TopicAssignment {
private List<Integer> topicAssignment;
private boolean ready;
TopicAssignment() {
topicAssignment = new ArrayList<>();
ready = false;
}
void set(int wordID, int topicID) {
if (!ready) throw new IllegalStateException();
if (wordID < 0 || topicAssignment.size() <= wordID || topicID < 0) {
throw new IllegalArgumentException();
}
topicAssignment.set(wordID, topicID);
}
int get(int wordID) {
if (!ready) throw new IllegalStateException();
if (wordID < 0 || topicAssignment.size() <= wordID) {
throw new IllegalArgumentException();
}
return topicAssignment.get(wordID);
}
void initialize(int docLength, int numTopics, long seed) {
if (docLength <= 0 || numTopics <= 0) {
throw new IllegalArgumentException();
}
Random random = new Random(seed);
topicAssignment = random.ints(docLength, 0, numTopics)
.boxed()
.collect(Collectors.toList());
ready = true;
}
}

View File

@@ -0,0 +1,45 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference.internal;
class TopicCounter {
private AssignmentCounter topicCount;
TopicCounter(int numTopics) {
this.topicCount = new AssignmentCounter(numTopics);
}
int getTopicCount(int topicID) {
return topicCount.get(topicID);
}
int getDocLength() {
return topicCount.getSum();
}
void incrementTopicCount(int topicID) {
topicCount.increment(topicID);
}
void decrementTopicCount(int topicID) {
topicCount.decrement(topicID);
}
int size() {
return topicCount.size();
}
}

View File

@@ -0,0 +1,85 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference.internal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import forge.lda.lda.LDA;
import forge.lda.dataset.Vocabularies;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
class Topics {
private List<Topic> topics;
Topics(LDA lda) {
if (lda == null) throw new NullPointerException();
topics = new ArrayList<>();
for (int t = 0; t < lda.getNumTopics(); ++t) {
topics.add(new Topic(t, lda.getBow().getNumVocabs()));
}
}
int numTopics() {
return topics.size();
}
Topic get(int id) {
return topics.get(id);
}
int getVocabCount(int topicID, int vocabID) {
return topics.get(topicID).getVocabCount(vocabID);
}
int getSumCount(int topicID) {
return topics.get(topicID).getSumCount();
}
void incrementVocabCount(int topicID, int vocabID) {
topics.get(topicID).incrementVocabCount(vocabID);
}
void decrementVocabCount(int topicID, int vocabID) {
topics.get(topicID).decrementVocabCount(vocabID);
}
double getPhi(int topicID, int vocabID, double beta) {
if (topicID < 0 || topics.size() <= topicID) throw new IllegalArgumentException();
return topics.get(topicID).getPhi(vocabID, beta);
}
List<Pair<String, Double>> getVocabsSortedByPhi(int topicID, Vocabularies vocabs, final double beta) {
if (topicID < 0 || topics.size() <= topicID || vocabs == null || beta <= 0.0) {
throw new IllegalArgumentException();
}
Topic topic = topics.get(topicID);
List<Pair<String, Double>> vocabProbPairs
= vocabs.getVocabularyList()
.stream()
.map(v -> new ImmutablePair<String, Double>(v.toString(), topic.getPhi(v.id(), beta)))
.sorted((p1, p2) -> Double.compare(p2.getRight(), p1.getRight()))
.collect(Collectors.toList());
return Collections.unmodifiableList(vocabProbPairs);
}
}

View File

@@ -0,0 +1,46 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference.internal;
class VocabularyCounter {
private AssignmentCounter vocabCount;
private int sumCount;
VocabularyCounter(int numVocabs) {
this.vocabCount = new AssignmentCounter(numVocabs);
this.sumCount = 0;
}
int getVocabCount(int vocabID) {
if (vocabCount.size() < vocabID) return 0;
else return vocabCount.get(vocabID);
}
int getSumCount() {
return sumCount;
}
void incrementVocabCount(int vocabID) {
vocabCount.increment(vocabID);
++sumCount;
}
void decrementVocabCount(int vocabID) {
vocabCount.decrement(vocabID);
--sumCount;
}
}

View File

@@ -0,0 +1,46 @@
/*
* Copyright 2015 Kohei Yamamoto
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package forge.lda.lda.inference.internal;
import java.util.Collections;
import java.util.List;
import forge.lda.dataset.Vocabulary;
class Words {
private List<Vocabulary> words;
Words(List<Vocabulary> words) {
if (words == null) throw new NullPointerException();
this.words = words;
}
int getNumWords() {
return words.size();
}
Vocabulary get(int id) {
if (id < 0 || words.size() <= id) {
throw new IllegalArgumentException();
}
return words.get(id);
}
List<Vocabulary> getWords() {
return Collections.unmodifiableList(words);
}
}

View File

@@ -65,6 +65,7 @@
<module>forge-gui-mobile-dev</module> <module>forge-gui-mobile-dev</module>
<module>forge-gui-android</module> <module>forge-gui-android</module>
<module>forge-gui-ios</module> <module>forge-gui-ios</module>
<module>forge-lda</module>
</modules> </modules>
<distributionManagement> <distributionManagement>