diff --git a/packages/squiggle-lang/__tests__/JS__Test.ts b/packages/squiggle-lang/__tests__/JS__Test.ts index aded69c1..07a4cabe 100644 --- a/packages/squiggle-lang/__tests__/JS__Test.ts +++ b/packages/squiggle-lang/__tests__/JS__Test.ts @@ -1,34 +1,66 @@ -import { run } from '../src/js/index'; +import { run, GenericDist, resultMap } from "../src/js/index"; let testRun = (x: string) => { - let result = run(x) - if(result.tag == 'Ok'){ - return { tag: 'Ok', value: result.value.exports } + let result = run(x); + if (result.tag == "Ok") { + return { tag: "Ok", value: result.value.exports }; + } else { + return result; } - else { - return result - } -} +}; describe("Simple calculations and results", () => { - test("mean(normal(5,2))", () => { - expect(testRun("mean(normal(5,2))")).toEqual({ tag: 'Ok', value: [ { NAME: 'Float', VAL: 5 } ] }) - }) - test("10+10", () => { - let foo = testRun("10 + 10") - expect(foo).toEqual({ tag: 'Ok', value: [ { NAME: 'Float', VAL: 20 } ] }) - }) -}) + test("mean(normal(5,2))", () => { + expect(testRun("mean(normal(5,2))")).toEqual({ + tag: "Ok", + value: [{ NAME: "Float", VAL: 5 }], + }); + }); + test("10+10", () => { + let foo = testRun("10 + 10"); + expect(foo).toEqual({ tag: "Ok", value: [{ NAME: "Float", VAL: 20 }] }); + }); +}); describe("Log function", () => { - test("log(1) = 0", () => { - let foo = testRun("log(1)") - expect(foo).toEqual({ tag: 'Ok', value: [ { NAME: 'Float', VAL: 0} ]}) - }) -}) + test("log(1) = 0", () => { + let foo = testRun("log(1)"); + expect(foo).toEqual({ tag: "Ok", value: [{ NAME: "Float", VAL: 0 }] }); + }); +}); describe("Multimodal too many weights error", () => { - test("mm(0,0,[0,0,0])", () => { - let foo = testRun("mm(0,0,[0,0,0])") - expect(foo).toEqual({ "tag": "Error", "value": "Function multimodal error: Too many weights provided" }) - }) + test("mm(0,0,[0,0,0])", () => { + let foo = testRun("mm(0,0,[0,0,0])"); + expect(foo).toEqual({ + tag: "Error", + value: "Function multimodal error: Too many weights provided", + }); + }); +}); + +describe("GenericDist", () => { + let dist = new GenericDist( + { tag: "SampleSet", value: [3, 4, 5, 6, 6, 7, 10, 15, 30] }, + { sampleCount: 100, xyPointLength: 100 } + ); + test("mean", () => { + expect(dist.mean().value).toBeCloseTo(3.737); + }); + test("pdf", () => { + expect(dist.pdf(5.0).value).toBeCloseTo(0.0431); + }); + test("cdf", () => { + expect(dist.cdf(5.0).value).toBeCloseTo(0.155); + }); + test("inv", () => { + expect(dist.inv(0.5).value).toBeCloseTo(9.458); + }); + test("toPointSet", () => { + expect( + resultMap(dist.toPointSet(), (r: GenericDist) => r.toString()).value.value + ).toBe("Point Set Distribution"); + }); + test("toSparkline", () => { + expect(dist.toSparkline(20).value).toBe("▁▁▃▅███▆▄▃▂▁▁▂▂▃▂▁▁▁"); + }); }); diff --git a/packages/squiggle-lang/src/js/index.ts b/packages/squiggle-lang/src/js/index.ts index a250c4fc..b690d581 100644 --- a/packages/squiggle-lang/src/js/index.ts +++ b/packages/squiggle-lang/src/js/index.ts @@ -7,7 +7,6 @@ import type { } from "../rescript/ProgramEvaluator.gen"; export type { SamplingInputs, exportEnv, exportDistribution }; export type { t as DistPlus } from "../rescript/OldInterpreter/DistPlus.gen"; -import type { Operation_genericFunctionCallInfo } from "../rescript/Distributions/GenericDist/GenericDist_Types.gen"; import { genericDist, resultDist, @@ -16,30 +15,30 @@ import { } from "../rescript/TSInterface.gen"; import { env, - Constructors_UsingDists_mean, - Constructors_UsingDists_sample, - Constructors_UsingDists_pdf, - Constructors_UsingDists_cdf, - Constructors_UsingDists_inv, - Constructors_UsingDists_normalize, - Constructors_UsingDists_toPointSet, - Constructors_UsingDists_toSampleSet, - Constructors_UsingDists_truncate, - Constructors_UsingDists_inspect, - Constructors_UsingDists_toString, - Constructors_UsingDists_toSparkline, - Constructors_UsingDists_algebraicAdd, - Constructors_UsingDists_algebraicMultiply, - Constructors_UsingDists_algebraicDivide, - Constructors_UsingDists_algebraicSubtract, - Constructors_UsingDists_algebraicLogarithm, - Constructors_UsingDists_algebraicExponentiate, - Constructors_UsingDists_pointwiseAdd, - Constructors_UsingDists_pointwiseMultiply, - Constructors_UsingDists_pointwiseDivide, - Constructors_UsingDists_pointwiseSubtract, - Constructors_UsingDists_pointwiseLogarithm, - Constructors_UsingDists_pointwiseExponentiate, + Constructors_mean, + Constructors_sample, + Constructors_pdf, + Constructors_cdf, + Constructors_inv, + Constructors_normalize, + Constructors_toPointSet, + Constructors_toSampleSet, + Constructors_truncate, + Constructors_inspect, + Constructors_toString, + Constructors_toSparkline, + Constructors_algebraicAdd, + Constructors_algebraicMultiply, + Constructors_algebraicDivide, + Constructors_algebraicSubtract, + Constructors_algebraicLogarithm, + Constructors_algebraicExponentiate, + Constructors_pointwiseAdd, + Constructors_pointwiseMultiply, + Constructors_pointwiseDivide, + Constructors_pointwiseSubtract, + Constructors_pointwiseLogarithm, + Constructors_pointwiseExponentiate, } from "../rescript/Distributions/DistributionOperation/DistributionOperation.gen"; export let defaultSamplingInputs: SamplingInputs = { @@ -60,160 +59,172 @@ export function run( return runAll(squiggleString, si, env); } -class GenericDist { +export function resultMap( + r: + | { + tag: "Ok"; + value: any; + } + | { + tag: "Error"; + value: any; + }, + mapFn: any +): + | { + tag: "Ok"; + value: any; + } + | { + tag: "Error"; + value: any; + } { + if (r.tag === "Ok") { + return { tag: "Ok", value: mapFn(r.value) }; + } else { + return r; + } +} + +export class GenericDist { t: genericDist; env: env; constructor(t: genericDist, env: env) { this.t = t; + this.env = env; + return this; } - mean(): resultFloat { - return Constructors_UsingDists_mean({ env: this.env }, this.t); + mapResultDist(r: resultDist) { + return resultMap(r, (v: genericDist) => new GenericDist(v, this.env)); + } + + mean() { + return Constructors_mean({ env: this.env }, this.t); } sample(): resultFloat { - return Constructors_UsingDists_sample({ env: this.env }, this.t); + return Constructors_sample({ env: this.env }, this.t); } pdf(n: number): resultFloat { - return Constructors_UsingDists_pdf({ env: this.env }, this.t, n); + return Constructors_pdf({ env: this.env }, this.t, n); } cdf(n: number): resultFloat { - return Constructors_UsingDists_cdf({ env: this.env }, this.t, n); + return Constructors_cdf({ env: this.env }, this.t, n); } inv(n: number): resultFloat { - return Constructors_UsingDists_inv({ env: this.env }, this.t, n); + return Constructors_inv({ env: this.env }, this.t, n); } - normalize(): resultDist { - return Constructors_UsingDists_normalize({ env: this.env }, this.t); - } - - toPointSet(): resultDist { - return Constructors_UsingDists_toPointSet({ env: this.env }, this.t); - } - - toSampleSet(n: number): resultDist { - return Constructors_UsingDists_toSampleSet({ env: this.env }, this.t, n); - } - - truncate(left: number, right: number): resultDist { - return Constructors_UsingDists_truncate( - { env: this.env }, - this.t, - left, - right + normalize() { + return this.mapResultDist( + Constructors_normalize({ env: this.env }, this.t) ); } - inspect(): resultDist { - return Constructors_UsingDists_inspect({ env: this.env }, this.t); + toPointSet() { + return this.mapResultDist( + Constructors_toPointSet({ env: this.env }, this.t) + ); + } + + toSampleSet(n: number) { + return this.mapResultDist( + Constructors_toSampleSet({ env: this.env }, this.t, n) + ); + } + + truncate(left: number, right: number) { + return this.mapResultDist( + Constructors_truncate({ env: this.env }, this.t, left, right) + ); + } + + inspect() { + return this.mapResultDist(Constructors_inspect({ env: this.env }, this.t)); } toString(): resultString { - return Constructors_UsingDists_toString({ env: this.env }, this.t); + return Constructors_toString({ env: this.env }, this.t); } toSparkline(n: number): resultString { - return Constructors_UsingDists_toSparkline({ env: this.env }, this.t, n); + return Constructors_toSparkline({ env: this.env }, this.t, n); } - algebraicAdd(d2: GenericDist): resultDist { - return Constructors_UsingDists_algebraicAdd( - { env: this.env }, - this.t, - d2.t + algebraicAdd(d2: GenericDist) { + return this.mapResultDist( + Constructors_algebraicAdd({ env: this.env }, this.t, d2.t) ); } - algebraicMultiply(d2: GenericDist): resultDist { - return Constructors_UsingDists_algebraicMultiply( - { env: this.env }, - this.t, - d2.t + algebraicMultiply(d2: GenericDist) { + return this.mapResultDist( + Constructors_algebraicMultiply({ env: this.env }, this.t, d2.t) ); } - algebraicDivide(d2: GenericDist): resultDist { - return Constructors_UsingDists_algebraicDivide( - { env: this.env }, - this.t, - d2.t + algebraicDivide(d2: GenericDist) { + return this.mapResultDist( + Constructors_algebraicDivide({ env: this.env }, this.t, d2.t) ); } - algebraicSubtract(d2: GenericDist): resultDist { - return Constructors_UsingDists_algebraicSubtract( - { env: this.env }, - this.t, - d2.t + algebraicSubtract(d2: GenericDist) { + return this.mapResultDist( + Constructors_algebraicSubtract({ env: this.env }, this.t, d2.t) ); } - algebraicLogarithm(d2: GenericDist): resultDist { - return Constructors_UsingDists_algebraicLogarithm( - { env: this.env }, - this.t, - d2.t + algebraicLogarithm(d2: GenericDist) { + return this.mapResultDist( + Constructors_algebraicLogarithm({ env: this.env }, this.t, d2.t) ); } - algebraicExponentiate(d2: GenericDist): resultDist { - return Constructors_UsingDists_algebraicExponentiate( - { env: this.env }, - this.t, - d2.t + algebraicExponentiate(d2: GenericDist) { + return this.mapResultDist( + Constructors_algebraicExponentiate({ env: this.env }, this.t, d2.t) ); } - pointwiseAdd(d2: GenericDist): resultDist { - return Constructors_UsingDists_pointwiseAdd( - { env: this.env }, - this.t, - d2.t + pointwiseAdd(d2: GenericDist) { + return this.mapResultDist( + Constructors_pointwiseAdd({ env: this.env }, this.t, d2.t) ); } - pointwiseMultiply(d2: GenericDist): resultDist { - return Constructors_UsingDists_pointwiseMultiply( - { env: this.env }, - this.t, - d2.t + pointwiseMultiply(d2: GenericDist) { + return this.mapResultDist( + Constructors_pointwiseMultiply({ env: this.env }, this.t, d2.t) ); } - pointwiseDivide(d2: GenericDist): resultDist { - return Constructors_UsingDists_pointwiseDivide( - { env: this.env }, - this.t, - d2.t + pointwiseDivide(d2: GenericDist) { + return this.mapResultDist( + Constructors_pointwiseDivide({ env: this.env }, this.t, d2.t) ); } - pointwiseSubtract(d2: GenericDist): resultDist { - return Constructors_UsingDists_pointwiseSubtract( - { env: this.env }, - this.t, - d2.t + pointwiseSubtract(d2: GenericDist) { + return this.mapResultDist( + Constructors_pointwiseSubtract({ env: this.env }, this.t, d2.t) ); } - pointwiseLogarithm(d2: GenericDist): resultDist { - return Constructors_UsingDists_pointwiseLogarithm( - { env: this.env }, - this.t, - d2.t + pointwiseLogarithm(d2: GenericDist) { + return this.mapResultDist( + Constructors_pointwiseLogarithm({ env: this.env }, this.t, d2.t) ); } - pointwiseExponentiate(d2: GenericDist): resultDist { - return Constructors_UsingDists_pointwiseExponentiate( - { env: this.env }, - this.t, - d2.t + pointwiseExponentiate(d2: GenericDist) { + return this.mapResultDist( + Constructors_pointwiseExponentiate({ env: this.env }, this.t, d2.t) ); } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res index 7d58af7b..12b94bcc 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.res @@ -187,8 +187,7 @@ module Output = { } module Constructors = { - module UsingDists = { - module C = GenericDist_Types.Constructors.UsingDists + module C = GenericDist_Types.Constructors.UsingDists; open OutputLocal let mean = (~env, dist) => C.mean(dist)->run(~env)->toFloatR let sample = (~env, dist) => C.sample(dist)->run(~env)->toFloatR @@ -225,5 +224,4 @@ module Constructors = { C.pointwiseSubtract(dist1, dist2)->run(~env)->toDistR let pointwiseExponentiate = (~env, dist1, dist2) => C.pointwiseSubtract(dist1, dist2)->run(~env)->toDistR - } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.resi b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.resi index b0799dc9..ce0fca72 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.resi +++ b/packages/squiggle-lang/src/rescript/Distributions/DistributionOperation/DistributionOperation.resi @@ -39,7 +39,6 @@ module Output: { } module Constructors: { - module UsingDists: { @genType let mean: (~env: env, genericDist) => result @genType @@ -93,5 +92,4 @@ module Constructors: { let pointwiseLogarithm: (~env: env, genericDist, genericDist) => result @genType let pointwiseExponentiate: (~env: env, genericDist, genericDist) => result - } } diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res index 58366a77..4815e98c 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist.res @@ -83,8 +83,8 @@ let toPointSet = ( let toSparkline = (t: t, ~sampleCount: int, ~buckets: int=20, unit): result => t - ->toPointSet(~xSelection=#Linear, ~xyPointLength=buckets, ~sampleCount, ()) - ->E.R.bind(r => r->PointSetDist.toSparkline->E.R2.errMap(r => Error(GenericDist_Types.Other(r)))) + ->toPointSet(~xSelection=#Linear, ~xyPointLength=buckets*3, ~sampleCount, ()) + ->E.R.bind(r => r->PointSetDist.toSparkline(buckets)->E.R2.errMap(r => Error(GenericDist_Types.Other(r)))) module Truncate = { let trySymbolicSimplification = (leftCutoff, rightCutoff, t: t): option => diff --git a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist_Types.res b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist_Types.res index 45ad6f64..1bfb3e82 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist_Types.res +++ b/packages/squiggle-lang/src/rescript/Distributions/GenericDist/GenericDist_Types.res @@ -102,6 +102,7 @@ module Constructors = { type t = Operation.genericFunctionCallInfo module UsingDists = { + @genType let mean = (dist): t => FromDist(ToFloat(#Mean), dist) let sample = (dist): t => FromDist(ToFloat(#Sample), dist) let cdf = (dist, f): t => FromDist(ToFloat(#Cdf(f)), dist) diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res index f01457b7..aa27fb62 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/Continuous.res @@ -249,6 +249,9 @@ module T = Dist({ ) }) +let downsampleEquallyOverX = (length, t): t => + t |> shapeMap(XYShape.XsConversion.proportionEquallyOverX(length)) + /* This simply creates multiple copies of the continuous distribution, scaled and shifted according to each discrete data point, and then adds them all together. */ let combineAlgebraicallyWithDiscrete = ( diff --git a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res index 8224f4cb..c7bec8a9 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res +++ b/packages/squiggle-lang/src/rescript/Distributions/PointSetDist/PointSetDist.res @@ -203,7 +203,8 @@ let operate = (distToFloatOp: Operation.distToFloatOperation, s): float => | #Mean => T.mean(s) } -let toSparkline = (t: t) => +let toSparkline = (t: t, n) => T.toContinuous(t) + ->E.O2.fmap(Continuous.downsampleEquallyOverX(n)) ->E.O2.toResult("toContinous Error: Could not convert into continuous distribution") ->E.R2.fmap(r => Continuous.getShape(r).ys->Sparklines.create()) \ No newline at end of file