/* Copyright (C) 2005 Univ. of Massachusetts Amherst, Computer Science Dept.
This file is part of "MALLET" (MAchine Learning for LanguagE Toolkit).
This software is provided under the terms of the Common Public License,
version 1.0, as published by http://www.opensource.org. For further
information, see the file `LICENSE' included with this distribution. */
package cc.mallet.topics;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.io.Serializable;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Formatter;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.TreeSet;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.logging.Logger;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import com.carrotsearch.hppc.ObjectIntHashMap;
import com.google.errorprone.annotations.Var;
import cc.mallet.types.Alphabet;
import cc.mallet.types.AugmentableFeatureVector;
import cc.mallet.types.Dirichlet;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.FeatureSequenceWithBigrams;
import cc.mallet.types.IDSorter;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;
import cc.mallet.types.LabelAlphabet;
import cc.mallet.types.LabelSequence;
import cc.mallet.types.MatrixOps;
import cc.mallet.types.RankedFeatureVector;
import cc.mallet.util.MalletLogger;
import cc.mallet.util.Randoms;
* Simple parallel threaded implementation of LDA,
* following Newman, Asuncion, Smyth and Welling, Distributed Algorithms for Topic Models
* JMLR (2009), with SparseLDA sampling scheme and data structure from
* Yao, Mimno and McCallum, Efficient Methods for Topic Model Inference on Streaming Document Collections, KDD (2009).
* @author David Mimno, Andrew McCallum
public class ParallelTopicModel implements Serializable {
public static final int UNASSIGNED_TOPIC = -1;
public static Logger logger = MalletLogger.getLogger(ParallelTopicModel.class.getName());
public ArrayList<TopicAssignment> data; // the training instances and their topic assignments
public Alphabet alphabet; // the alphabet for the input data
public LabelAlphabet topicAlphabet; // the alphabet for the topics
public int numTopics; // Number of topics to be fit
// These values are used to encode type/topic counts as
// count/topic pairs in a single int.
public int topicMask;
public int topicBits;
public int numTypes;
public long totalTokens;
public double[] alpha; // Dirichlet(alpha,alpha,...) is the distribution over topics
public double alphaSum;
public double beta; // Prior on per-topic multinomial distribution over words
public double betaSum;
public boolean usingSymmetricAlpha = false;
public static final double DEFAULT_BETA = 0.01;
public int[][] typeTopicCounts; // indexed by <feature index, topic index>
public int[] tokensPerTopic; // indexed by <topic index>
// for dirichlet estimation
public int[] docLengthCounts; // histogram of document sizes
public int[][] topicDocCounts; // histogram of document/topic counts, indexed by <topic index, sequence position index>
public int numIterations = 1000;
public int burninPeriod = 200;
public int saveSampleInterval = 10;
public int optimizeInterval = 50;
public int temperingInterval = 0;
public int showTopicsInterval = 50;
public int wordsPerTopic = 7;
public int saveStateInterval = 0;
public String stateFilename = null;
public int saveModelInterval = 0;
public String modelFilename = null;
public int randomSeed = -1;
public NumberFormat formatter;
public boolean printLogLikelihood = true;
// The number of times each type appears in the corpus
int[] typeTotals;
// The max over typeTotals, used for beta optimization
int maxTypeCount;
int numThreads = 1;
public ParallelTopicModel (int numberOfTopics) {
this (numberOfTopics, numberOfTopics, DEFAULT_BETA);
public ParallelTopicModel (int numberOfTopics, double alphaSum, double beta) {
this (newLabelAlphabet (numberOfTopics), alphaSum, beta);
private static LabelAlphabet newLabelAlphabet (int numTopics) {
LabelAlphabet ret = new LabelAlphabet();
for (int i = 0; i < numTopics; i++) {
return ret;
public ParallelTopicModel (LabelAlphabet topicAlphabet, double alphaSum, double beta) {
this.data = new ArrayList<TopicAssignment>();
this.topicAlphabet = topicAlphabet;
this.alphaSum = alphaSum;
this.beta = beta;
formatter = NumberFormat.getInstance();
logger.info("Mallet LDA: " + numTopics + " topics, " + topicBits + " topic bits, " +
Integer.toBinaryString(topicMask) + " topic mask");
public Alphabet getAlphabet() { return alphabet; }
public LabelAlphabet getTopicAlphabet() { return topicAlphabet; }
public int getNumTopics() { return numTopics; }
/** Set or reset the number of topics. This method will not change any token-topic assignments,
so it should only be used before initializing or restoring a previously saved state. */
public void setNumTopics(int numTopics) {
this.numTopics = numTopics;
if (Integer.bitCount(numTopics) == 1) {
// exact power of 2
topicMask = numTopics - 1;
topicBits = Integer.bitCount(topicMask);
else {
// otherwise add an extra bit
topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
topicBits = Integer.bitCount(topicMask);
this.alpha = new double[numTopics];
Arrays.fill(alpha, alphaSum / numTopics);
tokensPerTopic = new int[numTopics];
public ArrayList<TopicAssignment> getData() { return data; }
public int[][] getTypeTopicCounts() { return typeTopicCounts; }
public int[] getTokensPerTopic() { return tokensPerTopic; }
public void setNumIterations (int numIterations) {
this.numIterations = numIterations;
public void setBurninPeriod (int burninPeriod) {
this.burninPeriod = burninPeriod;
public void setTopicDisplay(int interval, int n) {
this.showTopicsInterval = interval;
this.wordsPerTopic = n;
public void setRandomSeed(int seed) {
randomSeed = seed;
/** Interval for optimizing Dirichlet hyperparameters */
public void setOptimizeInterval(int interval) {
this.optimizeInterval = interval;
// Make sure we always have at least one sample
// before optimizing hyperparameters
if (saveSampleInterval > optimizeInterval) {
saveSampleInterval = optimizeInterval;
public void setSymmetricAlpha(boolean b) {
usingSymmetricAlpha = b;
public void setTemperingInterval(int interval) {
temperingInterval = interval;
public void setNumThreads(int threads) {
this.numThreads = threads;
/** Define how often and where to save a text representation of the current state.
* Files are GZipped.