squiggle/packages/squiggle-lang/src/js/SqPointSetDist.ts

102 lines
2.5 KiB
TypeScript
Raw Normal View History

2022-08-27 17:46:43 +00:00
import * as _ from "lodash";
import { wrapDistribution } from "./SqDistribution";
2022-08-27 17:46:43 +00:00
import * as RSPointSetDist from "../rescript/ForTS/ForTS_Distribution/ForTS_Distribution_PointSetDistribution.gen";
import { pointSetDistributionTag as Tag } from "../rescript/ForTS/ForTS_Distribution/ForTS_Distribution_PointSetDistribution_tag";
type T = RSPointSetDist.pointSetDistribution;
2022-08-28 16:19:44 +00:00
export type SqPoint = { x: number; y: number };
export type SqShape = {
continuous: SqPoint[];
discrete: SqPoint[];
2022-08-27 17:46:43 +00:00
};
const shapePoints = (
x: RSPointSetDist.continuousShape | RSPointSetDist.discreteShape
2022-08-28 16:19:44 +00:00
): SqPoint[] => {
2022-08-27 17:46:43 +00:00
let xs = x.xyShape.xs;
let ys = x.xyShape.ys;
return _.zipWith(xs, ys, (x, y) => ({ x, y }));
};
export const wrapPointSetDist = (value: T) => {
const tag = RSPointSetDist.getTag(value);
return new tagToClass[tag](value);
};
2022-08-28 16:19:44 +00:00
abstract class SqAbstractPointSetDist {
constructor(private _value: T) {}
2022-08-27 17:46:43 +00:00
2022-08-28 16:19:44 +00:00
abstract asShape(): SqShape;
2022-08-27 17:46:43 +00:00
protected valueMethod = <IR>(rsMethod: (v: T) => IR | null | undefined) => {
const value = rsMethod(this._value);
if (!value) throw new Error("Internal casting error");
return value;
};
asDistribution() {
return wrapDistribution(RSPointSetDist.toDistribution(this._value));
}
}
2022-08-27 17:46:43 +00:00
2022-08-28 16:19:44 +00:00
export class SqMixedPointSetDist extends SqAbstractPointSetDist {
2022-08-28 17:33:16 +00:00
tag = Tag.Mixed as const;
2022-08-27 17:46:43 +00:00
get value(): RSPointSetDist.mixedShape {
return this.valueMethod(RSPointSetDist.getMixed);
2022-08-27 17:46:43 +00:00
}
asShape() {
const v = this.value;
return {
discrete: shapePoints(v.discrete),
continuous: shapePoints(v.continuous),
};
}
}
2022-08-28 16:19:44 +00:00
export class SqDiscretePointSetDist extends SqAbstractPointSetDist {
2022-08-28 17:33:16 +00:00
tag = Tag.Discrete as const;
2022-08-27 17:46:43 +00:00
get value(): RSPointSetDist.discreteShape {
return this.valueMethod(RSPointSetDist.getDiscrete);
2022-08-27 17:46:43 +00:00
}
asShape() {
const v = this.value;
return {
discrete: shapePoints(v),
continuous: [],
};
}
}
2022-08-28 16:19:44 +00:00
export class SqContinuousPointSetDist extends SqAbstractPointSetDist {
2022-08-28 17:33:16 +00:00
tag = Tag.Continuous as const;
2022-08-27 17:46:43 +00:00
get value(): RSPointSetDist.continuousShape {
return this.valueMethod(RSPointSetDist.getContinues);
2022-08-27 17:46:43 +00:00
}
asShape() {
const v = this.value;
return {
discrete: [],
continuous: shapePoints(v),
};
}
}
const tagToClass = {
2022-08-28 17:33:16 +00:00
[Tag.Mixed]: SqMixedPointSetDist,
[Tag.Discrete]: SqDiscretePointSetDist,
[Tag.Continuous]: SqContinuousPointSetDist,
2022-08-27 17:46:43 +00:00
} as const;
2022-08-28 16:19:44 +00:00
export type SqPointSetDist =
| SqMixedPointSetDist
| SqDiscretePointSetDist
| SqContinuousPointSetDist;