asco/fit.py
Jacob 2cb11e4933 Initial commit.
Yeah I know there are a lot of nonessential files but w/e.
2024-12-17 01:39:52 -05:00

59 lines
1.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as ν
import csv
import get_emb
MAX_RES = 10
φ = open('beta', 'r'); # coefficients
Δ = [float(el) for el in (open('Delta', 'r').read().split('\n')[:-1])]
#α = -0.8569279; # some magic constant
α = 0;
M = open('titles2.text', 'r').read().strip().split('\n') # TITLE\nAUTHORS\nACCEPTED?
T = ν.array(M[0::4])
A = ν.array(M[1::4])
O = ν.array(M[2::4])
X = ν.loadtxt(open('embeddings2.nsv', 'rb'), delimiter=',', skiprows=0)
NN = open('NN', 'r').read().split('\n')
β = φ.read().split('\n');
β = β[:-1]
β = [float(el) for el in β]
def get(θ):
return(get_emb.get_embedding(θ))
def percent(χ):
γ = α + ν.dot(χ, β)
π = ν.exp(γ) / (1 + ν.exp(γ))
return(str(π)[2:4] + '%')
def closest(χ, n):
n = abs(n)
n = n % MAX_RES
if n == 0: n = MAX_RES
ψ = ν.array(ν.dot(X, χ))
topn = T[ν.argsort(ψ)[-n:]]
aopn = A[ν.argsort(ψ)[-n:]]
oopn = O[ν.argsort(ψ)[-n:]]
print(ν.argsort(ψ)[-n:])
out = ""#"tail prob = " + str(percentile_far(ν.max(ψ))) + "\n"
for i in reversed(range(len(topn))):
if oopn[i] == "TRUE":
p = "presented"
else:
p = "online-only"
out += topn[i] + " <i>(" + aopn[i] + ", " + p + ")</i>\n"
tailprob = int(percentile_far(ν.max(ψ))*100)
return [out, tailprob]
def percentile_far(q_dist):
return sum(1*(ν.array(Δ)<=q_dist)) / len(Δ) # fraction of abstracts further from their nearest neighbor than χ