223 lines
5.5 KiB
JavaScript
223 lines
5.5 KiB
JavaScript
|
/**
|
||
|
* @license Apache-2.0
|
||
|
*
|
||
|
* Copyright (c) 2018 The Stdlib Authors.
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
'use strict';
|
||
|
|
||
|
// MODULES //
|
||
|
|
||
|
var isNonNegativeInteger = require( '@stdlib/assert/is-nonnegative-integer' );
|
||
|
var isPositiveInteger = require( '@stdlib/assert/is-positive-integer' );
|
||
|
var isStringArray = require( '@stdlib/assert/is-string-array' );
|
||
|
var setReadOnly = require( '@stdlib/utils/define-read-only-property' );
|
||
|
var contains = require( '@stdlib/assert/contains' );
|
||
|
var tokenize = require( './../../tokenize' );
|
||
|
var Int32Array = require( '@stdlib/array/int32' );
|
||
|
var matrix = require( './matrix.js' );
|
||
|
var getThetas = require( './get_thetas.js' );
|
||
|
var validate = require( './validate.js' );
|
||
|
var getPhis = require( './get_phis.js' );
|
||
|
var init = require( './init.js' );
|
||
|
var fit = require( './fit.js' );
|
||
|
|
||
|
|
||
|
// FUNCTIONS //
|
||
|
|
||
|
/**
|
||
|
* Find index of the value in vocabulary equal to the supplied search value.
|
||
|
*
|
||
|
* @private
|
||
|
* @param {Array} vocab - vocabulary
|
||
|
* @param {string} searchVal - search value
|
||
|
* @returns {integer} index in vocab if search value is found, -1 otherwise
|
||
|
*/
|
||
|
function findIndex( vocab, searchVal ) {
|
||
|
var i;
|
||
|
for ( i = 0; i < vocab.length; i++ ) {
|
||
|
if ( vocab[ i ] === searchVal ) {
|
||
|
return i;
|
||
|
}
|
||
|
}
|
||
|
return -1;
|
||
|
}
|
||
|
|
||
|
|
||
|
// MAIN //
|
||
|
|
||
|
/**
|
||
|
* Latent Dirichlet Allocation via collapsed Gibbs sampling.
|
||
|
*
|
||
|
* @param {StringArray} documents - document corpus
|
||
|
* @param {PositiveInteger} K - number of topics
|
||
|
* @param {Options} [options] - options object
|
||
|
* @param {PositiveNumber} [options.alpha=50/K] - Dirichlet hyper-parameter of topic vector theta:
|
||
|
* @param {PositiveNumber} [options.beta=0.1] - Dirichlet hyper-parameter for word vector phi
|
||
|
* @throws {TypeError} first argument must be an array of strings
|
||
|
* @throws {TypeError} second argument must be a positive integer
|
||
|
* @throws {TypeError} must provide valid options
|
||
|
* @returns {Object} model object
|
||
|
*/
|
||
|
function lda( documents, K, options ) {
|
||
|
var target;
|
||
|
var vocab;
|
||
|
var model;
|
||
|
var alpha;
|
||
|
var beta;
|
||
|
var opts;
|
||
|
var err;
|
||
|
var pos;
|
||
|
var nd;
|
||
|
var it;
|
||
|
var wd;
|
||
|
var D;
|
||
|
var d;
|
||
|
var i;
|
||
|
var j;
|
||
|
var W;
|
||
|
var w;
|
||
|
|
||
|
if ( !isStringArray( documents ) ) {
|
||
|
throw new TypeError( 'invalid argument. First argument must be a string array. Value: `' + documents + '`.' );
|
||
|
}
|
||
|
if ( !isPositiveInteger( K ) ) {
|
||
|
throw new TypeError( 'invalid argument. Number of topics `K` must be a positive integer. Value: `' + K + '`.' );
|
||
|
}
|
||
|
opts = {};
|
||
|
if ( arguments.length > 2 ) {
|
||
|
err = validate( opts, options );
|
||
|
if ( err ) {
|
||
|
throw err;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Number of documents:
|
||
|
D = documents.length;
|
||
|
|
||
|
// Hyper-parameter for Dirichlet distribution of topic vector theta:
|
||
|
alpha = opts.alpha || 50 / K;
|
||
|
|
||
|
// Hyper-parameter of Dirichlet distribution of phi:
|
||
|
beta = opts.beta || 0.1;
|
||
|
|
||
|
// Extract words & construct vocabulary:s
|
||
|
vocab = [];
|
||
|
w = [];
|
||
|
pos = 0;
|
||
|
for ( d = 0; d < D; d++ ) {
|
||
|
w.push( [] );
|
||
|
wd = tokenize( documents[ d ] );
|
||
|
nd = wd.length;
|
||
|
for ( i = 0; i < nd; i++ ) {
|
||
|
target = wd[ i ];
|
||
|
it = findIndex( vocab, target );
|
||
|
if ( it === -1 ) {
|
||
|
vocab.push( target );
|
||
|
w[ d ].push( pos );
|
||
|
pos += 1;
|
||
|
} else {
|
||
|
w[ d ].push( it );
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
// Size of vocabulary:
|
||
|
W = vocab.length;
|
||
|
|
||
|
model = {};
|
||
|
|
||
|
// Attach read-only properties:
|
||
|
setReadOnly( model, 'K', K );
|
||
|
setReadOnly( model, 'D', D );
|
||
|
setReadOnly( model, 'W', W );
|
||
|
setReadOnly( model, 'alpha', alpha );
|
||
|
setReadOnly( model, 'beta', beta );
|
||
|
|
||
|
// Attach methods:
|
||
|
setReadOnly( model, 'init', init );
|
||
|
setReadOnly( model, 'fit', fit );
|
||
|
setReadOnly( model, 'getPhis', getPhis );
|
||
|
setReadOnly( model, 'getThetas', getThetas );
|
||
|
setReadOnly( model, 'getTerms', getTerms );
|
||
|
|
||
|
model.nwSum = new Int32Array( K );
|
||
|
model.ndSum = new Int32Array( D );
|
||
|
model.nw = matrix( [ W, K ], 'int32' );
|
||
|
model.nd = matrix( [ D, K ], 'int32' );
|
||
|
|
||
|
model.phiList = [];
|
||
|
model.thetaList = [];
|
||
|
|
||
|
model.w = w;
|
||
|
model.init();
|
||
|
|
||
|
return model;
|
||
|
|
||
|
/**
|
||
|
* Get top terms for the specified topic.
|
||
|
*
|
||
|
* @private
|
||
|
* @param {NonNegativeInteger} k - topic
|
||
|
* @param {PositiveInteger} [no=10] - number of terms
|
||
|
* @throws {TypeError} first argument must be a nonnegative integer smaller than the total number of topics
|
||
|
* @throws {TypeError} second argument must be a positive integer
|
||
|
* @returns {Array} word probability array
|
||
|
*/
|
||
|
function getTerms( k, no ) {
|
||
|
/* eslint-disable no-invalid-this */
|
||
|
var skip;
|
||
|
var phi;
|
||
|
var ret;
|
||
|
var max;
|
||
|
var mid;
|
||
|
var i;
|
||
|
|
||
|
if ( !isNonNegativeInteger( k ) || k >= K ) {
|
||
|
throw new TypeError( 'invalid argument. First argument must be a nonnegative integer smaller than the total number of topics. Value: `' + k + '`.' );
|
||
|
}
|
||
|
if ( no ) {
|
||
|
if ( !isPositiveInteger( no ) ) {
|
||
|
throw new TypeError( 'invalid argument. Second argument must be a positive integer. Value: `' + no + '`.' );
|
||
|
}
|
||
|
} else {
|
||
|
no = 10;
|
||
|
}
|
||
|
|
||
|
ret = [];
|
||
|
skip = [];
|
||
|
for ( i = 0; i < no; i++ ) {
|
||
|
max = 0;
|
||
|
for ( j = 0; j < this.W; j++ ) {
|
||
|
phi = this.avgPhi.get( k, j );
|
||
|
if ( phi > max && !contains( skip, j ) ) {
|
||
|
max = phi;
|
||
|
mid = j;
|
||
|
}
|
||
|
}
|
||
|
skip.push( mid );
|
||
|
ret.push({
|
||
|
'word': vocab[ mid ],
|
||
|
'prob': max
|
||
|
});
|
||
|
}
|
||
|
return ret;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
// EXPORTS //
|
||
|
|
||
|
module.exports = lda;
|