diff --git a/ocaml/out/samples b/ocaml/out/samples index ac03c5db..10c720d7 100755 Binary files a/ocaml/out/samples and b/ocaml/out/samples differ diff --git a/ocaml/out/samples.cmi b/ocaml/out/samples.cmi index 17b39170..dabde895 100644 Binary files a/ocaml/out/samples.cmi and b/ocaml/out/samples.cmi differ diff --git a/ocaml/out/samples.cmx b/ocaml/out/samples.cmx index bebd406a..1d6ce640 100644 Binary files a/ocaml/out/samples.cmx and b/ocaml/out/samples.cmx differ diff --git a/ocaml/out/samples.o b/ocaml/out/samples.o index 7dffff0e..85ec9631 100644 Binary files a/ocaml/out/samples.o and b/ocaml/out/samples.o differ diff --git a/ocaml/samples.ml b/ocaml/samples.ml index 7f47e89d..b1bf8440 100644 --- a/ocaml/samples.ml +++ b/ocaml/samples.ml @@ -2,17 +2,27 @@ let pi = acos (-1.) let normal_95_ci_length = 1.6448536269514722 -(* Array manipulation helpers *) -let sumFloats xs = Array.fold_left(fun acc x -> acc +. x) 0.0 xs +(* List manipulation helpers *) +let sumFloats xs = List.fold_left(fun acc x -> acc +. x) 0.0 xs let normalizeXs xs = let sum_xs = sumFloats xs in - Array.map(fun x -> x /. sum_xs) xs + List.map(fun x -> x /. sum_xs) xs let cumsumXs xs = - let _, cum_sum = Array.fold_left(fun (sum, ys) x -> + let _, cum_sum = List.fold_left(fun (sum, ys) x -> let new_sum = sum +. x in new_sum, ys @ [new_sum] ) (0.0, []) xs in cum_sum +let rec nth xs (n: int) = + match xs with + | [] -> None + | y :: ys -> if n = 0 then Some(y) else nth ys (n-1) + (* + Note that this is O(n) access. + That is the cost of using the nice match syntax, + which is not possible with OCaml arrays + *) + let findIndex xs test = let rec recursiveHelper ys i = match ys with @@ -41,18 +51,22 @@ let sampleTo low high = let logstd = (loghigh -. loglow) /. (2.0 -. normal_95_ci_length ) in sampleLognormal logmean logstd -let mixture (samplers: (unit -> float) array) (weights: float array): float option = - if (Array.length samplers == Array.length weights) +let mixture (samplers: (unit -> float) list) (weights: float list): float option = + if (List.length samplers == List.length weights) then None else let normalized_weights = normalizeXs weights in let cumsummed_normalized_weights = cumsumXs normalized_weights in let p = sampleZeroToOne () in let chosenSamplerIndex = findIndex cumsummed_normalized_weights (fun x -> x < p) in - let sample = match chosenSamplerIndex with + let sampler = match chosenSamplerIndex with | None -> None - | Some(i) -> Some (samplers.(i) ()) + | Some(i) -> nth samplers i in + let sample = match sampler with + | None -> None + | Some(f) -> Some(f ()) + in sample let () = @@ -63,9 +77,9 @@ let () = let p1 = 0.8 in let p2 = 0.5 in let p3 = p1 *. p2 in - let weights = [| 1. -. p3; p3 /. 2.; p3 /. 4.; p3/. 4. |] in - let sampler () = mixture [| sample0; sample1; sampleFew; sampleMany |] weights in + let weights = [ 1. -. p3; p3 /. 2.; p3 /. 4.; p3/. 4. ] in + let sampler () = mixture [ sample0; sample1; sampleFew; sampleMany ] weights in let n = 1_000_000 in - let samples = Array.init n (fun _ -> sampler ()) in + let samples = List.init n (fun _ -> sampler ()) in (* let mean = sumFloats samples /. n in *) Printf.printf "Hello world\n"