/*
* Copyright (C) 2007 by
*
* Xuan-Hieu Phan
* hieuxuan@ecei.tohoku.ac.jp or pxhieu@gmail.com
* Graduate School of Information Sciences
* Tohoku University
*
* GibbsLDA++ is a free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published
* by the Free Software Foundation; either version 2 of the License,
* or (at your option) any later version.
*
* GibbsLDA++ is distributed in the hope that it will be useful, but
* WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with GibbsLDA++; if not, write to the Free Software Foundation,
* Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA.
*/
/*
* References:
* + The Java code of Gregor Heinrich (gregor@arbylon.net)
* http://www.arbylon.net/projects/LdaGibbsSampler.java
* + "Parameter estimation for text analysis" by Gregor Heinrich
* http://www.arbylon.net/publications/text-est.pdf
*/
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <math.h>
#include "constants.h"
#include "strtokenizer.h"
#include "utils.h"
#include "dataset.h"
#include "model.h"
using namespace std;
model::~model() {
if (p) {
delete p;
}
if (ptrndata) {
delete ptrndata;
}
if (pnewdata) {
delete pnewdata;
}
if (z) {
for (int m = 0; m < M; m++) {
if (z[m]) {
delete z[m];
}
}
}
if (nw) {
for (int w = 0; w < V; w++) {
if (nw[w]) {
delete nw[w];
}
}
}
if (nd) {
for (int m = 0; m < M; m++) {
if (nd[m]) {
delete nd[m];
}
}
}
if (nwsum) {
delete nwsum;
}
if (ndsum) {
delete ndsum;
}
if (theta) {
for (int m = 0; m < M; m++) {
if (theta[m]) {
delete theta[m];
}
}
}
if (phi) {
for (int k = 0; k < K; k++) {
if (phi[k]) {
delete phi[k];
}
}
}
// only for inference
if (newz) {
for (int m = 0; m < newM; m++) {
if (newz[m]) {
delete newz[m];
}
}
}
if (newnw) {
for (int w = 0; w < newV; w++) {
if (newnw[w]) {
delete newnw[w];
}
}
}
if (newnd) {
for (int m = 0; m < newM; m++) {
if (newnd[m]) {
delete newnd[m];
}
}
}
if (newnwsum) {
delete newnwsum;
}
if (newndsum) {
delete newndsum;
}
if (newtheta) {
for (int m = 0; m < newM; m++) {
if (newtheta[m]) {
delete newtheta[m];
}
}
}
if (newphi) {
for (int k = 0; k < K; k++) {
if (newphi[k]) {
delete newphi[k];
}
}
}
}
void model::set_default_values() {
wordmapfile = "wordmap.txt";
trainlogfile = "trainlog.txt";
tassign_suffix = ".tassign";
theta_suffix = ".theta";
phi_suffix = ".phi";
others_suffix = ".others";
twords_suffix = ".twords";
dir = "./";
dfile = "trndocs.dat";
model_name = "model-final";
model_status = MODEL_STATUS_UNKNOWN;
ptrndata = NULL;
pnewdata = NULL;
M = 0;
V = 0;
K = 100;
alpha = 50.0 / K;
beta = 0.1;
niters = 2000;
liter = 0;
savestep = 200;
twords = 0;
withrawstrs = 0;
p = NULL;
z = NULL;
nw = NULL;
nd = NULL;
nwsum = NULL;
ndsum = NULL;
theta = NULL;
phi = NULL;
newM = 0;
newV = 0;
newz = NULL;
newnw = NULL;
newnd = NULL;
newnwsum = NULL;
newndsum = NULL;
newtheta = NULL;
newphi = NULL;
}
int model::parse_args(int argc, char ** argv) {
return utils::parse_args(argc, argv, this);
}
int model::init(int argc, char ** argv) {
// call parse_args
if (parse_args(argc, argv)) {
return 1;
}
if (model_status == MODEL_STATUS_EST) {
// estimating the model from scratch
if (init_est()) {
return 1;
}
} else if (model_status == MODEL_STATUS_ESTC) {
// estimating the model from a previously estimated one
if (init_estc()) {
return 1;
}
} else if (model_status == MODEL_STATUS_INF) {
// do inference
if (init_inf()) {
return 1;
}
}
return 0;
}
int model::load_model(string model_name) {
int i, j;
string filename = dir + model_name + tassign_suffix;
FILE * fin = fopen(filename.c_str(), "r");
if (!fin) {
printf("Cannot open file %d to load model!\n", filename.c_str());
return 1;
}
char buff[BUFF_SIZE_LONG];
string line;
// allocate memory for z and ptrndata
z = new int*[M];
ptrndata = new dataset(M);
ptrndata->V = V;
for (i = 0; i < M; i++) {
char * pointer = fgets(buff, BUFF_SIZE_LONG, fin);
if (!pointer) {
printf("Invalid word-topic assignment file, check the number of docs!\n");
return 1;
}
line = buff;
strtokenizer strtok(line, " \t\r\n");
int length = strtok.count_tokens();
vector<int> words;
vector<int> topics;
for (j = 0; j < length; j++) {
string token = strtok.token(j);
strtokenizer tok(token, ":");
if (tok.count_tokens() != 2) {
printf("Invalid word-topic assignment line!\n");
return 1;
}
words.push_back(atoi(tok.token(0).c_str()));
topics.push_back(atoi(tok.token(1).c_str()));
}
// allocate and add new document to the corpus
document * pdoc = new document(words);
ptrndata->add_doc(pdoc, i);
// assign values for z
z[i] = new int[topics.size()];
for (j = 0; j < topics.size(); j++) {
z[i][j] = topics[j];
}
}
fclose(fin);
return 0;
}
int model::save_model(string model_name) {
if (save_model_tassign(dir + model_name + tassign_suffix)) {
return 1;
}
if (save_model_others(dir + model_name + others_suffix)) {
return 1;
}
if (save_model_theta(dir + model_name + theta_suffix)) {
return 1;
}
if (save_model_phi(dir + model_name + phi_suffix)) {
return 1;
}
if (twords > 0) {
if (save_model_twords(dir + model_name + twords_suffix)) {
return 1;
}
}
return 0;
}
int model::save_model_tassign(string filename) {
int i, j;
FILE * fout = fopen(filename.c_str(), "w");
if (!fout) {
printf("Cannot open file %s to save!\n", filename.c_str());
return 1;
}
// wirte docs with topic assignments for words
for (i = 0; i < ptrndata->M; i++) {
for (j = 0; j < ptrndata->docs[i]->length; j++) {
fprintf(fout, "%d:%d ", ptrndata->docs[i]->words[j], z[i][j]);
}
fprintf(fout, "\n");
}
fclose(fout);
return 0;
}
int model::save_model_theta(string filename) {
FILE * fout = fopen(filename.c_str(), "w");
if (!fout) {
printf("Cannot open file %s to save!\n", filename.c_str());
return 1;
}
for (int i = 0; i < M; i++) {
for (int j = 0; j < K; j++) {
fprintf(fout, "%f ", theta[i][j]);
}
fprintf(fout, "\n");
}
fclose(fout);
return 0;
}
int model::save_model_phi(string filename) {
FILE * fout = fopen(filename.c_str(), "w");
if (!fout) {
printf("Cannot open file %s to save!\n", filename.c_str());
return 1;
}
for (int i = 0; i < K; i++) {
for (int j = 0; j < V; j++) {
fprintf(fout, "%f ", phi[i][j]);
}
fprintf(fout, "\n");
}
fclose(fout);
return 0;
}
int model::save_model_others(string filename) {
FILE * fout = fopen(filename.c_str(), "w");
if (!fout) {
printf("Cannot open file %s to save!\n", filename.c_str());
return 1;
}
fprintf(fout, "alpha=%f\n", alpha);
fprintf(fout, "beta=%f\n", beta);
fprintf(fout, "ntopics=%d\n", K);
fprintf(fout, "ndocs=%d\n", M);
fprintf(fout, "nwords=%d\n", V);
fprintf(fout, "liter=%d\n", liter);
fclose(fout);
return 0;
}
int model::save_model_twords(string filename) {
FILE * fout = fopen(filename.c_str(),