diff --git a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDistTypes.res b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDistTypes.res index 4625867d..00705ccd 100644 --- a/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDistTypes.res +++ b/packages/squiggle-lang/src/rescript/Distributions/SymbolicDist/SymbolicDistTypes.res @@ -31,7 +31,85 @@ type triangular = { high: float, } +type symbolicValidationError = +| InvalidNormal(string) +| InvalidLognormal(string) +| InvalidUniform(string) +| InvalidBeta(string) +| InvalidExponential(string) +| InvalidCauchy(string) +| InvalidTriangular(string) + +type validated<'a> = result<'a, symbolicValidationError> + +let valiNormal: normal => validated = t => { + if t.stdev <= 0.0 { + Error(InvalidNormal("Stdev must be strictly greater than 0")) + } else { + Ok(t) + } +} + +let valiExponential: exponential => validated = t => { + if t.rate <= 0.0 { + Error(InvalidExponential("Exponential distribtion rate must be larger than 0")) + } else { + Ok(t) + } +} + +let valiCauchy: cauchy => validated = t => { + Ok(t) +} + +let valiTriangular: triangular => validated = t => { + if t.low >= t.medium || t.medium >= t.high { + Error(InvalidTriangular("Triangular values must be in increasing order")) + } else { + Ok(t) + } +} + +let valiBeta: beta => validated = t => { + if t.alpha <= 0.0 || t.beta <= 0.0 { + Error(InvalidBeta("Beta distribution parameters must be strictly positive")) + } else { + Ok(t) + } +} + +let valiLognormal: lognormal => validated = t => { + if t.sigma <= 0.0 { + Error(InvalidLognormal("Lognormal standard deviation must be strictly positive")) + } else { + Ok(t) + } +} + +let valiUniform: uniform => validated = t => { + if t.low >= t.high { + Error(InvalidUniform("High must be strictly greater than low")) + } else { + Ok(t) + } +} + +let valiFloat: float => validated = t => { + Ok(t) +} + @genType +type symbolicDistR = [ + | #NormalR(validated) + | #BetaR(validated) + | #LognormalR(validated) + | #UniformR(validated) + | #ExponentialR(validated) + | #CauchyR(validated) + | #TriangularR(validated) + | #FloatR(validated) +] + type symbolicDist = [ | #Normal(normal) | #Beta(beta) @@ -48,3 +126,50 @@ type analyticalSimplificationResult = [ | #Error(string) | #NoSolution ] + +// I feel like this should be something in `E.R.`... +let f: symbolicDistR => validated = x => { + switch x { + | #NormalR(vNormal) => switch vNormal { + | Ok(t) => Ok(#Normal(t)) + | Error(t) => Error(t) + } + | #BetaR(vBeta) => switch vBeta { + | Ok(t) => Ok(#Beta(t)) + | Error(t) => Error(t) + } + | #LognormalR(vLognormal) => switch vLognormal { + | Ok(t) => Ok(#Lognormal(t)) + | Error(t) => Error(t) + } + | #UniformR(vUniform) => switch vUniform { + | Ok(t) => Ok(#Uniform(t)) + | Error(t) => Error(t) + } + | #ExponentialR(vExponential) => switch vExponential { + | Ok(t) => Ok(#Exponential(t)) + | Error(t) => Error(t) + } + | #CauchyR(vExponential) => switch vExponential { + | Ok(t) => Ok(#Cauchy(t)) + | Error(t) => Error(t) + } + | #TriangularR(vExponential) => switch vExponential { + | Ok(t) => Ok(#Triangular(t)) + | Error(t) => Error(t) + } + | #FloatR(vExponential) => switch vExponential { + | Ok(t) => Ok(#Float(t)) + | Error(t) => Error(t) + } + } +} + +let normalConstr: normal => validated = t => t -> valiNormal -> #NormalR -> f +let exponentialConstr: exponential => validated = t => t -> valiExponential -> #ExponentialR -> f +let cauchyConstr: cauchy => validated = t => t -> valiCauchy -> #CauchyR -> f +let triangularConstr: triangular => validated = t => t -> valiTriangular -> #TriangularR -> f +let betaConstr: beta => validated = t => t -> valiBeta -> #BetaR -> f +let lognormalConstr: lognormal => validated = t => t -> valiLognormal -> #LognormalR -> f +let uniformConstr: uniform => validated = t => t -> valiUniform -> #UniformR -> f +let floatConstr: float => validated = t => t -> valiFloat -> #FloatR -> f