137 lines
3.6 KiB
JavaScript
137 lines
3.6 KiB
JavaScript
|
/**
|
||
|
* @license Apache-2.0
|
||
|
*
|
||
|
* Copyright (c) 2018 The Stdlib Authors.
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
'use strict';
|
||
|
|
||
|
// MODULES //
|
||
|
|
||
|
var isPositiveInteger = require( '@stdlib/assert/is-positive-integer' );
|
||
|
var randu = require( '@stdlib/random/base/randu' );
|
||
|
var avgMatrix = require( './avg_matrix.js' );
|
||
|
|
||
|
|
||
|
// MAIN //
|
||
|
|
||
|
/**
|
||
|
* Fit model using collapsed Gibbs sampling.
|
||
|
*
|
||
|
* @private
|
||
|
* @param {PositiveInteger} iter - number of sampling iterations
|
||
|
* @param {PositiveInteger} burnin - number of estimates to be thrown away at beginning
|
||
|
* @param {PositiveInteger} thin - number of discarded in-between iterations
|
||
|
* @throws {TypeError} first argument must be a positive integer
|
||
|
* @throws {TypeError} second argument must be a positive integer
|
||
|
* @throws {TypeError} third argument must be a positive integer
|
||
|
*/
|
||
|
function fit( iter, burnin, thin ) {
|
||
|
/* eslint-disable no-invalid-this */
|
||
|
var kalpha;
|
||
|
var wbeta;
|
||
|
var topic;
|
||
|
var theta;
|
||
|
var prob;
|
||
|
var word;
|
||
|
var phi;
|
||
|
var len;
|
||
|
var nt;
|
||
|
var d;
|
||
|
var i;
|
||
|
var j;
|
||
|
var u;
|
||
|
var w;
|
||
|
|
||
|
if ( !isPositiveInteger( iter ) ) {
|
||
|
throw new TypeError( 'invalid argument. First argument must be a positive integer. Value: `' + iter + '`.' );
|
||
|
}
|
||
|
if ( !isPositiveInteger( burnin ) ) {
|
||
|
throw new TypeError( 'invalid argument. Second argument must be a positive integer. Value: `' + burnin + '`.' );
|
||
|
}
|
||
|
if ( !isPositiveInteger( thin ) ) {
|
||
|
throw new TypeError( 'invalid argument. Third argument must be a positive integer. Value: `' + thin + '`.' );
|
||
|
}
|
||
|
|
||
|
wbeta = this.W * this.beta;
|
||
|
kalpha = this.K * this.alpha;
|
||
|
|
||
|
for ( i = 0; i < iter; i++ ) {
|
||
|
for ( d = 0; d < this.D; d++ ) {
|
||
|
for ( w = 0; w < this.ndSum[ d ]; w++ ) {
|
||
|
word = this.w[ d ][ w ];
|
||
|
topic = this.z[ d ][ w ];
|
||
|
|
||
|
this.nw.set( word, topic, this.nw.get( word, topic ) - 1 );
|
||
|
this.nd.set( d, topic, this.nd.get( d, topic ) - 1 );
|
||
|
this.ndSum[ d ] -= 1;
|
||
|
this.nwSum[ topic ] -= 1;
|
||
|
|
||
|
prob = [];
|
||
|
for ( j = 0; j < this.K; j++ ) {
|
||
|
prob.push( ( this.nw.get( word, j ) + this.beta ) /
|
||
|
( this.nwSum[ j ] + wbeta ) *
|
||
|
( this.nd.get( d, j ) + this.alpha ) /
|
||
|
( this.ndSum[ d ] + kalpha ) );
|
||
|
}
|
||
|
for ( j = 1; j < this.K; j++ ) {
|
||
|
prob[ j ] += prob[ j - 1 ];
|
||
|
}
|
||
|
u = prob[ this.K - 1 ] * randu();
|
||
|
topic = 0;
|
||
|
for ( nt = 0; nt < this.K; nt++ ) {
|
||
|
if ( prob[ nt ] > u ) {
|
||
|
topic = nt;
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
// Assign new z_i to counts...
|
||
|
this.nw.set( word, topic, this.nw.get( word, topic ) + 1 );
|
||
|
this.nd.set( d, topic, this.nd.get( d, topic ) + 1 );
|
||
|
this.nwSum[ topic ] += 1;
|
||
|
this.ndSum[ d ] += 1;
|
||
|
|
||
|
this.z[ d ][ w ] = topic;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if ( i % thin === 0 && i > burnin ) {
|
||
|
phi = this.getPhis();
|
||
|
theta = this.getThetas();
|
||
|
|
||
|
this.phiList.push( phi );
|
||
|
this.thetaList.push( theta );
|
||
|
|
||
|
len = this.phiList.length;
|
||
|
if ( len === 1 ) {
|
||
|
this.avgPhi = phi;
|
||
|
} else {
|
||
|
this.avgPhi = avgMatrix( this.avgPhi, phi, len );
|
||
|
}
|
||
|
len = this.thetaList.length;
|
||
|
if ( len === 1 ) {
|
||
|
this.avgTheta = theta;
|
||
|
} else {
|
||
|
this.avgTheta = avgMatrix( this.avgTheta, theta, len );
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
// EXPORTS //
|
||
|
|
||
|
module.exports = fit;
|