squiggle/packages/squiggle-lang/src/js/index.ts

555 lines
14 KiB
TypeScript
Raw Normal View History

import * as _ from "lodash";
import {
genericDist,
samplingParams,
2022-04-29 13:50:57 +00:00
evaluateUsingExternalBindings,
evaluatePartialUsingExternalBindings,
externalBindings,
expressionValue,
errorValue,
distributionError,
toPointSet,
continuousShape,
discreteShape,
distributionErrorToString,
mixedShape,
sampleSetDist,
symbolicDist,
} from "../rescript/TypescriptInterface.gen";
export {
makeSampleSetDist,
errorValueToString,
distributionErrorToString,
} from "../rescript/TypescriptInterface.gen";
2022-04-08 19:55:04 +00:00
import {
Constructors_mean,
Constructors_sample,
Constructors_pdf,
Constructors_cdf,
Constructors_inv,
Constructors_normalize,
Constructors_isNormalized,
Constructors_toPointSet,
Constructors_toSampleSet,
Constructors_truncate,
Constructors_inspect,
Constructors_toString,
Constructors_toSparkline,
Constructors_algebraicAdd,
Constructors_algebraicMultiply,
Constructors_algebraicDivide,
Constructors_algebraicSubtract,
Constructors_algebraicLogarithm,
2022-04-09 16:37:26 +00:00
Constructors_algebraicPower,
Constructors_pointwiseAdd,
Constructors_pointwiseMultiply,
Constructors_pointwiseDivide,
Constructors_pointwiseSubtract,
Constructors_pointwiseLogarithm,
2022-04-09 16:37:26 +00:00
Constructors_pointwisePower,
} from "../rescript/Distributions/DistributionOperation/DistributionOperation.gen";
2022-04-29 13:50:57 +00:00
export type { samplingParams, errorValue, externalBindings as bindings };
2022-04-11 03:16:31 +00:00
export let defaultSamplingInputs: samplingParams = {
sampleCount: 10000,
xyPointLength: 10000,
};
2022-04-11 06:16:29 +00:00
export type result<a, b> =
| {
tag: "Ok";
2022-04-11 00:48:45 +00:00
value: a;
}
| {
tag: "Error";
2022-04-11 00:48:45 +00:00
value: b;
2022-04-09 02:55:06 +00:00
};
2022-04-11 00:51:43 +00:00
export function resultMap<a, b, c>(
r: result<a, c>,
mapFn: (x: a) => b
): result<b, c> {
if (r.tag === "Ok") {
return { tag: "Ok", value: mapFn(r.value) };
} else {
return r;
}
}
function Ok<a, b>(x: a): result<a, b> {
return { tag: "Ok", value: x };
2022-04-11 06:16:29 +00:00
}
type tagged<a, b> = { tag: a; value: b };
2022-04-11 03:16:31 +00:00
function tag<a, b>(x: a, y: b): tagged<a, b> {
return { tag: x, value: y };
2022-04-11 03:16:31 +00:00
}
export type squiggleExpression =
| tagged<"symbol", string>
| tagged<"string", string>
2022-04-12 23:34:06 +00:00
| tagged<"call", string>
| tagged<"array", squiggleExpression[]>
| tagged<"boolean", boolean>
| tagged<"distribution", Distribution>
| tagged<"number", number>
| tagged<"record", { [key: string]: squiggleExpression }>;
2022-04-11 03:16:31 +00:00
export function run(
squiggleString: string,
2022-04-29 13:50:57 +00:00
bindings?: externalBindings,
2022-04-29 18:46:44 +00:00
samplingInputs?: samplingParams,
parameters?: parameters
2022-04-11 03:16:31 +00:00
): result<squiggleExpression, errorValue> {
2022-04-29 13:50:57 +00:00
let b = bindings ? bindings : {};
2022-04-29 18:46:44 +00:00
let p = parameters ? parameters : {};
2022-04-11 03:16:31 +00:00
let si: samplingParams = samplingInputs
? samplingInputs
: defaultSamplingInputs;
2022-04-29 13:50:57 +00:00
let result: result<expressionValue, errorValue> =
2022-04-29 18:46:44 +00:00
evaluateUsingExternalBindings(squiggleString, mergeParameters(b, p));
return resultMap(result, (x) => createTsExport(x, si));
2022-04-11 03:16:31 +00:00
}
2022-04-29 13:50:57 +00:00
// Run Partial. A partial is a block of code that doesn't return a value
export function runPartial(
squiggleString: string,
2022-04-29 18:46:44 +00:00
bindings?: externalBindings,
_samplingInputs?: samplingParams,
parameters?: parameters
2022-04-29 13:50:57 +00:00
): result<externalBindings, errorValue> {
2022-04-29 18:46:44 +00:00
let b = bindings ? bindings : {};
let p = parameters ? parameters : {};
return evaluatePartialUsingExternalBindings(
squiggleString,
mergeParameters(b, p)
);
}
function mergeParameters(
bindings: externalBindings,
parameters: parameters
): externalBindings {
let transformedParemeters = Object.fromEntries(
Object.entries(parameters).map(([key, value]) => [
"$" + key,
jsValueToBinding(value),
])
);
return _.merge(bindings, transformedParemeters);
}
type parameters = { [key: string]: jsValue };
type jsValue =
| string
| number
| jsValue[]
| { [key: string]: jsValue }
| boolean;
function jsValueToBinding(value: jsValue): rescriptExport {
if (typeof value === "boolean") {
return { TAG: 1, _0: value as boolean };
} else if (typeof value === "string") {
return { TAG: 6, _0: value as string };
} else if (typeof value === "number") {
return { TAG: 4, _0: value as number };
} else if (Array.isArray(value)) {
return { TAG: 0, _0: value.map(jsValueToBinding) };
} else {
// Record
return { TAG: 5, _0: _.mapValues(value, jsValueToBinding) };
}
2022-04-29 13:50:57 +00:00
}
function createTsExport(
x: expressionValue,
sampEnv: samplingParams
): squiggleExpression {
switch (x.tag) {
case "EvArray":
// genType doesn't convert anything more than 2 layers down into {tag: x, value: x}
// format, leaving it as the raw values. This converts the raw values
// directly into typescript values.
//
// The casting here is because genType is about the types of the returned
// values, claiming they are fully recursive when that's not actually the
// case
return tag(
"array",
x.value.map((arrayItem): squiggleExpression => {
switch (arrayItem.tag) {
case "EvRecord":
return tag(
"record",
_.mapValues(arrayItem.value, (recordValue: unknown) =>
convertRawToTypescript(recordValue as rescriptExport, sampEnv)
)
);
case "EvArray":
let y = arrayItem.value as unknown as rescriptExport[];
return tag(
"array",
y.map((childArrayItem) =>
convertRawToTypescript(childArrayItem, sampEnv)
)
);
default:
return createTsExport(arrayItem, sampEnv);
}
})
);
case "EvBool":
return tag("boolean", x.value);
case "EvCall":
return tag("call", x.value);
case "EvDistribution":
return tag("distribution", new Distribution(x.value, sampEnv));
case "EvNumber":
return tag("number", x.value);
case "EvRecord":
// genType doesn't support records, so we have to do the raw conversion ourself
let result: tagged<"record", { [key: string]: squiggleExpression }> = tag(
"record",
_.mapValues(x.value, (x: unknown) =>
convertRawToTypescript(x as rescriptExport, sampEnv)
)
);
return result;
case "EvString":
return tag("string", x.value);
case "EvSymbol":
return tag("symbol", x.value);
}
}
2022-04-29 13:50:57 +00:00
// Helper functions to convert the rescript representations that genType doesn't
// cover
function convertRawToTypescript(
result: rescriptExport,
sampEnv: samplingParams
): squiggleExpression {
switch (result.TAG) {
case 0: // EvArray
return tag(
"array",
result._0.map((x) => convertRawToTypescript(x, sampEnv))
);
case 1: // EvBool
return tag("boolean", result._0);
case 2: // EvCall
return tag("call", result._0);
case 3: // EvDistribution
return tag(
"distribution",
new Distribution(
convertRawDistributionToGenericDist(result._0),
sampEnv
)
);
case 4: // EvNumber
return tag("number", result._0);
case 5: // EvRecord
return tag(
"record",
_.mapValues(result._0, (x) => convertRawToTypescript(x, sampEnv))
);
case 6: // EvString
return tag("string", result._0);
case 7: // EvSymbol
return tag("symbol", result._0);
}
}
function convertRawDistributionToGenericDist(
result: rescriptDist
): genericDist {
switch (result.TAG) {
case 0: // Point Set Dist
switch (result._0.TAG) {
case 0: // Mixed
return tag("PointSet", tag("Mixed", result._0._0));
case 1: // Discrete
return tag("PointSet", tag("Discrete", result._0._0));
case 2: // Continuous
return tag("PointSet", tag("Continuous", result._0._0));
}
case 1: // Sample Set Dist
return tag("SampleSet", result._0);
case 2: // Symbolic Dist
return tag("Symbolic", result._0);
}
}
// Raw rescript types.
type rescriptExport =
| {
TAG: 0; // EvArray
_0: rescriptExport[];
}
| {
TAG: 1; // EvBool
_0: boolean;
}
| {
TAG: 2; // EvCall
_0: string;
}
| {
TAG: 3; // EvDistribution
_0: rescriptDist;
}
| {
TAG: 4; // EvNumber
_0: number;
}
| {
TAG: 5; // EvRecord
_0: { [key: string]: rescriptExport };
}
| {
TAG: 6; // EvString
_0: string;
}
| {
TAG: 7; // EvSymbol
_0: string;
};
type rescriptDist =
| { TAG: 0; _0: rescriptPointSetDist }
| { TAG: 1; _0: sampleSetDist }
| { TAG: 2; _0: symbolicDist };
type rescriptPointSetDist =
| {
TAG: 0; // Mixed
_0: mixedShape;
}
| {
TAG: 1; // Discrete
_0: discreteShape;
}
| {
TAG: 2; // ContinuousShape
_0: continuousShape;
};
2022-04-11 03:16:31 +00:00
2022-04-11 00:51:43 +00:00
export function resultExn<a, c>(r: result<a, c>): a | c {
return r.value;
2022-04-10 01:56:05 +00:00
}
export type point = { x: number; y: number };
2022-04-11 06:16:29 +00:00
export type shape = {
continuous: point[];
discrete: point[];
};
2022-04-11 06:16:29 +00:00
function shapePoints(x: continuousShape | discreteShape): point[] {
2022-04-11 06:16:29 +00:00
let xs = x.xyShape.xs;
let ys = x.xyShape.ys;
return _.zipWith(xs, ys, (x, y) => ({ x, y }));
2022-04-11 06:16:29 +00:00
}
2022-04-11 00:48:45 +00:00
export class Distribution {
t: genericDist;
2022-04-11 03:16:31 +00:00
env: samplingParams;
2022-04-11 03:16:31 +00:00
constructor(t: genericDist, env: samplingParams) {
this.t = t;
this.env = env;
return this;
}
mapResultDist(
r: result<genericDist, distributionError>
): result<Distribution, distributionError> {
2022-04-11 00:48:45 +00:00
return resultMap(r, (v: genericDist) => new Distribution(v, this.env));
}
2022-04-11 03:16:31 +00:00
mean(): result<number, distributionError> {
return Constructors_mean({ env: this.env }, this.t);
2022-04-08 19:55:04 +00:00
}
2022-04-11 03:16:31 +00:00
sample(): result<number, distributionError> {
return Constructors_sample({ env: this.env }, this.t);
2022-04-08 19:55:04 +00:00
}
2022-04-11 03:16:31 +00:00
pdf(n: number): result<number, distributionError> {
return Constructors_pdf({ env: this.env }, this.t, n);
2022-04-08 19:55:04 +00:00
}
2022-04-11 03:16:31 +00:00
cdf(n: number): result<number, distributionError> {
return Constructors_cdf({ env: this.env }, this.t, n);
2022-04-08 19:55:04 +00:00
}
2022-04-11 03:16:31 +00:00
inv(n: number): result<number, distributionError> {
return Constructors_inv({ env: this.env }, this.t, n);
2022-04-08 19:55:04 +00:00
}
isNormalized(): result<boolean, distributionError> {
return Constructors_isNormalized({ env: this.env }, this.t);
}
2022-04-11 03:16:31 +00:00
normalize(): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_normalize({ env: this.env }, this.t)
);
2022-04-08 19:55:04 +00:00
}
type() {
return this.t.tag;
}
pointSet(): result<shape, distributionError> {
let pointSet = toPointSet(
this.t,
{
xyPointLength: this.env.xyPointLength,
sampleCount: this.env.sampleCount,
},
undefined
);
if (pointSet.tag === "Ok") {
2022-04-11 06:16:29 +00:00
let distribution = pointSet.value;
if (distribution.tag === "Continuous") {
2022-04-11 06:16:29 +00:00
return Ok({
continuous: shapePoints(distribution.value),
discrete: [],
});
} else if (distribution.tag === "Discrete") {
2022-04-11 06:16:29 +00:00
return Ok({
discrete: shapePoints(distribution.value),
continuous: [],
});
} else {
2022-04-11 06:16:29 +00:00
return Ok({
discrete: shapePoints(distribution.value.discrete),
continuous: shapePoints(distribution.value.continuous),
});
2022-04-11 06:16:29 +00:00
}
} else {
return pointSet;
2022-04-11 06:16:29 +00:00
}
}
2022-04-11 03:16:31 +00:00
toPointSet(): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_toPointSet({ env: this.env }, this.t)
);
2022-04-08 19:55:04 +00:00
}
2022-04-11 03:16:31 +00:00
toSampleSet(n: number): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_toSampleSet({ env: this.env }, this.t, n)
);
2022-04-08 19:55:04 +00:00
}
truncate(
left: number,
right: number
): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_truncate({ env: this.env }, this.t, left, right)
2022-04-08 19:55:04 +00:00
);
}
2022-04-11 03:16:31 +00:00
inspect(): result<Distribution, distributionError> {
return this.mapResultDist(Constructors_inspect({ env: this.env }, this.t));
2022-04-08 19:55:04 +00:00
}
2022-04-11 06:16:29 +00:00
toString(): string {
let result = Constructors_toString({ env: this.env }, this.t);
if (result.tag === "Ok") {
return result.value;
} else {
return distributionErrorToString(result.value);
2022-04-11 06:16:29 +00:00
}
2022-04-08 19:55:04 +00:00
}
2022-04-11 03:16:31 +00:00
toSparkline(n: number): result<string, distributionError> {
return Constructors_toSparkline({ env: this.env }, this.t, n);
2022-04-08 19:55:04 +00:00
}
2022-04-11 03:16:31 +00:00
algebraicAdd(d2: Distribution): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_algebraicAdd({ env: this.env }, this.t, d2.t)
);
}
2022-04-11 03:16:31 +00:00
algebraicMultiply(d2: Distribution): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_algebraicMultiply({ env: this.env }, this.t, d2.t)
);
}
2022-04-11 03:16:31 +00:00
algebraicDivide(d2: Distribution): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_algebraicDivide({ env: this.env }, this.t, d2.t)
);
}
2022-04-11 03:16:31 +00:00
algebraicSubtract(d2: Distribution): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_algebraicSubtract({ env: this.env }, this.t, d2.t)
);
}
algebraicLogarithm(
d2: Distribution
): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_algebraicLogarithm({ env: this.env }, this.t, d2.t)
);
}
2022-04-11 03:16:31 +00:00
algebraicPower(d2: Distribution): result<Distribution, distributionError> {
return this.mapResultDist(
2022-04-09 16:37:26 +00:00
Constructors_algebraicPower({ env: this.env }, this.t, d2.t)
);
}
2022-04-11 03:16:31 +00:00
pointwiseAdd(d2: Distribution): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_pointwiseAdd({ env: this.env }, this.t, d2.t)
);
}
2022-04-11 03:16:31 +00:00
pointwiseMultiply(d2: Distribution): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_pointwiseMultiply({ env: this.env }, this.t, d2.t)
);
}
2022-04-11 03:16:31 +00:00
pointwiseDivide(d2: Distribution): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_pointwiseDivide({ env: this.env }, this.t, d2.t)
);
}
2022-04-11 03:16:31 +00:00
pointwiseSubtract(d2: Distribution): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_pointwiseSubtract({ env: this.env }, this.t, d2.t)
);
}
pointwiseLogarithm(
d2: Distribution
): result<Distribution, distributionError> {
return this.mapResultDist(
Constructors_pointwiseLogarithm({ env: this.env }, this.t, d2.t)
);
}
2022-04-11 03:16:31 +00:00
pointwisePower(d2: Distribution): result<Distribution, distributionError> {
return this.mapResultDist(
2022-04-09 16:37:26 +00:00
Constructors_pointwisePower({ env: this.env }, this.t, d2.t)
);
}
}