feat: add lua.

This commit is contained in:
NunoSempere 2023-06-10 19:07:04 -06:00
parent 82ef16b55d
commit c9eeaf5b5d

View File

@ -1 +1,82 @@
print("Hello world")
-- Consts and prep
PI = 3.14159265358979323846;
NORMAL95CONFIDENCE = 1.6448536269514722;
math.randomseed(1234)
-- Random distribution functions
function sample_normal_0_1()
local u1 = math.random()
local u2 = math.random()
local result = math.sqrt(-2 * math.log(u1)) * math.sin(2 * PI * u2)
return result
end
function sample_normal(mean, sigma)
return mean + (sigma * sample_normal_0_1())
end
function sample_uniform(min, max)
return math.random() * (max - min) + min
end
function sample_lognormal(logmean, logsigma)
return math.exp(sample_normal(logmean, logsigma))
end
function sample_to(low, high)
local loglow = math.log(low);
local loghigh = math.log(high);
local logmean = (loglow + loghigh) / 2;
local logsigma = (loghigh - loglow) / (2.0 * NORMAL95CONFIDENCE);
return sample_lognormal(logmean, logsigma, seed);
end
-- Mixture
function mixture(samplers, weights, n)
assert(#samplers == #weights)
local l = #weights
local sum_weights = 0
for i = 1, l, 1 do
-- ^ arrays start at 1
sum_weights = sum_weights + weights[i]
end
local cumsummed_normalized_weights = {}
table.insert(cumsummed_normalized_weights, weights[1]/sum_weights)
for i = 2, l, 1 do
table.insert(cumsummed_normalized_weights, cumsummed_normalized_weights[i-1] + weights[i]/sum_weights)
end
local result = {}
for i = 1, n, 1 do
r = math.random()
local i = 1
while r > cumsummed_normalized_weights[i] do
i = i+1
end
table.insert(result, samplers[i]())
end
return result
end
-- Main
p_a = 0.8
p_b = 0.5
p_c = p_a * p_b
function sample_0() return 0 end
function sample_1() return 1 end
function sample_few() return sample_to(1, 3) end
function sample_many() return sample_to(2, 10) end
samplers = {sample_0, sample_1, sample_few, sample_many}
weights = { (1 - p_c), p_c/2, p_c/4, p_c/4 }
n = 1000000
result = mixture(samplers, weights, n)
sum = 0
for i = 1, n, 1 do
sum = sum + result[i]
-- print(result[i])
end
print(sum/n)