Fixed some LDA zero indexed array issues so now it completes.

(cherry picked from commit 5ab7704)
This commit is contained in:
austinio7116
2018-05-05 23:35:28 +01:00
committed by maustin
parent 4a131b63d8
commit a25592ebfd
2 changed files with 9 additions and 9 deletions

View File

@@ -42,7 +42,7 @@ public class Example {
}); });
Dataset dataset = new Dataset(FModel.getFormats().getStandard()); Dataset dataset = new Dataset(FModel.getFormats().getStandard());
final int numTopics = 100; final int numTopics = 50;
LDA lda = new LDA(0.1, 0.1, numTopics, dataset, CGS); LDA lda = new LDA(0.1, 0.1, numTopics, dataset, CGS);
lda.run(); lda.run();
System.out.println(lda.computePerplexity(dataset)); System.out.println(lda.computePerplexity(dataset));
@@ -50,8 +50,8 @@ public class Example {
for (int t = 0; t < numTopics; ++t) { for (int t = 0; t < numTopics; ++t) {
List<Pair<String, Double>> highRankVocabs = lda.getVocabsSortedByPhi(t); List<Pair<String, Double>> highRankVocabs = lda.getVocabsSortedByPhi(t);
System.out.print("t" + t + ": "); System.out.print("t" + t + ": ");
for (int i = 0; i < 5; ++i) { for (int i = 0; i < 20; ++i) {
System.out.print("[" + highRankVocabs.get(i).getLeft() + "," + highRankVocabs.get(i).getRight() + "],"); System.out.println("[" + highRankVocabs.get(i).getLeft() + "," + highRankVocabs.get(i).getRight() + "],");
} }
System.out.println(); System.out.println();
} }

View File

@@ -71,7 +71,7 @@ public class LDA {
* @throws IllegalArgumentException vocabID <= 0 || the number of vocabularies < vocabID * @throws IllegalArgumentException vocabID <= 0 || the number of vocabularies < vocabID
*/ */
public String getVocab(int vocabID) { public String getVocab(int vocabID) {
if (vocabID <= 0 || dataset.getNumVocabs() < vocabID) { if (vocabID < 0 || dataset.getNumVocabs() < vocabID) {
throw new IllegalArgumentException(); throw new IllegalArgumentException();
} }
return dataset.get(vocabID).toString(); return dataset.get(vocabID).toString();
@@ -94,7 +94,7 @@ public class LDA {
* @throws ArrayIndexOutOfBoundsException topic < 0 || #topics <= topic * @throws ArrayIndexOutOfBoundsException topic < 0 || #topics <= topic
*/ */
public double getAlpha(final int topic) { public double getAlpha(final int topic) {
if (topic < 0 || numTopics <= topic) { if (topic < 0 || numTopics < topic) {
throw new ArrayIndexOutOfBoundsException(topic); throw new ArrayIndexOutOfBoundsException(topic);
} }
return hyperparameters.alpha(topic); return hyperparameters.alpha(topic);
@@ -125,8 +125,8 @@ public class LDA {
* @throws IllegalStateException call this method when the inference has not been finished yet * @throws IllegalStateException call this method when the inference has not been finished yet
*/ */
public double getTheta(final int docID, final int topicID) { public double getTheta(final int docID, final int topicID) {
if (docID <= 0 || dataset.getNumDocs() < docID if (docID < 0 || dataset.getNumDocs() < docID
|| topicID < 0 || numTopics <= topicID) { || topicID < 0 || numTopics < topicID) {
throw new IllegalArgumentException(); throw new IllegalArgumentException();
} }
if (!trained) { if (!trained) {
@@ -145,7 +145,7 @@ public class LDA {
* @throws IllegalStateException call this method when the inference has not been finished yet * @throws IllegalStateException call this method when the inference has not been finished yet
*/ */
public double getPhi(final int topicID, final int vocabID) { public double getPhi(final int topicID, final int vocabID) {
if (topicID < 0 || numTopics <= topicID || vocabID <= 0) { if (topicID < 0 || numTopics < topicID || vocabID < 0) {
throw new IllegalArgumentException(); throw new IllegalArgumentException();
} }
if (!trained) { if (!trained) {
@@ -170,7 +170,7 @@ public class LDA {
*/ */
public double computePerplexity(Dataset testDataset) { public double computePerplexity(Dataset testDataset) {
double loglikelihood = 0.0; double loglikelihood = 0.0;
for (int d = 1; d <= testDataset.getNumDocs(); ++d) { for (int d = 0; d < testDataset.getNumDocs(); ++d) {
for (Integer w : testDataset.getWords(d)) { for (Integer w : testDataset.getWords(d)) {
double sum = 0.0; double sum = 0.0;
for (int t = 0; t < getNumTopics(); ++t) { for (int t = 0; t < getNumTopics(); ++t) {