93 lines
2.4 KiB
JavaScript
93 lines
2.4 KiB
JavaScript
|
"use strict";
|
||
|
|
||
|
Object.defineProperty(exports, "__esModule", {
|
||
|
value: true
|
||
|
});
|
||
|
exports.createKldivergence = void 0;
|
||
|
|
||
|
var _factory = require("../../utils/factory.js");
|
||
|
|
||
|
var name = 'kldivergence';
|
||
|
var dependencies = ['typed', 'matrix', 'divide', 'sum', 'multiply', 'dotDivide', 'log', 'isNumeric'];
|
||
|
var createKldivergence = /* #__PURE__ */(0, _factory.factory)(name, dependencies, function (_ref) {
|
||
|
var typed = _ref.typed,
|
||
|
matrix = _ref.matrix,
|
||
|
divide = _ref.divide,
|
||
|
sum = _ref.sum,
|
||
|
multiply = _ref.multiply,
|
||
|
dotDivide = _ref.dotDivide,
|
||
|
log = _ref.log,
|
||
|
isNumeric = _ref.isNumeric;
|
||
|
|
||
|
/**
|
||
|
* Calculate the Kullback-Leibler (KL) divergence between two distributions
|
||
|
*
|
||
|
* Syntax:
|
||
|
*
|
||
|
* math.kldivergence(x, y)
|
||
|
*
|
||
|
* Examples:
|
||
|
*
|
||
|
* math.kldivergence([0.7,0.5,0.4], [0.2,0.9,0.5]) //returns 0.24376698773121153
|
||
|
*
|
||
|
*
|
||
|
* @param {Array | Matrix} q First vector
|
||
|
* @param {Array | Matrix} p Second vector
|
||
|
* @return {number} Returns distance between q and p
|
||
|
*/
|
||
|
return typed(name, {
|
||
|
'Array, Array': function ArrayArray(q, p) {
|
||
|
return _kldiv(matrix(q), matrix(p));
|
||
|
},
|
||
|
'Matrix, Array': function MatrixArray(q, p) {
|
||
|
return _kldiv(q, matrix(p));
|
||
|
},
|
||
|
'Array, Matrix': function ArrayMatrix(q, p) {
|
||
|
return _kldiv(matrix(q), p);
|
||
|
},
|
||
|
'Matrix, Matrix': function MatrixMatrix(q, p) {
|
||
|
return _kldiv(q, p);
|
||
|
}
|
||
|
});
|
||
|
|
||
|
function _kldiv(q, p) {
|
||
|
var plength = p.size().length;
|
||
|
var qlength = q.size().length;
|
||
|
|
||
|
if (plength > 1) {
|
||
|
throw new Error('first object must be one dimensional');
|
||
|
}
|
||
|
|
||
|
if (qlength > 1) {
|
||
|
throw new Error('second object must be one dimensional');
|
||
|
}
|
||
|
|
||
|
if (plength !== qlength) {
|
||
|
throw new Error('Length of two vectors must be equal');
|
||
|
} // Before calculation, apply normalization
|
||
|
|
||
|
|
||
|
var sumq = sum(q);
|
||
|
|
||
|
if (sumq === 0) {
|
||
|
throw new Error('Sum of elements in first object must be non zero');
|
||
|
}
|
||
|
|
||
|
var sump = sum(p);
|
||
|
|
||
|
if (sump === 0) {
|
||
|
throw new Error('Sum of elements in second object must be non zero');
|
||
|
}
|
||
|
|
||
|
var qnorm = divide(q, sum(q));
|
||
|
var pnorm = divide(p, sum(p));
|
||
|
var result = sum(multiply(qnorm, log(dotDivide(qnorm, pnorm))));
|
||
|
|
||
|
if (isNumeric(result)) {
|
||
|
return result;
|
||
|
} else {
|
||
|
return Number.NaN;
|
||
|
}
|
||
|
}
|
||
|
});
|
||
|
exports.createKldivergence = createKldivergence;
|