feat: prettify normal prediction

This commit is contained in:
NunoSempere 2023-05-24 00:19:05 -04:00
parent 22eb848911
commit 599340fd45
2 changed files with 19 additions and 7 deletions

Binary file not shown.

View File

@ -2,6 +2,7 @@ import print
import strutils
import sequtils
import std/sugar
import std/algorithm
## Get sequences
let file_path = "../data/stripped"
@ -54,7 +55,14 @@ proc findIndex(xs: seq[string], y: string): int =
return -1
## Do simple predictions
proc predictContinuation(seqs: seq[seq[string]], start: seq[string]): int =
proc compareTuple (x: (string, float), y: (string, float)): int =
let (_, p1) = x
let (_, p2) = y
if p1 < p2: return -1
elif p2 > p2: return 1
else: return 0
proc predictContinuation(seqs: seq[seq[string]], start: seq[string]): seq[(string, float)] =
let continuations = getSequencesWithStart(seqs, start)
let l = start.len
var nexts: seq[string]
@ -68,11 +76,15 @@ proc predictContinuation(seqs: seq[seq[string]], start: seq[string]): int =
else:
ps[i] = ps[i] + 1.0
let sum = foldl(ps, a + b, 0.0)
echo nexts
echo ps
echo ps.map(p=> p/sum)
return 1
# to do: wrangle this in some kind of sequence of tuples, e.g., using some zip type of thing.
ps = ps.map( p => p/sum)
var next_and_ps = zip(nexts, ps)
# next_and_ps = sort(next_and_ps, compareTuple)
sort(next_and_ps, compareTuple)
# ^ sorts in place
# also, openArray refers to both arrays and sequences.
return next_and_ps
var start = @["1", "2", "3", "4", "5", "6"]
echo predictContinuation(seqs, start)
print "Full prediction with access to all hypotheses:"
print predictContinuation(seqs, start)