mirror of
https://github.com/Card-Forge/forge.git
synced 2025-11-20 12:48:00 +00:00
Merge branch 'ldastream' into 'master'
LDA module See merge request core-developers/forge!4251
This commit is contained in:
@@ -17,9 +17,9 @@ import forge.model.FModel;
|
|||||||
public class ArchetypeDeckGenerator extends DeckProxy implements Comparable<ArchetypeDeckGenerator> {
|
public class ArchetypeDeckGenerator extends DeckProxy implements Comparable<ArchetypeDeckGenerator> {
|
||||||
public static List<DeckProxy> getMatrixDecks(GameFormat format, boolean isForAi){
|
public static List<DeckProxy> getMatrixDecks(GameFormat format, boolean isForAi){
|
||||||
final List<DeckProxy> decks = new ArrayList<>();
|
final List<DeckProxy> decks = new ArrayList<>();
|
||||||
for(Archetype archetype: CardArchetypeLDAGenerator.ldaArchetypes.get(format.getName())) {
|
for(Archetype archetype: CardArchetypeLDAGenerator.ldaArchetypes.get(format.getName())) {
|
||||||
decks.add(new ArchetypeDeckGenerator(archetype, format, isForAi));
|
decks.add(new ArchetypeDeckGenerator(archetype, format, isForAi));
|
||||||
}
|
}
|
||||||
|
|
||||||
return decks;
|
return decks;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -331,12 +331,10 @@ public class DeckgenUtil {
|
|||||||
return deck;
|
return deck;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param selection {@link java.lang.String} array
|
* @param selection {@link java.lang.String} array
|
||||||
* @return {@link forge.deck.Deck}
|
* @return {@link forge.deck.Deck}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
public static Deck buildColorDeck(List<String> selection, Predicate<PaperCard> formatFilter, boolean forAi) {
|
public static Deck buildColorDeck(List<String> selection, Predicate<PaperCard> formatFilter, boolean forAi) {
|
||||||
try {
|
try {
|
||||||
final Deck deck;
|
final Deck deck;
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package forge.deck.io;
|
package forge.deck.io;
|
||||||
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
|||||||
37
forge-lda/pom.xml
Normal file
37
forge-lda/pom.xml
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<parent>
|
||||||
|
<artifactId>forge</artifactId>
|
||||||
|
<groupId>forge</groupId>
|
||||||
|
<version>1.6.40-SNAPSHOT</version>
|
||||||
|
</parent>
|
||||||
|
|
||||||
|
<artifactId>forge-lda</artifactId>
|
||||||
|
<packaging>jar</packaging>
|
||||||
|
<name>Forge LDA</name>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<parsedVersion.majorVersion>0</parsedVersion.majorVersion>
|
||||||
|
<parsedVersion.minorVersion>0</parsedVersion.minorVersion>
|
||||||
|
<parsedVersion.incrementalVersion>0</parsedVersion.incrementalVersion>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>forge</groupId>
|
||||||
|
<artifactId>forge-gui</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-lang3</artifactId>
|
||||||
|
<version>3.8.1</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.apache.commons</groupId>
|
||||||
|
<artifactId>commons-math3</artifactId>
|
||||||
|
<version>3.6.1</version>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
</project>
|
||||||
389
forge-lda/src/main/java/forge/lda/LDAModelGenetrator.java
Normal file
389
forge-lda/src/main/java/forge/lda/LDAModelGenetrator.java
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
package forge.lda;
|
||||||
|
|
||||||
|
import com.google.common.base.Function;
|
||||||
|
import com.google.common.base.Predicate;
|
||||||
|
import com.google.common.base.Predicates;
|
||||||
|
import com.google.common.collect.Iterables;
|
||||||
|
import com.google.common.collect.Lists;
|
||||||
|
|
||||||
|
//import forge.GuiDesktop;
|
||||||
|
import forge.StaticData;
|
||||||
|
import forge.card.CardRules;
|
||||||
|
import forge.card.CardRulesPredicates;
|
||||||
|
import forge.deck.Deck;
|
||||||
|
import forge.deck.DeckFormat;
|
||||||
|
import forge.deck.io.Archetype;
|
||||||
|
import forge.deck.io.CardThemedLDAIO;
|
||||||
|
import forge.deck.io.DeckStorage;
|
||||||
|
import forge.lda.dataset.Dataset;
|
||||||
|
import forge.lda.lda.LDA;
|
||||||
|
import forge.game.GameFormat;
|
||||||
|
import forge.item.PaperCard;
|
||||||
|
import forge.localinstance.properties.ForgeConstants;
|
||||||
|
import forge.localinstance.properties.ForgePreferences;
|
||||||
|
import forge.model.FModel;
|
||||||
|
import forge.util.storage.IStorage;
|
||||||
|
import forge.util.storage.StorageImmediatelySerialized;
|
||||||
|
import org.apache.commons.lang3.ArrayUtils;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
import static forge.lda.lda.inference.InferenceMethod.CGS;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Created by maustin on 09/05/2017.
|
||||||
|
*/
|
||||||
|
public final class LDAModelGenetrator {
|
||||||
|
|
||||||
|
public static Map<String, Map<String,List<List<Pair<String, Double>>>>> ldaPools = new HashMap<>();
|
||||||
|
public static Map<String, List<Archetype>> ldaArchetypes = new HashMap<>();
|
||||||
|
|
||||||
|
|
||||||
|
public static final void main(String[] args){
|
||||||
|
//GuiBase.setInterface(new GuiDesktop());
|
||||||
|
FModel.initialize(null, new Function<ForgePreferences, Void>() {
|
||||||
|
@Override
|
||||||
|
public Void apply(ForgePreferences preferences) {
|
||||||
|
preferences.setPref(ForgePreferences.FPref.LOAD_CARD_SCRIPTS_LAZILY, false);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
initialize();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean initialize(){
|
||||||
|
List<String> formatStrings = new ArrayList<>();
|
||||||
|
formatStrings.add(FModel.getFormats().getStandard().getName());
|
||||||
|
formatStrings.add(FModel.getFormats().getPioneer().getName());
|
||||||
|
formatStrings.add(FModel.getFormats().getModern().getName());
|
||||||
|
formatStrings.add("Legacy");
|
||||||
|
formatStrings.add("Vintage");
|
||||||
|
formatStrings.add(DeckFormat.Commander.toString());
|
||||||
|
|
||||||
|
for (String formatString : formatStrings){
|
||||||
|
if(!initializeFormat(formatString)){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Try to load matrix .dat files, otherwise check for deck folders and build .dat, otherwise return false **/
|
||||||
|
public static boolean initializeFormat(String format){
|
||||||
|
Map<String,List<List<Pair<String, Double>>>> formatMap = CardThemedLDAIO.loadLDA(format);
|
||||||
|
List<Archetype> lda = CardThemedLDAIO.loadRawLDA(format);
|
||||||
|
if(formatMap==null) {
|
||||||
|
try {
|
||||||
|
if(lda==null) {
|
||||||
|
if (CardThemedLDAIO.getMatrixFolder(format).exists()) {
|
||||||
|
if (format.equals(FModel.getFormats().getStandard().getName())) {
|
||||||
|
lda = initializeFormat(FModel.getFormats().getStandard());
|
||||||
|
} else if (format.equals(FModel.getFormats().getModern().getName())) {
|
||||||
|
lda = initializeFormat(FModel.getFormats().getModern());
|
||||||
|
} else if (format != DeckFormat.Commander.toString()) {
|
||||||
|
lda = initializeFormat(FModel.getFormats().get(format));
|
||||||
|
}
|
||||||
|
CardThemedLDAIO.saveRawLDA(format, lda);
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (format.equals(FModel.getFormats().getStandard().getName())) {
|
||||||
|
formatMap = loadFormat(FModel.getFormats().getStandard(), lda);
|
||||||
|
} else if (format.equals(FModel.getFormats().getModern().getName())) {
|
||||||
|
formatMap = loadFormat(FModel.getFormats().getModern(), lda);
|
||||||
|
} else if (format != DeckFormat.Commander.toString()) {
|
||||||
|
formatMap = loadFormat(FModel.getFormats().get(format), lda);;
|
||||||
|
}
|
||||||
|
CardThemedLDAIO.saveLDA(format, formatMap);
|
||||||
|
}catch (Exception e){
|
||||||
|
e.printStackTrace();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ldaPools.put(format, formatMap);
|
||||||
|
ldaArchetypes.put(format, lda);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static Map<String,List<List<Pair<String, Double>>>> loadFormat(GameFormat format,List<Archetype> lda) throws Exception{
|
||||||
|
|
||||||
|
List<List<Pair<String, Double>>> topics = new ArrayList<>();
|
||||||
|
Set<String> cards = new HashSet<String>();
|
||||||
|
for (int t = 0; t < lda.size(); ++t) {
|
||||||
|
List<Pair<String, Double>> topic = new ArrayList<>();
|
||||||
|
Set<String> topicCards = new HashSet<>();
|
||||||
|
List<Pair<String, Double>> highRankVocabs = lda.get(t).getCardProbabilities();
|
||||||
|
if (highRankVocabs.get(0).getRight()<=0.01d){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
System.out.print("t" + t + ": ");
|
||||||
|
int i = 0;
|
||||||
|
while (topic.size()<=40&&i<highRankVocabs.size()) {
|
||||||
|
String cardName = highRankVocabs.get(i).getLeft();;
|
||||||
|
PaperCard card = StaticData.instance().getCommonCards().getUniqueByName(cardName);
|
||||||
|
if(card == null){
|
||||||
|
System.out.println("Card " + cardName + " is MISSING!");
|
||||||
|
i++;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if(!card.getRules().getType().isBasicLand()){
|
||||||
|
if(highRankVocabs.get(i).getRight()>=0.005d) {
|
||||||
|
topicCards.add(cardName);
|
||||||
|
}
|
||||||
|
System.out.println("[" + highRankVocabs.get(i).getLeft() + "," + highRankVocabs.get(i).getRight() + "],");
|
||||||
|
topic.add(highRankVocabs.get(i));
|
||||||
|
}
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
System.out.println();
|
||||||
|
if(topic.size()>18) {
|
||||||
|
cards.addAll(topicCards);
|
||||||
|
topics.add(topic);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Map<String,List<List<Pair<String, Double>>>> cardTopicMap = new HashMap<>();
|
||||||
|
for (String card:cards){
|
||||||
|
List<List<Pair<String, Double>>> cardTopics = new ArrayList<>();
|
||||||
|
for( List<Pair<String, Double>> topic:topics){
|
||||||
|
if(topicContains(card,topic)){
|
||||||
|
cardTopics.add(topic);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cardTopicMap.put(card,cardTopics);
|
||||||
|
}
|
||||||
|
return cardTopicMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static List<Archetype> initializeFormat(GameFormat format) throws Exception{
|
||||||
|
Dataset dataset = new Dataset(format);
|
||||||
|
|
||||||
|
//estimate number of topics to attempt to find using power law
|
||||||
|
final int numTopics = Float.valueOf(347f*dataset.getNumDocs()/(2892f + dataset.getNumDocs())).intValue();
|
||||||
|
System.out.println("Num Topics = " + numTopics);
|
||||||
|
LDA lda = new LDA(0.1, 0.1, numTopics, dataset, CGS);
|
||||||
|
lda.run();
|
||||||
|
System.out.println(lda.computePerplexity(dataset));
|
||||||
|
|
||||||
|
//sort decks by topic
|
||||||
|
Map<Integer,List<Deck>> topicDecks = new HashMap<>();
|
||||||
|
|
||||||
|
int deckNum=0;
|
||||||
|
for(Deck deck: dataset.getBow().getLegalDecks()){
|
||||||
|
double maxTheta = 0;
|
||||||
|
int mainTopic = 0;
|
||||||
|
for (int t = 0; t < lda.getNumTopics(); ++t){
|
||||||
|
double theta = lda.getTheta(deckNum,t);
|
||||||
|
if (theta > maxTheta){
|
||||||
|
maxTheta = theta;
|
||||||
|
mainTopic = t;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if(topicDecks.containsKey(mainTopic)){
|
||||||
|
topicDecks.get(mainTopic).add(deck);
|
||||||
|
}else{
|
||||||
|
List<Deck> decks = new ArrayList<>();
|
||||||
|
decks.add(deck);
|
||||||
|
topicDecks.put(mainTopic,decks);
|
||||||
|
}
|
||||||
|
++deckNum;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
List<Archetype> unfilteredTopics = new ArrayList<>();
|
||||||
|
for (int t = 0; t < lda.getNumTopics(); ++t) {
|
||||||
|
List<Pair<String, Double>> highRankVocabs = lda.getVocabsSortedByPhi(t);
|
||||||
|
Double min = 1d;
|
||||||
|
for(Pair<String, Double> p:highRankVocabs){
|
||||||
|
if(p.getRight()<min){
|
||||||
|
min=p.getRight();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
List<Pair<String, Double>> topRankVocabs = new ArrayList<>();
|
||||||
|
for(Pair<String, Double> p:highRankVocabs){
|
||||||
|
if(p.getRight()>min){
|
||||||
|
topRankVocabs.add(p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//generate names for topics
|
||||||
|
List<Deck> decks = topicDecks.get(t);
|
||||||
|
if(decks==null){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
LinkedHashMap<String, Integer> wordCounts = new LinkedHashMap<>();
|
||||||
|
for( Deck deck: decks){
|
||||||
|
String name = deck.getName().replaceAll(".* Version - ","").replaceAll(" \\((Modern|Pioneer|Standard|Legacy|Vintage), #[0-9]+\\)","");
|
||||||
|
name = name.replaceAll("\\(Modern|Pioneer|Standard|Legacy|Vintage|Fuck|Shit|Cunt\\)","");
|
||||||
|
String[] tokens = name.split(" ");
|
||||||
|
for(String rawtoken: tokens){
|
||||||
|
String token = rawtoken.toLowerCase();
|
||||||
|
if (token.matches("[0-9]+")) {
|
||||||
|
//skip just numbers as not useful
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if(wordCounts.containsKey(token)){
|
||||||
|
wordCounts.put(token, wordCounts.get(token)+1);
|
||||||
|
}else{
|
||||||
|
wordCounts.put(token, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Map<String, Integer> sortedWordCounts = sortByValue(wordCounts);
|
||||||
|
|
||||||
|
List<String> topWords = new ArrayList<>();
|
||||||
|
Iterator<String> wordIterator = sortedWordCounts.keySet().iterator();
|
||||||
|
while(wordIterator.hasNext() && topWords.size() < 3){
|
||||||
|
topWords.add(wordIterator.next());
|
||||||
|
}
|
||||||
|
StringJoiner sb = new StringJoiner(" ");
|
||||||
|
for(String word: wordCounts.keySet()){
|
||||||
|
if(topWords.contains(word)){
|
||||||
|
sb.add(word);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
String deckName = sb.toString();
|
||||||
|
System.out.println("============ " + deckName);
|
||||||
|
System.out.println(decks.toString());
|
||||||
|
|
||||||
|
unfilteredTopics.add(new Archetype(topRankVocabs,deckName,decks.size()));
|
||||||
|
}
|
||||||
|
Comparator<Archetype> archetypeComparator = new Comparator<Archetype>() {
|
||||||
|
@Override
|
||||||
|
public int compare(Archetype o1, Archetype o2) {
|
||||||
|
return o2.getDeckCount().compareTo(o1.getDeckCount());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Collections.sort(unfilteredTopics,archetypeComparator);
|
||||||
|
return unfilteredTopics;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
private static <K, V> Map<K, V> sortByValue(Map<K, V> map) {
|
||||||
|
List<Map.Entry<K, V>> list = new LinkedList<>(map.entrySet());
|
||||||
|
Collections.sort(list, new Comparator<Object>() {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
public int compare(Object o1, Object o2) {
|
||||||
|
return ((Comparable<V>) ((Map.Entry<K, V>) (o2)).getValue()).compareTo(((Map.Entry<K, V>) (o1)).getValue());
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Map<K, V> result = new LinkedHashMap<>();
|
||||||
|
for (Iterator<Map.Entry<K, V>> it = list.iterator(); it.hasNext();) {
|
||||||
|
Map.Entry<K, V> entry = (Map.Entry<K, V>) it.next();
|
||||||
|
result.put(entry.getKey(), entry.getValue());
|
||||||
|
}
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean topicContains(String card, List<Pair<String, Double>> topic){
|
||||||
|
for(Pair<String,Double> pair:topic){
|
||||||
|
if(pair.getLeft().equals(card)){
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static HashMap<String,List<Map.Entry<PaperCard,Integer>>> initializeCommanderFormat(){
|
||||||
|
|
||||||
|
IStorage<Deck> decks = new StorageImmediatelySerialized<Deck>("Generator",
|
||||||
|
new DeckStorage(new File(ForgeConstants.DECK_GEN_DIR,DeckFormat.Commander.toString()),
|
||||||
|
ForgeConstants.DECK_GEN_DIR, false),
|
||||||
|
true);
|
||||||
|
|
||||||
|
//get all cards
|
||||||
|
final Iterable<PaperCard> cards = Iterables.filter(FModel.getMagicDb().getCommonCards().getUniqueCards()
|
||||||
|
, Predicates.compose(Predicates.not(CardRulesPredicates.Presets.IS_BASIC_LAND_NOT_WASTES), PaperCard.FN_GET_RULES));
|
||||||
|
List<PaperCard> cardList = Lists.newArrayList(cards);
|
||||||
|
cardList.add(FModel.getMagicDb().getCommonCards().getCard("Wastes"));
|
||||||
|
Map<String, Integer> cardIntegerMap = new HashMap<>();
|
||||||
|
Map<Integer, PaperCard> integerCardMap = new HashMap<>();
|
||||||
|
Map<String, Integer> legendIntegerMap = new HashMap<>();
|
||||||
|
Map<Integer, PaperCard> integerLegendMap = new HashMap<>();
|
||||||
|
//generate lookups for cards to link card names to matrix columns
|
||||||
|
for (int i=0; i<cardList.size(); ++i){
|
||||||
|
cardIntegerMap.put(cardList.get(i).getName(), i);
|
||||||
|
integerCardMap.put(i, cardList.get(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
//filter to just legal commanders
|
||||||
|
List<PaperCard> legends = Lists.newArrayList(Iterables.filter(cardList,Predicates.compose(
|
||||||
|
new Predicate<CardRules>() {
|
||||||
|
@Override
|
||||||
|
public boolean apply(CardRules rules) {
|
||||||
|
return DeckFormat.Commander.isLegalCommander(rules);
|
||||||
|
}
|
||||||
|
}, PaperCard.FN_GET_RULES)));
|
||||||
|
|
||||||
|
//generate lookups for legends to link commander names to matrix rows
|
||||||
|
for (int i=0; i<legends.size(); ++i){
|
||||||
|
legendIntegerMap.put(legends.get(i).getName(), i);
|
||||||
|
integerLegendMap.put(i, legends.get(i));
|
||||||
|
}
|
||||||
|
int[][] matrix = new int[legends.size()][cardList.size()];
|
||||||
|
|
||||||
|
//loop through commanders and decks
|
||||||
|
for (PaperCard legend:legends){
|
||||||
|
for (Deck deck:decks){
|
||||||
|
//if the deck has the commander
|
||||||
|
if (deck.getCommanders().contains(legend)){
|
||||||
|
//update the matrix by incrementing the connectivity count for each card in the deck
|
||||||
|
updateLegendMatrix(deck, legend, cardIntegerMap, legendIntegerMap, matrix);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//convert the matrix into a map of pools for each commander
|
||||||
|
HashMap<String,List<Map.Entry<PaperCard,Integer>>> cardPools = new HashMap<>();
|
||||||
|
for (PaperCard card:legends){
|
||||||
|
int col=legendIntegerMap.get(card.getName());
|
||||||
|
int[] distances = matrix[col];
|
||||||
|
int max = (Integer) Collections.max(Arrays.asList(ArrayUtils.toObject(distances)));
|
||||||
|
if (max>0) {
|
||||||
|
List<Map.Entry<PaperCard,Integer>> deckPool=new ArrayList<>();
|
||||||
|
for(int k=0;k<cardList.size(); k++){
|
||||||
|
if(matrix[col][k]>0){
|
||||||
|
deckPool.add(new AbstractMap.SimpleEntry<PaperCard, Integer>(integerCardMap.get(k),matrix[col][k]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cardPools.put(card.getName(), deckPool);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cardPools;
|
||||||
|
}
|
||||||
|
|
||||||
|
//update the matrix by incrementing the connectivity count for each card in the deck
|
||||||
|
public static void updateLegendMatrix(Deck deck, PaperCard legend, Map<String, Integer> cardIntegerMap,
|
||||||
|
Map<String, Integer> legendIntegerMap, int[][] matrix){
|
||||||
|
for (PaperCard pairCard:Iterables.filter(deck.getMain().toFlatList(),
|
||||||
|
Predicates.compose(Predicates.not(CardRulesPredicates.Presets.IS_BASIC_LAND_NOT_WASTES), PaperCard.FN_GET_RULES))){
|
||||||
|
if (!pairCard.getName().equals(legend.getName())){
|
||||||
|
try {
|
||||||
|
int old = matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(pairCard.getName())];
|
||||||
|
matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(pairCard.getName())] = old + 1;
|
||||||
|
}catch (NullPointerException ne){
|
||||||
|
//Todo: Not sure what was failing here
|
||||||
|
ne.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
//add partner commanders to matrix
|
||||||
|
if(deck.getCommanders().size()>1){
|
||||||
|
for(PaperCard partner:deck.getCommanders()){
|
||||||
|
if(!partner.equals(legend)){
|
||||||
|
int old = matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(partner.getName())];
|
||||||
|
matrix[legendIntegerMap.get(legend.getName())][cardIntegerMap.get(partner.getName())] = old + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
210
forge-lda/src/main/java/forge/lda/dataset/BagOfWords.java
Normal file
210
forge-lda/src/main/java/forge/lda/dataset/BagOfWords.java
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.dataset;
|
||||||
|
|
||||||
|
import forge.deck.Deck;
|
||||||
|
import forge.deck.io.DeckStorage;
|
||||||
|
import forge.game.GameFormat;
|
||||||
|
import forge.item.PaperCard;
|
||||||
|
import forge.localinstance.properties.ForgeConstants;
|
||||||
|
import forge.util.storage.IStorage;
|
||||||
|
import forge.util.storage.StorageImmediatelySerialized;
|
||||||
|
|
||||||
|
import java.io.*;
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class is immutable.
|
||||||
|
*/
|
||||||
|
public final class BagOfWords {
|
||||||
|
|
||||||
|
private final int numDocs;
|
||||||
|
private final int numVocabs;
|
||||||
|
private final int numNNZ;
|
||||||
|
private final int numWords;
|
||||||
|
private List<Deck> legalDecks;
|
||||||
|
|
||||||
|
public Vocabularies getVocabs() {
|
||||||
|
return vocabs;
|
||||||
|
}
|
||||||
|
|
||||||
|
private final Vocabularies vocabs;
|
||||||
|
|
||||||
|
// docID -> the vocabs sequence in the doc
|
||||||
|
private Map<Integer, List<Integer>> words;
|
||||||
|
|
||||||
|
// docID -> the doc length
|
||||||
|
private Map<Integer, Integer> docLength;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Read the bag-of-words dataset.
|
||||||
|
* @throws FileNotFoundException
|
||||||
|
* @throws IOException
|
||||||
|
* @throws Exception
|
||||||
|
* @throws NullPointerException filePath is null
|
||||||
|
*/
|
||||||
|
public BagOfWords(GameFormat format) throws FileNotFoundException, IOException, Exception {
|
||||||
|
IStorage<Deck> decks = new StorageImmediatelySerialized<Deck>("Generator", new DeckStorage(new File(ForgeConstants.DECK_GEN_DIR+ForgeConstants.PATH_SEPARATOR+format.getName()),
|
||||||
|
ForgeConstants.DECK_GEN_DIR, false),
|
||||||
|
true);
|
||||||
|
|
||||||
|
Set<PaperCard> cardSet = new HashSet<>();
|
||||||
|
legalDecks = new ArrayList<>();
|
||||||
|
for(Deck deck:decks){
|
||||||
|
try {
|
||||||
|
if (format.isDeckLegal(deck) && deck.getMain().toFlatList().size() == 60) {
|
||||||
|
legalDecks.add(deck);
|
||||||
|
for (PaperCard card : deck.getMain().toFlatList()) {
|
||||||
|
cardSet.add(card);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}catch(Exception e){
|
||||||
|
System.out.println("Skipping deck "+deck.getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
List<PaperCard> cardList = new ArrayList<>(cardSet);
|
||||||
|
|
||||||
|
this.words = new HashMap<>();
|
||||||
|
this.docLength = new HashMap<>();
|
||||||
|
ArrayList<Vocabulary> vocabList = new ArrayList<Vocabulary>();
|
||||||
|
|
||||||
|
int numDocs = legalDecks.size();
|
||||||
|
int numVocabs = cardList.size();
|
||||||
|
int numNNZ = 0;
|
||||||
|
int numWords = 0;
|
||||||
|
|
||||||
|
Map<String, Integer> cardIntegerMap = new HashMap<>();
|
||||||
|
Map<Integer, PaperCard> integerCardMap = new HashMap<>();
|
||||||
|
for (int i=0; i<cardList.size(); ++i){
|
||||||
|
cardIntegerMap.put(cardList.get(i).getName(), i);
|
||||||
|
vocabList.add(new Vocabulary(i,cardList.get(i).getName()));
|
||||||
|
integerCardMap.put(i, cardList.get(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
this.vocabs = new Vocabularies(vocabList);
|
||||||
|
int deckID = 0;
|
||||||
|
for (Deck deck:legalDecks){
|
||||||
|
numNNZ += deck.getMain().countDistinct();
|
||||||
|
List<Integer> cardNumbers = new ArrayList<>();
|
||||||
|
for(PaperCard card:deck.getMain().toFlatList()){
|
||||||
|
if(cardIntegerMap.get(card.getName()) == null){
|
||||||
|
System.out.println(card.getName() + " is missing!!");
|
||||||
|
}
|
||||||
|
cardNumbers.add(cardIntegerMap.get(card.getName()));
|
||||||
|
}
|
||||||
|
words.put(deckID,cardNumbers);
|
||||||
|
numWords+=cardNumbers.size();
|
||||||
|
docLength.put(deckID,cardNumbers.size());
|
||||||
|
deckID ++;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*String s = null;
|
||||||
|
while ((s = reader.readLine()) != null) {
|
||||||
|
List<Integer> numbers
|
||||||
|
= Arrays.asList(s.split(" ")).stream().map(Integer::parseInt).collect(Collectors.toList());
|
||||||
|
|
||||||
|
if (numbers.size() == 1) {
|
||||||
|
if (headerCount == 2) numDocs = numbers.get(0);
|
||||||
|
else if (headerCount == 1) numVocabs = numbers.get(0);
|
||||||
|
else if (headerCount == 0) numNNZ = numbers.get(0);
|
||||||
|
--headerCount;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
else if (numbers.size() == 3) {
|
||||||
|
final int docID = numbers.get(0);
|
||||||
|
final int vocabID = numbers.get(1);
|
||||||
|
final int count = numbers.get(2);
|
||||||
|
|
||||||
|
// Set up the words container
|
||||||
|
if (!words.containsKey(docID)) {
|
||||||
|
words.put(docID, new ArrayList<>());
|
||||||
|
}
|
||||||
|
for (int c = 0; c < count; ++c) {
|
||||||
|
words.get(docID).add(vocabID);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up the doc length map
|
||||||
|
Optional<Integer> currentCount
|
||||||
|
= Optional.ofNullable(docLength.putIfAbsent(docID, count));
|
||||||
|
currentCount.ifPresent(c -> docLength.replace(docID, c + count));
|
||||||
|
|
||||||
|
numWords += count;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
throw new Exception("Invalid dataset form was detected.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
reader.close();*/
|
||||||
|
|
||||||
|
this.numDocs = numDocs;
|
||||||
|
this.numVocabs = numVocabs;
|
||||||
|
this.numNNZ = numNNZ;
|
||||||
|
this.numWords = numWords;
|
||||||
|
|
||||||
|
System.out.println("Num Decks" + this.numDocs);
|
||||||
|
System.out.println("Num Vocabs" + this.numVocabs);
|
||||||
|
System.out.println("Num NNZ" + this.numNNZ);
|
||||||
|
System.out.println("Num Cards" + this.numWords);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumDocs() {
|
||||||
|
return numDocs;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the length of the document.
|
||||||
|
* @param docID
|
||||||
|
* @return length of the document
|
||||||
|
* @throws IllegalArgumentException docID <= 0 || #documents < docID
|
||||||
|
*/
|
||||||
|
public int getDocLength(int docID) {
|
||||||
|
if (docID < 0 || getNumDocs() < docID) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
|
||||||
|
return docLength.get(docID);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the unmodifiable list of words in the document.
|
||||||
|
* @param docID
|
||||||
|
* @return the unmodifiable list of words
|
||||||
|
* @throws IllegalArgumentException docID <= 0 || #documents < docID
|
||||||
|
*/
|
||||||
|
public List<Integer> getWords(final int docID) {
|
||||||
|
if (docID < 0 || getNumDocs() < docID) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
return Collections.unmodifiableList(words.get(docID));
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumVocabs() {
|
||||||
|
return numVocabs;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumNNZ() {
|
||||||
|
return numNNZ;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumWords() {
|
||||||
|
return numWords;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Deck> getLegalDecks() { return legalDecks; }
|
||||||
|
}
|
||||||
|
|
||||||
80
forge-lda/src/main/java/forge/lda/dataset/Dataset.java
Normal file
80
forge-lda/src/main/java/forge/lda/dataset/Dataset.java
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.dataset;
|
||||||
|
|
||||||
|
import forge.game.GameFormat;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class Dataset {
|
||||||
|
private BagOfWords bow;
|
||||||
|
private Vocabularies vocabs;
|
||||||
|
|
||||||
|
public Dataset(GameFormat format) throws Exception {
|
||||||
|
bow = new BagOfWords(format);
|
||||||
|
vocabs = bow.getVocabs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Dataset(BagOfWords bow) {
|
||||||
|
this.bow = bow;
|
||||||
|
this.vocabs = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public BagOfWords getBow() {
|
||||||
|
return bow;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumDocs() {
|
||||||
|
return bow.getNumDocs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getDocLength(int docID) {
|
||||||
|
return bow.getDocLength(docID);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Integer> getWords(int docID) {
|
||||||
|
return bow.getWords(docID);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumVocabs() {
|
||||||
|
return bow.getNumVocabs();
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumNNZ() {
|
||||||
|
return bow.getNumNNZ();
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumWords() {
|
||||||
|
return bow.getNumWords();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Vocabulary get(int id) {
|
||||||
|
return vocabs.get(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int size() {
|
||||||
|
return vocabs.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
public Vocabularies getVocabularies() {
|
||||||
|
return vocabs;
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Vocabulary> getVocabularyList() {
|
||||||
|
return vocabs.getVocabularyList();
|
||||||
|
}
|
||||||
|
}
|
||||||
40
forge-lda/src/main/java/forge/lda/dataset/Vocabularies.java
Normal file
40
forge-lda/src/main/java/forge/lda/dataset/Vocabularies.java
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.dataset;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
public class Vocabularies {
|
||||||
|
private List<Vocabulary> vocabs;
|
||||||
|
|
||||||
|
public Vocabularies(List<Vocabulary> vocabs) {
|
||||||
|
this.vocabs = vocabs;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Vocabulary get(int id) {
|
||||||
|
return vocabs.get(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int size() {
|
||||||
|
return vocabs.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Vocabulary> getVocabularyList() {
|
||||||
|
return Collections.unmodifiableList(vocabs);
|
||||||
|
}
|
||||||
|
}
|
||||||
37
forge-lda/src/main/java/forge/lda/dataset/Vocabulary.java
Normal file
37
forge-lda/src/main/java/forge/lda/dataset/Vocabulary.java
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.dataset;
|
||||||
|
|
||||||
|
public class Vocabulary {
|
||||||
|
private final int id;
|
||||||
|
private String vocabulary;
|
||||||
|
|
||||||
|
public Vocabulary(int id, String vocabulary) {
|
||||||
|
if (vocabulary == null) throw new NullPointerException();
|
||||||
|
this.id = id;
|
||||||
|
this.vocabulary = vocabulary;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int id() {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return vocabulary;
|
||||||
|
}
|
||||||
|
}
|
||||||
46
forge-lda/src/main/java/forge/lda/lda/Alpha.java
Normal file
46
forge-lda/src/main/java/forge/lda/lda/Alpha.java
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
class Alpha {
|
||||||
|
private List<Double> alphas;
|
||||||
|
|
||||||
|
Alpha(double alpha, int numTopics) {
|
||||||
|
if (alpha <= 0.0 || numTopics <= 0) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
this.alphas = Stream.generate(() -> alpha)
|
||||||
|
.limit(numTopics)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
double get(int i) {
|
||||||
|
return alphas.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
void set(int i, double value) {
|
||||||
|
alphas.set(i, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
double sum() {
|
||||||
|
return alphas.stream().reduce(Double::sum).get();
|
||||||
|
}
|
||||||
|
}
|
||||||
54
forge-lda/src/main/java/forge/lda/lda/Beta.java
Normal file
54
forge-lda/src/main/java/forge/lda/lda/Beta.java
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
|
class Beta {
|
||||||
|
private List<Double> betas;
|
||||||
|
|
||||||
|
Beta(double beta, int numVocabs) {
|
||||||
|
if (beta <= 0.0 || numVocabs <= 0) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
this.betas = Stream.generate(() -> beta)
|
||||||
|
.limit(numVocabs)
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
Beta(double beta) {
|
||||||
|
if (beta <= 0.0) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
this.betas = Arrays.asList(beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
double get() {
|
||||||
|
return get(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
double get(int i) {
|
||||||
|
return betas.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
void set(int i, double value) {
|
||||||
|
betas.set(i, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
60
forge-lda/src/main/java/forge/lda/lda/Hyperparameters.java
Normal file
60
forge-lda/src/main/java/forge/lda/lda/Hyperparameters.java
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda;
|
||||||
|
|
||||||
|
class Hyperparameters {
|
||||||
|
private Alpha alpha;
|
||||||
|
private Beta beta;
|
||||||
|
|
||||||
|
Hyperparameters(double alpha, double beta, int numTopics, int numVocabs) {
|
||||||
|
this.alpha = new Alpha(alpha, numTopics);
|
||||||
|
this.beta = new Beta(beta, numVocabs);
|
||||||
|
}
|
||||||
|
|
||||||
|
Hyperparameters(double alpha, double beta, int numTopics) {
|
||||||
|
this.alpha = new Alpha(alpha, numTopics);
|
||||||
|
this.beta = new Beta(beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
double alpha(int i) {
|
||||||
|
return alpha.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
double sumAlpha() {
|
||||||
|
return alpha.sum();
|
||||||
|
}
|
||||||
|
|
||||||
|
double beta() {
|
||||||
|
return beta.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
double beta(int i) {
|
||||||
|
return beta.get(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
void setAlpha(int i, double value) {
|
||||||
|
alpha.set(i, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void setBeta(int i, double value) {
|
||||||
|
beta.set(i, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
void setBeta(double value) {
|
||||||
|
beta.set(0, value);
|
||||||
|
}
|
||||||
|
}
|
||||||
183
forge-lda/src/main/java/forge/lda/lda/LDA.java
Normal file
183
forge-lda/src/main/java/forge/lda/lda/LDA.java
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import forge.lda.dataset.BagOfWords;
|
||||||
|
import forge.lda.dataset.Dataset;
|
||||||
|
import forge.lda.dataset.Vocabularies;
|
||||||
|
|
||||||
|
import forge.lda.lda.inference.Inference;
|
||||||
|
import forge.lda.lda.inference.InferenceFactory;
|
||||||
|
import forge.lda.lda.inference.InferenceMethod;
|
||||||
|
import forge.lda.lda.inference.InferenceProperties;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
|
||||||
|
public class LDA {
|
||||||
|
private Hyperparameters hyperparameters;
|
||||||
|
private final int numTopics;
|
||||||
|
private Dataset dataset;
|
||||||
|
private final Inference inference;
|
||||||
|
private InferenceProperties properties;
|
||||||
|
private boolean trained;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @param alpha doc-topic hyperparameter
|
||||||
|
* @param beta topic-vocab hyperparameter
|
||||||
|
* @param numTopics the number of topics
|
||||||
|
* @param dataset dataset
|
||||||
|
* @param bow bag-of-words
|
||||||
|
* @param method inference method
|
||||||
|
*/
|
||||||
|
public LDA(final double alpha, final double beta, final int numTopics,
|
||||||
|
final Dataset dataset, InferenceMethod method) {
|
||||||
|
this(alpha, beta, numTopics, dataset, method, InferenceProperties.PROPERTIES_FILE_NAME);
|
||||||
|
}
|
||||||
|
|
||||||
|
LDA(final double alpha, final double beta, final int numTopics,
|
||||||
|
final Dataset dataset, InferenceMethod method, String propertiesFilePath) {
|
||||||
|
this.hyperparameters = new Hyperparameters(alpha, beta, numTopics);
|
||||||
|
this.numTopics = numTopics;
|
||||||
|
this.dataset = dataset;
|
||||||
|
this.inference = InferenceFactory.getInstance(method);
|
||||||
|
this.properties = new InferenceProperties();
|
||||||
|
this.trained = false;
|
||||||
|
|
||||||
|
properties.setSeed(123L);
|
||||||
|
properties.setNumIteration(100);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the vocabulary from its ID.
|
||||||
|
* @param vocabID
|
||||||
|
* @return the vocabulary
|
||||||
|
* @throws IllegalArgumentException vocabID <= 0 || the number of vocabularies < vocabID
|
||||||
|
*/
|
||||||
|
public String getVocab(int vocabID) {
|
||||||
|
if (vocabID < 0 || dataset.getNumVocabs() < vocabID) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
return dataset.get(vocabID).toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Run model inference.
|
||||||
|
*/
|
||||||
|
public void run() {
|
||||||
|
if (properties == null) inference.setUp(this);
|
||||||
|
else inference.setUp(this, properties);
|
||||||
|
inference.run();
|
||||||
|
trained = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get hyperparameter alpha corresponding to topic.
|
||||||
|
* @param topic
|
||||||
|
* @return alpha corresponding to topicID
|
||||||
|
* @throws ArrayIndexOutOfBoundsException topic < 0 || #topics <= topic
|
||||||
|
*/
|
||||||
|
public double getAlpha(final int topic) {
|
||||||
|
if (topic < 0 || numTopics < topic) {
|
||||||
|
throw new ArrayIndexOutOfBoundsException(topic);
|
||||||
|
}
|
||||||
|
return hyperparameters.alpha(topic);
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getSumAlpha() {
|
||||||
|
return hyperparameters.sumAlpha();
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getBeta() {
|
||||||
|
return hyperparameters.beta();
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumTopics() {
|
||||||
|
return numTopics;
|
||||||
|
}
|
||||||
|
|
||||||
|
public BagOfWords getBow() {
|
||||||
|
return dataset.getBow();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the value of doc-topic probability \theta_{docID, topicID}.
|
||||||
|
* @param docID
|
||||||
|
* @param topicID
|
||||||
|
* @return the value of doc-topic probability
|
||||||
|
* @throws IllegalArgumentException docID <= 0 || #docs < docID || topicID < 0 || #topics <= topicID
|
||||||
|
* @throws IllegalStateException call this method when the inference has not been finished yet
|
||||||
|
*/
|
||||||
|
public double getTheta(final int docID, final int topicID) {
|
||||||
|
if (docID < 0 || dataset.getNumDocs() < docID
|
||||||
|
|| topicID < 0 || numTopics < topicID) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
if (!trained) {
|
||||||
|
throw new IllegalStateException();
|
||||||
|
}
|
||||||
|
|
||||||
|
return inference.getTheta(docID, topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the value of topic-vocab probability \phi_{topicID, vocabID}.
|
||||||
|
* @param topicID
|
||||||
|
* @param vocabID
|
||||||
|
* @return the value of topic-vocab probability
|
||||||
|
* @throws IllegalArgumentException topicID < 0 || #topics <= topicID || vocabID <= 0
|
||||||
|
* @throws IllegalStateException call this method when the inference has not been finished yet
|
||||||
|
*/
|
||||||
|
public double getPhi(final int topicID, final int vocabID) {
|
||||||
|
if (topicID < 0 || numTopics < topicID || vocabID < 0) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
if (!trained) {
|
||||||
|
throw new IllegalStateException();
|
||||||
|
}
|
||||||
|
|
||||||
|
return inference.getPhi(topicID, vocabID);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Vocabularies getVocabularies() {
|
||||||
|
return dataset.getVocabularies();
|
||||||
|
}
|
||||||
|
|
||||||
|
public List<Pair<String, Double>> getVocabsSortedByPhi(int topicID) {
|
||||||
|
return inference.getVocabsSortedByPhi(topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the perplexity of trained LDA for the test bag-of-words dataset.
|
||||||
|
* @param testDataset
|
||||||
|
* @return the perplexity for the test bag-of-words dataset
|
||||||
|
*/
|
||||||
|
public double computePerplexity(Dataset testDataset) {
|
||||||
|
double loglikelihood = 0.0;
|
||||||
|
for (int d = 0; d < testDataset.getNumDocs(); ++d) {
|
||||||
|
for (Integer w : testDataset.getWords(d)) {
|
||||||
|
double sum = 0.0;
|
||||||
|
for (int t = 0; t < getNumTopics(); ++t) {
|
||||||
|
sum += getTheta(d, t) * getPhi(t, w.intValue());
|
||||||
|
}
|
||||||
|
loglikelihood += Math.log(sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Math.exp(-loglikelihood / testDataset.getNumWords());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import forge.lda.lda.LDA;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
|
||||||
|
public interface Inference {
|
||||||
|
/**
|
||||||
|
* Set up for inference.
|
||||||
|
* @param lda
|
||||||
|
*/
|
||||||
|
public void setUp(LDA lda);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set up for inference.
|
||||||
|
* The configuration is read from properties class.
|
||||||
|
* @param lda
|
||||||
|
* @param properties
|
||||||
|
*/
|
||||||
|
public void setUp(LDA lda, InferenceProperties properties);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Run model inference.
|
||||||
|
*/
|
||||||
|
public void run();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the value of doc-topic probability \theta_{docID, topicID}.
|
||||||
|
* @param docID
|
||||||
|
* @param topicID
|
||||||
|
* @return the value of doc-topic probability
|
||||||
|
*/
|
||||||
|
public double getTheta(final int docID, final int topicID);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the value of topic-vocab probability \phi_{topicID, vocabID}.
|
||||||
|
* @param topicID
|
||||||
|
* @param vocabID
|
||||||
|
* @return the value of topic-vocab probability
|
||||||
|
*/
|
||||||
|
public double getPhi(final int topicID, final int vocabID);
|
||||||
|
|
||||||
|
public List<Pair<String, Double>> getVocabsSortedByPhi(int topicID);
|
||||||
|
}
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference;
|
||||||
|
|
||||||
|
|
||||||
|
public class InferenceFactory {
|
||||||
|
private InferenceFactory() {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the LDAInference instance specified by the argument
|
||||||
|
* @param method
|
||||||
|
* @return the instance which implements LDAInference
|
||||||
|
*/
|
||||||
|
public static Inference getInstance(InferenceMethod method) {
|
||||||
|
Inference clazz = null;
|
||||||
|
try {
|
||||||
|
clazz = (Inference)Class.forName(method.toString()).newInstance();
|
||||||
|
} catch (ReflectiveOperationException roe) {
|
||||||
|
roe.printStackTrace();
|
||||||
|
}
|
||||||
|
return clazz;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference;
|
||||||
|
|
||||||
|
import forge.lda.lda.inference.internal.CollapsedGibbsSampler;
|
||||||
|
|
||||||
|
public enum InferenceMethod {
|
||||||
|
CGS(CollapsedGibbsSampler.class.getName()),
|
||||||
|
// more
|
||||||
|
;
|
||||||
|
|
||||||
|
private String className;
|
||||||
|
|
||||||
|
private InferenceMethod(final String className) {
|
||||||
|
this.className = className;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return className;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,70 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference;
|
||||||
|
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.io.InputStream;
|
||||||
|
import java.util.Properties;
|
||||||
|
|
||||||
|
public class InferenceProperties {
|
||||||
|
public static String PROPERTIES_FILE_NAME = "lda.properties";
|
||||||
|
|
||||||
|
PropertiesLoader loader = new PropertiesLoader();
|
||||||
|
private Properties properties;
|
||||||
|
|
||||||
|
public InferenceProperties() {
|
||||||
|
this.properties = new Properties();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setSeed(Long seed){
|
||||||
|
properties.setProperty("seed",seed.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setNumIteration(Integer numIteration){
|
||||||
|
properties.setProperty("numIteration",numIteration.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Load properties.
|
||||||
|
* @param fileName
|
||||||
|
* @throws IOException
|
||||||
|
* @throws NullPointerException fileName is null
|
||||||
|
*/
|
||||||
|
public void load(String fileName) throws IOException {
|
||||||
|
if (fileName == null) throw new NullPointerException();
|
||||||
|
InputStream stream = loader.getInputStream(fileName);
|
||||||
|
if (stream == null) throw new NullPointerException();
|
||||||
|
properties.load(stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Long seed() {
|
||||||
|
return Long.parseLong(properties.getProperty("seed"));
|
||||||
|
}
|
||||||
|
|
||||||
|
public Integer numIteration() {
|
||||||
|
return Integer.parseInt(properties.getProperty("numIteration"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
class PropertiesLoader {
|
||||||
|
public InputStream getInputStream(String fileName) throws FileNotFoundException {
|
||||||
|
if (fileName == null) throw new NullPointerException();
|
||||||
|
return new FileInputStream(fileName);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,65 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference.internal;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
class AssignmentCounter {
|
||||||
|
private List<Integer> counter;
|
||||||
|
|
||||||
|
AssignmentCounter(int size) {
|
||||||
|
if (size <= 0) throw new IllegalArgumentException();
|
||||||
|
this.counter = IntStream.generate(() -> 0)
|
||||||
|
.limit(size)
|
||||||
|
.boxed()
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() {
|
||||||
|
return counter.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
int get(int id) {
|
||||||
|
if (id < 0 || counter.size() <= id) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
return counter.get(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getSum() {
|
||||||
|
return counter.stream().reduce(Integer::sum).get();
|
||||||
|
}
|
||||||
|
|
||||||
|
void increment(int id) {
|
||||||
|
if (id < 0 || counter.size() <= id) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
counter.set(id, counter.get(id) + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
void decrement(int id) {
|
||||||
|
if (id < 0 || counter.size() <= id) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
if (counter.get(id) == 0) {
|
||||||
|
throw new IllegalStateException();
|
||||||
|
}
|
||||||
|
counter.set(id, counter.get(id) - 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,220 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference.internal;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.IntStream;
|
||||||
|
|
||||||
|
import forge.lda.lda.LDA;
|
||||||
|
import forge.lda.lda.inference.Inference;
|
||||||
|
import forge.lda.lda.inference.InferenceProperties;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
import org.apache.commons.math3.distribution.EnumeratedIntegerDistribution;
|
||||||
|
import org.apache.commons.math3.distribution.IntegerDistribution;
|
||||||
|
|
||||||
|
import forge.lda.dataset.Vocabulary;
|
||||||
|
|
||||||
|
public class CollapsedGibbsSampler implements Inference {
|
||||||
|
private LDA lda;
|
||||||
|
private Topics topics;
|
||||||
|
private Documents documents;
|
||||||
|
private int numIteration;
|
||||||
|
|
||||||
|
private static final long DEFAULT_SEED = 0L;
|
||||||
|
private static final int DEFAULT_NUM_ITERATION = 100;
|
||||||
|
|
||||||
|
// ready for Gibbs sampling
|
||||||
|
private boolean ready;
|
||||||
|
|
||||||
|
public CollapsedGibbsSampler() {
|
||||||
|
ready = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setUp(LDA lda, InferenceProperties properties) {
|
||||||
|
if (properties == null) {
|
||||||
|
setUp(lda);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
this.lda = lda;
|
||||||
|
initialize(this.lda);
|
||||||
|
|
||||||
|
final long seed = properties.seed() != null ? properties.seed() : DEFAULT_SEED;
|
||||||
|
initializeTopicAssignment(seed);
|
||||||
|
|
||||||
|
this.numIteration
|
||||||
|
= properties.numIteration() != null ? properties.numIteration() : DEFAULT_NUM_ITERATION;
|
||||||
|
this.ready = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void setUp(LDA lda) {
|
||||||
|
if (lda == null) throw new NullPointerException();
|
||||||
|
|
||||||
|
this.lda = lda;
|
||||||
|
|
||||||
|
initialize(this.lda);
|
||||||
|
initializeTopicAssignment(DEFAULT_SEED);
|
||||||
|
|
||||||
|
this.numIteration = DEFAULT_NUM_ITERATION;
|
||||||
|
this.ready = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void initialize(LDA lda) {
|
||||||
|
assert lda != null;
|
||||||
|
this.topics = new Topics(lda);
|
||||||
|
this.documents = new Documents(lda);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isReady() {
|
||||||
|
return ready;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumIteration() {
|
||||||
|
return numIteration;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setNumIteration(final int numIteration) {
|
||||||
|
this.numIteration = numIteration;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
if (!ready) {
|
||||||
|
throw new IllegalStateException("instance has not set up yet");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 1; i <= numIteration; ++i) {
|
||||||
|
System.out.println("Iteration " + i + ".");
|
||||||
|
runSampling();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Run collapsed Gibbs sampling [Griffiths and Steyvers 2004].
|
||||||
|
*/
|
||||||
|
void runSampling() {
|
||||||
|
for (Document d : documents.getDocuments()) {
|
||||||
|
for (int w = 0; w < d.getDocLength(); ++w) {
|
||||||
|
final Topic oldTopic = topics.get(d.getTopicID(w));
|
||||||
|
d.decrementTopicCount(oldTopic.id());
|
||||||
|
|
||||||
|
final Vocabulary v = d.getVocabulary(w);
|
||||||
|
oldTopic.decrementVocabCount(v.id());
|
||||||
|
|
||||||
|
IntegerDistribution distribution
|
||||||
|
= getFullConditionalDistribution(lda.getNumTopics(), d.id(), v.id());
|
||||||
|
|
||||||
|
final int newTopicID = distribution.sample();
|
||||||
|
d.setTopicID(w, newTopicID);
|
||||||
|
|
||||||
|
d.incrementTopicCount(newTopicID);
|
||||||
|
final Topic newTopic = topics.get(newTopicID);
|
||||||
|
newTopic.incrementVocabCount(v.id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the full conditional distribution over topics.
|
||||||
|
* docID and vocabID are passed to this distribution for parameters.
|
||||||
|
* @param numTopics
|
||||||
|
* @param docID
|
||||||
|
* @param vocabID
|
||||||
|
* @return the integer distribution over topics
|
||||||
|
*/
|
||||||
|
IntegerDistribution getFullConditionalDistribution(final int numTopics, final int docID, final int vocabID) {
|
||||||
|
int[] topics = IntStream.range(0, numTopics).toArray();
|
||||||
|
double[] probabilities = Arrays.stream(topics)
|
||||||
|
.mapToDouble(t -> getTheta(docID, t) * getPhi(t, vocabID))
|
||||||
|
.toArray();
|
||||||
|
return new EnumeratedIntegerDistribution(topics, probabilities);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Initialize the topic assignment.
|
||||||
|
* @param seed the seed of a pseudo random number generator
|
||||||
|
*/
|
||||||
|
void initializeTopicAssignment(final long seed) {
|
||||||
|
documents.initializeTopicAssignment(topics, seed);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the count of topicID assigned to docID.
|
||||||
|
* @param docID
|
||||||
|
* @param topicID
|
||||||
|
* @return the count of topicID assigned to docID
|
||||||
|
*/
|
||||||
|
int getDTCount(final int docID, final int topicID) {
|
||||||
|
if (!ready) throw new IllegalStateException();
|
||||||
|
if (docID <= 0 || lda.getBow().getNumDocs() < docID
|
||||||
|
|| topicID < 0 || lda.getNumTopics() <= topicID) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
return documents.getTopicCount(docID, topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the count of vocabID assigned to topicID.
|
||||||
|
* @param topicID
|
||||||
|
* @param vocabID
|
||||||
|
* @return the count of vocabID assigned to topicID
|
||||||
|
*/
|
||||||
|
int getTVCount(final int topicID, final int vocabID) {
|
||||||
|
if (!ready) throw new IllegalStateException();
|
||||||
|
if (topicID < 0 || lda.getNumTopics() <= topicID || vocabID <= 0) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
final Topic topic = topics.get(topicID);
|
||||||
|
return topic.getVocabCount(vocabID);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the sum of counts of vocabs assigned to topicID.
|
||||||
|
* This is the sum of topic-vocab count over vocabs.
|
||||||
|
* @param topicID
|
||||||
|
* @return the sum of counts of vocabs assigned to topicID
|
||||||
|
* @throws IllegalArgumentException topicID < 0 || #topic <= topicID
|
||||||
|
*/
|
||||||
|
int getTSumCount(final int topicID) {
|
||||||
|
if (topicID < 0 || lda.getNumTopics() <= topicID) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
final Topic topic = topics.get(topicID);
|
||||||
|
return topic.getSumCount();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getTheta(final int docID, final int topicID) {
|
||||||
|
if (!ready) throw new IllegalStateException();
|
||||||
|
return documents.getTheta(docID, topicID, lda.getAlpha(topicID), lda.getSumAlpha());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public double getPhi(int topicID, int vocabID) {
|
||||||
|
if (!ready) throw new IllegalStateException();
|
||||||
|
return topics.getPhi(topicID, vocabID, lda.getBeta());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<Pair<String, Double>> getVocabsSortedByPhi(int topicID) {
|
||||||
|
return topics.getVocabsSortedByPhi(topicID, lda.getVocabularies(), lda.getBeta());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference.internal;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import forge.lda.dataset.Vocabulary;
|
||||||
|
|
||||||
|
public class Document {
|
||||||
|
private final int id;
|
||||||
|
private TopicCounter topicCount;
|
||||||
|
private Words words;
|
||||||
|
private TopicAssignment assignment;
|
||||||
|
|
||||||
|
Document(int id, int numTopics, List<Vocabulary> words) {
|
||||||
|
if (id < 0 || numTopics <= 0) throw new IllegalArgumentException();
|
||||||
|
this.id = id;
|
||||||
|
this.topicCount = new TopicCounter(numTopics);
|
||||||
|
this.words = new Words(words);
|
||||||
|
this.assignment = new TopicAssignment();
|
||||||
|
}
|
||||||
|
|
||||||
|
int id() {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
int getTopicCount(int topicID) {
|
||||||
|
return topicCount.getTopicCount(topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getDocLength() {
|
||||||
|
return words.getNumWords();
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementTopicCount(int topicID) {
|
||||||
|
topicCount.incrementTopicCount(topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
void decrementTopicCount(int topicID) {
|
||||||
|
topicCount.decrementTopicCount(topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
void initializeTopicAssignment(long seed) {
|
||||||
|
assignment.initialize(getDocLength(), topicCount.size(), seed);
|
||||||
|
for (int w = 0; w < getDocLength(); ++w) {
|
||||||
|
incrementTopicCount(assignment.get(w));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int getTopicID(int wordID) {
|
||||||
|
return assignment.get(wordID);
|
||||||
|
}
|
||||||
|
|
||||||
|
void setTopicID(int wordID, int topicID) {
|
||||||
|
assignment.set(wordID, topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vocabulary getVocabulary(int wordID) {
|
||||||
|
return words.get(wordID);
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Vocabulary> getWords() {
|
||||||
|
return words.getWords();
|
||||||
|
}
|
||||||
|
|
||||||
|
double getTheta(int topicID, double alpha, double sumAlpha) {
|
||||||
|
if (topicID < 0 || alpha <= 0.0 || sumAlpha <= 0.0) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
return (getTopicCount(topicID) + alpha) / (getDocLength() + sumAlpha);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference.internal;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import forge.lda.lda.LDA;
|
||||||
|
import forge.lda.dataset.BagOfWords;
|
||||||
|
import forge.lda.dataset.Vocabularies;
|
||||||
|
import forge.lda.dataset.Vocabulary;
|
||||||
|
|
||||||
|
class Documents {
|
||||||
|
private List<Document> documents;
|
||||||
|
|
||||||
|
Documents(LDA lda) {
|
||||||
|
if (lda == null) throw new NullPointerException();
|
||||||
|
|
||||||
|
documents = new ArrayList<>();
|
||||||
|
for (int d = 0; d < lda.getBow().getNumDocs(); ++d) {
|
||||||
|
List<Vocabulary> vocabList = getVocabularyList(d, lda.getBow(), lda.getVocabularies());
|
||||||
|
Document doc = new Document(d, lda.getNumTopics(), vocabList);
|
||||||
|
documents.add(doc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Vocabulary> getVocabularyList(int docID, BagOfWords bow, Vocabularies vocabs) {
|
||||||
|
assert docID > 0 && bow != null && vocabs != null;
|
||||||
|
//System.out.println(docID);
|
||||||
|
//System.out.println(bow.getWords(docID).toString());
|
||||||
|
return bow.getWords(docID).stream()
|
||||||
|
.map(id -> vocabs.get(id))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
}
|
||||||
|
|
||||||
|
int getTopicID(int docID, int wordID) {
|
||||||
|
return documents.get(docID).getTopicID(wordID);
|
||||||
|
}
|
||||||
|
|
||||||
|
void setTopicID(int docID, int wordID, int topicID) {
|
||||||
|
documents.get(docID).setTopicID(wordID, topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
Vocabulary getVocab(int docID, int wordID) {
|
||||||
|
return documents.get(docID).getVocabulary(wordID);
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Vocabulary> getWords(int docID) {
|
||||||
|
return documents.get(docID).getWords();
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Document> getDocuments() {
|
||||||
|
return Collections.unmodifiableList(documents);
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementTopicCount(int docID, int topicID) {
|
||||||
|
documents.get(docID).incrementTopicCount(topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
void decrementTopicCount(int docID, int topicID) {
|
||||||
|
documents.get(docID).decrementTopicCount(topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getTopicCount(int docID, int topicID) {
|
||||||
|
return documents.get(docID).getTopicCount(topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
double getTheta(int docID, int topicID, double alpha, double sumAlpha) {
|
||||||
|
if (docID < 0 || documents.size() < docID) throw new IllegalArgumentException();
|
||||||
|
return documents.get(docID).getTheta(topicID, alpha, sumAlpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
void initializeTopicAssignment(Topics topics, long seed) {
|
||||||
|
for (Document d : getDocuments()) {
|
||||||
|
d.initializeTopicAssignment(seed);
|
||||||
|
for (int w = 0; w < d.getDocLength(); ++w) {
|
||||||
|
final int topicID = d.getTopicID(w);
|
||||||
|
final Topic topic = topics.get(topicID);
|
||||||
|
final Vocabulary vocab = d.getVocabulary(w);
|
||||||
|
topic.incrementVocabCount(vocab.id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference.internal;
|
||||||
|
|
||||||
|
class Topic {
|
||||||
|
private final int id;
|
||||||
|
private final int numVocabs;
|
||||||
|
private VocabularyCounter counter;
|
||||||
|
|
||||||
|
Topic(int id, int numVocabs) {
|
||||||
|
if (id < 0 || numVocabs <= 0) throw new IllegalArgumentException();
|
||||||
|
this.id = id;
|
||||||
|
this.numVocabs = numVocabs;
|
||||||
|
this.counter = new VocabularyCounter(numVocabs);
|
||||||
|
}
|
||||||
|
|
||||||
|
int id() {
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
int getVocabCount(int vocabID) {
|
||||||
|
return counter.getVocabCount(vocabID);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getSumCount() {
|
||||||
|
return counter.getSumCount();
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementVocabCount(int vocabID) {
|
||||||
|
counter.incrementVocabCount(vocabID);
|
||||||
|
}
|
||||||
|
|
||||||
|
void decrementVocabCount(int vocabID) {
|
||||||
|
counter.decrementVocabCount(vocabID);
|
||||||
|
}
|
||||||
|
|
||||||
|
double getPhi(int vocabID, double beta) {
|
||||||
|
if (vocabID < 0 || beta <= 0) throw new IllegalArgumentException();
|
||||||
|
return (getVocabCount(vocabID) + beta) / (getSumCount() + beta * numVocabs);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference.internal;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
class TopicAssignment {
|
||||||
|
private List<Integer> topicAssignment;
|
||||||
|
private boolean ready;
|
||||||
|
|
||||||
|
TopicAssignment() {
|
||||||
|
topicAssignment = new ArrayList<>();
|
||||||
|
ready = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void set(int wordID, int topicID) {
|
||||||
|
if (!ready) throw new IllegalStateException();
|
||||||
|
if (wordID < 0 || topicAssignment.size() <= wordID || topicID < 0) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
topicAssignment.set(wordID, topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
int get(int wordID) {
|
||||||
|
if (!ready) throw new IllegalStateException();
|
||||||
|
if (wordID < 0 || topicAssignment.size() <= wordID) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
return topicAssignment.get(wordID);
|
||||||
|
}
|
||||||
|
|
||||||
|
void initialize(int docLength, int numTopics, long seed) {
|
||||||
|
if (docLength <= 0 || numTopics <= 0) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
|
||||||
|
Random random = new Random(seed);
|
||||||
|
topicAssignment = random.ints(docLength, 0, numTopics)
|
||||||
|
.boxed()
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
ready = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference.internal;
|
||||||
|
|
||||||
|
class TopicCounter {
|
||||||
|
private AssignmentCounter topicCount;
|
||||||
|
|
||||||
|
TopicCounter(int numTopics) {
|
||||||
|
this.topicCount = new AssignmentCounter(numTopics);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getTopicCount(int topicID) {
|
||||||
|
return topicCount.get(topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getDocLength() {
|
||||||
|
return topicCount.getSum();
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementTopicCount(int topicID) {
|
||||||
|
topicCount.increment(topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
void decrementTopicCount(int topicID) {
|
||||||
|
topicCount.decrement(topicID);
|
||||||
|
}
|
||||||
|
|
||||||
|
int size() {
|
||||||
|
return topicCount.size();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference.internal;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.stream.Collectors;
|
||||||
|
|
||||||
|
import forge.lda.lda.LDA;
|
||||||
|
import forge.lda.dataset.Vocabularies;
|
||||||
|
|
||||||
|
import org.apache.commons.lang3.tuple.ImmutablePair;
|
||||||
|
import org.apache.commons.lang3.tuple.Pair;
|
||||||
|
|
||||||
|
class Topics {
|
||||||
|
private List<Topic> topics;
|
||||||
|
|
||||||
|
Topics(LDA lda) {
|
||||||
|
if (lda == null) throw new NullPointerException();
|
||||||
|
|
||||||
|
topics = new ArrayList<>();
|
||||||
|
for (int t = 0; t < lda.getNumTopics(); ++t) {
|
||||||
|
topics.add(new Topic(t, lda.getBow().getNumVocabs()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int numTopics() {
|
||||||
|
return topics.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
Topic get(int id) {
|
||||||
|
return topics.get(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getVocabCount(int topicID, int vocabID) {
|
||||||
|
return topics.get(topicID).getVocabCount(vocabID);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getSumCount(int topicID) {
|
||||||
|
return topics.get(topicID).getSumCount();
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementVocabCount(int topicID, int vocabID) {
|
||||||
|
topics.get(topicID).incrementVocabCount(vocabID);
|
||||||
|
}
|
||||||
|
|
||||||
|
void decrementVocabCount(int topicID, int vocabID) {
|
||||||
|
topics.get(topicID).decrementVocabCount(vocabID);
|
||||||
|
}
|
||||||
|
|
||||||
|
double getPhi(int topicID, int vocabID, double beta) {
|
||||||
|
if (topicID < 0 || topics.size() <= topicID) throw new IllegalArgumentException();
|
||||||
|
return topics.get(topicID).getPhi(vocabID, beta);
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Pair<String, Double>> getVocabsSortedByPhi(int topicID, Vocabularies vocabs, final double beta) {
|
||||||
|
if (topicID < 0 || topics.size() <= topicID || vocabs == null || beta <= 0.0) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
|
||||||
|
Topic topic = topics.get(topicID);
|
||||||
|
List<Pair<String, Double>> vocabProbPairs
|
||||||
|
= vocabs.getVocabularyList()
|
||||||
|
.stream()
|
||||||
|
.map(v -> new ImmutablePair<String, Double>(v.toString(), topic.getPhi(v.id(), beta)))
|
||||||
|
.sorted((p1, p2) -> Double.compare(p2.getRight(), p1.getRight()))
|
||||||
|
.collect(Collectors.toList());
|
||||||
|
return Collections.unmodifiableList(vocabProbPairs);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference.internal;
|
||||||
|
|
||||||
|
class VocabularyCounter {
|
||||||
|
private AssignmentCounter vocabCount;
|
||||||
|
private int sumCount;
|
||||||
|
|
||||||
|
VocabularyCounter(int numVocabs) {
|
||||||
|
this.vocabCount = new AssignmentCounter(numVocabs);
|
||||||
|
this.sumCount = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int getVocabCount(int vocabID) {
|
||||||
|
if (vocabCount.size() < vocabID) return 0;
|
||||||
|
else return vocabCount.get(vocabID);
|
||||||
|
}
|
||||||
|
|
||||||
|
int getSumCount() {
|
||||||
|
return sumCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementVocabCount(int vocabID) {
|
||||||
|
vocabCount.increment(vocabID);
|
||||||
|
++sumCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
void decrementVocabCount(int vocabID) {
|
||||||
|
vocabCount.decrement(vocabID);
|
||||||
|
--sumCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
/*
|
||||||
|
* Copyright 2015 Kohei Yamamoto
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package forge.lda.lda.inference.internal;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import forge.lda.dataset.Vocabulary;
|
||||||
|
|
||||||
|
class Words {
|
||||||
|
private List<Vocabulary> words;
|
||||||
|
|
||||||
|
Words(List<Vocabulary> words) {
|
||||||
|
if (words == null) throw new NullPointerException();
|
||||||
|
this.words = words;
|
||||||
|
}
|
||||||
|
|
||||||
|
int getNumWords() {
|
||||||
|
return words.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
Vocabulary get(int id) {
|
||||||
|
if (id < 0 || words.size() <= id) {
|
||||||
|
throw new IllegalArgumentException();
|
||||||
|
}
|
||||||
|
return words.get(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
List<Vocabulary> getWords() {
|
||||||
|
return Collections.unmodifiableList(words);
|
||||||
|
}
|
||||||
|
}
|
||||||
1
pom.xml
1
pom.xml
@@ -65,6 +65,7 @@
|
|||||||
<module>forge-gui-mobile-dev</module>
|
<module>forge-gui-mobile-dev</module>
|
||||||
<module>forge-gui-android</module>
|
<module>forge-gui-android</module>
|
||||||
<module>forge-gui-ios</module>
|
<module>forge-gui-ios</module>
|
||||||
|
<module>forge-lda</module>
|
||||||
</modules>
|
</modules>
|
||||||
|
|
||||||
<distributionManagement>
|
<distributionManagement>
|
||||||
|
|||||||
Reference in New Issue
Block a user