You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
2.8 KiB
111 lines
2.8 KiB
import numpy as np
|
|
|
|
|
|
from cython cimport floating
|
|
from libc.math cimport exp, fabs, log
|
|
|
|
from sklearn.utils._typedefs cimport float64_t, intp_t
|
|
|
|
|
|
def mean_change(const floating[:] arr_1, const floating[:] arr_2):
|
|
"""Calculate the mean difference between two arrays.
|
|
|
|
Equivalent to np.abs(arr_1 - arr2).mean().
|
|
"""
|
|
|
|
cdef float64_t total, diff
|
|
cdef intp_t i, size
|
|
|
|
size = arr_1.shape[0]
|
|
total = 0.0
|
|
for i in range(size):
|
|
diff = fabs(arr_1[i] - arr_2[i])
|
|
total += diff
|
|
|
|
return total / size
|
|
|
|
|
|
def _dirichlet_expectation_1d(
|
|
floating[:] doc_topic,
|
|
floating doc_topic_prior,
|
|
floating[:] out
|
|
):
|
|
"""Dirichlet expectation for a single sample:
|
|
exp(E[log(theta)]) for theta ~ Dir(doc_topic)
|
|
after adding doc_topic_prior to doc_topic, in-place.
|
|
|
|
Equivalent to
|
|
doc_topic += doc_topic_prior
|
|
out[:] = np.exp(psi(doc_topic) - psi(np.sum(doc_topic)))
|
|
"""
|
|
|
|
cdef floating dt, psi_total, total
|
|
cdef intp_t i, size
|
|
|
|
size = doc_topic.shape[0]
|
|
|
|
total = 0.0
|
|
for i in range(size):
|
|
dt = doc_topic[i] + doc_topic_prior
|
|
doc_topic[i] = dt
|
|
total += dt
|
|
psi_total = psi(total)
|
|
|
|
for i in range(size):
|
|
out[i] = exp(psi(doc_topic[i]) - psi_total)
|
|
|
|
|
|
def _dirichlet_expectation_2d(const floating[:, :] arr):
|
|
"""Dirichlet expectation for multiple samples:
|
|
E[log(theta)] for theta ~ Dir(arr).
|
|
|
|
Equivalent to psi(arr) - psi(np.sum(arr, axis=1))[:, np.newaxis].
|
|
|
|
Note that unlike _dirichlet_expectation_1d, this function doesn't compute
|
|
the exp and doesn't add in the prior.
|
|
"""
|
|
cdef floating row_total, psi_row_total
|
|
cdef floating[:, :] d_exp
|
|
cdef intp_t i, j, n_rows, n_cols
|
|
|
|
n_rows = arr.shape[0]
|
|
n_cols = arr.shape[1]
|
|
|
|
d_exp = np.empty_like(arr)
|
|
for i in range(n_rows):
|
|
row_total = 0
|
|
for j in range(n_cols):
|
|
row_total += arr[i, j]
|
|
psi_row_total = psi(row_total)
|
|
|
|
for j in range(n_cols):
|
|
d_exp[i, j] = psi(arr[i, j]) - psi_row_total
|
|
|
|
return d_exp.base
|
|
|
|
|
|
# Psi function for positive arguments. Optimized for speed, not accuracy.
|
|
#
|
|
# After: J. Bernardo (1976). Algorithm AS 103: Psi (Digamma) Function.
|
|
# https://www.uv.es/~bernardo/1976AppStatist.pdf
|
|
cdef floating psi(floating x) noexcept nogil:
|
|
cdef double EULER = 0.577215664901532860606512090082402431
|
|
if x <= 1e-6:
|
|
# psi(x) = -EULER - 1/x + O(x)
|
|
return -EULER - 1. / x
|
|
|
|
cdef floating r, result = 0
|
|
|
|
# psi(x + 1) = psi(x) + 1/x
|
|
while x < 6:
|
|
result -= 1. / x
|
|
x += 1
|
|
|
|
# psi(x) = log(x) - 1/(2x) - 1/(12x**2) + 1/(120x**4) - 1/(252x**6)
|
|
# + O(1/x**8)
|
|
r = 1. / x
|
|
result += log(x) - .5 * r
|
|
r = r * r
|
|
result -= r * ((1./12.) - r * ((1./120.) - r * (1./252.)))
|
|
return result
|