/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
http://www.cs.umass.edu/~mccallum/mallet
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.topics;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Formatter;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.TreeSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.logging.Logger;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import com.carrotsearch.hppc.ObjectIntHashMap;
import com.google.errorprone.annotations.Var;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureSequenceWithBigrams;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;
/**
* Simple parallel threaded implementation of LDA,
* following Newman, Asuncion, Smyth and Welling, Distributed Algorithms for Topic Models
* JMLR (2009), with SparseLDA sampling scheme and data structure from
* Yao, Mimno and McCallum, Efficient Methods for Topic Model Inference on Streaming Document Collections, KDD (2009).
*
* @author David Mimno, Andrew McCallum
*/
public class ParallelTopicModel implements Serializable {
public static final int UNASSIGNED_TOPIC = -1;
public static Logger logger = MalletLogger.getLogger(ParallelTopicModel.class.getName());
public ArrayList<TopicAssignment> data; // the training instances and their topic assignments
public Alphabet alphabet; // the alphabet for the input data
public LabelAlphabet topicAlphabet; // the alphabet for the topics
public int numTopics; // Number of topics to be fit
// These values are used to encode type/topic counts as
// count/topic pairs in a single int.
public int topicMask;
public int topicBits;
public int numTypes;
public long totalTokens;
public double[] alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics
public double alphaSum;
public double beta; // Prior on per-topic multinomial distribution over words
public double betaSum;
public boolean usingSymmetricAlpha = false;
public static final double DEFAULT_BETA = 0.01;
public int[][] typeTopicCounts; // indexed by <feature index, topic index>
public int[] tokensPerTopic; // indexed by <topic index>
// for dirichlet estimation
public int[] docLengthCounts; // histogram of document sizes
public int[][] topicDocCounts; // histogram of document/topic counts, indexed by <topic index, sequence position index>
public int numIterations = 1000;
public int burninPeriod = 200;
public int saveSampleInterval = 10;
public int optimizeInterval = 50;
public int temperingInterval = 0;
public int showTopicsInterval = 50;
public int wordsPerTopic = 7;
public int saveStateInterval = 0;
public String stateFilename = null;
public int saveModelInterval = 0;
public String modelFilename = null;
public int randomSeed = -1;
public NumberFormat formatter;
public boolean printLogLikelihood = true;
// The number of times each type appears in the corpus
int[] typeTotals;
// The max over typeTotals, used for beta optimization
int maxTypeCount;
int numThreads = 1;
public ParallelTopicModel (int numberOfTopics) {
this (numberOfTopics, numberOfTopics, DEFAULT_BETA);
}
public ParallelTopicModel (int numberOfTopics, double alphaSum, double beta) {
this (newLabelAlphabet (numberOfTopics), alphaSum, beta);
}
private static LabelAlphabet newLabelAlphabet (int numTopics) {
LabelAlphabet ret = new LabelAlphabet();
for (int i = 0; i < numTopics; i++) {
ret.lookupIndex("topic"+i);
}
return ret;
}
public ParallelTopicModel (LabelAlphabet topicAlphabet, double alphaSum, double beta) {
this.data = new ArrayList<TopicAssignment>();
this.topicAlphabet = topicAlphabet;
this.alphaSum = alphaSum;
this.beta = beta;
setNumTopics(topicAlphabet.size());
formatter = NumberFormat.getInstance();
formatter.setMaximumFractionDigits(5);
logger.info("Mallet LDA: " + numTopics + " topics, " + topicBits + " topic bits, " +
Integer.toBinaryString(topicMask) + " topic mask");
}
public Alphabet getAlphabet() { return alphabet; }
public LabelAlphabet getTopicAlphabet() { return topicAlphabet; }
public int getNumTopics() { return numTopics; }
/** Set or reset the number of topics. This method will not change any token-topic assignments,
so it should only be used before initializing or restoring a previously saved state. */
public void setNumTopics(int numTopics) {
this.numTopics = numTopics;
if (Integer.bitCount(numTopics) == 1) {
// exact power of 2
topicMask = numTopics - 1;
topicBits = Integer.bitCount(topicMask);
}
else {
// otherwise add an extra bit
topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
topicBits = Integer.bitCount(topicMask);
}
this.alpha = new double[numTopics];
Arrays.fill(alpha, alphaSum / numTopics);
tokensPerTopic = new int[numTopics];
}
public ArrayList<TopicAssignment> getData() { return data; }
public int[][] getTypeTopicCounts() { return typeTopicCounts; }
public int[] getTokensPerTopic() { return tokensPerTopic; }
public void setNumIterations (int numIterations) {
this.numIterations = numIterations;
}
public void setBurninPeriod (int burninPeriod) {
this.burninPeriod = burninPeriod;
}
public void setTopicDisplay(int interval, int n) {
this.showTopicsInterval = interval;
this.wordsPerTopic = n;
}
public void setRandomSeed(int seed) {
randomSeed = seed;
}
/** Interval for optimizing Dirichlet hyperparameters */
public void setOptimizeInterval(int interval) {
this.optimizeInterval = interval;
// Make sure we always have at least one sample
// before optimizing hyperparameters
if (saveSampleInterval > optimizeInterval) {
saveSampleInterval = optimizeInterval;
}
}
public void setSymmetricAlpha(boolean b) {
usingSymmetricAlpha = b;
}
public void setTemperingInterval(int interval) {
temperingInterval = interval;
}
public void setNumThreads(int threads) {
this.numThreads = threads;
}
/** Define how often and where to save a text representation of the current state.
* Files are GZipped.
*