You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

60 lines
1.9 KiB

from typing import List
from thinc.api import Model
from thinc.types import Floats2d
from ..tokens import Doc
from ..util import registry
def CharacterEmbed(nM: int, nC: int) -> Model[List[Doc], List[Floats2d]]:
# nM: Number of dimensions per character. nC: Number of characters.
return Model(
"charembed",
forward,
init=init,
dims={"nM": nM, "nC": nC, "nO": nM * nC, "nV": 256},
params={"E": None},
)
def init(model: Model, X=None, Y=None):
vectors_table = model.ops.alloc3f(
model.get_dim("nC"), model.get_dim("nV"), model.get_dim("nM")
)
model.set_param("E", vectors_table)
def forward(model: Model, docs: List[Doc], is_train: bool):
if docs is None:
return []
ids = []
output = []
E = model.get_param("E")
nC = model.get_dim("nC")
nM = model.get_dim("nM")
nO = model.get_dim("nO")
# This assists in indexing; it's like looping over this dimension.
# Still consider this weird witch craft...But thanks to Mark Neumann
# for the tip.
nCv = model.ops.xp.arange(nC)
for doc in docs:
doc_ids = model.ops.asarray(doc.to_utf8_array(nr_char=nC))
doc_vectors = model.ops.alloc3f(len(doc), nC, nM)
# Let's say I have a 2d array of indices, and a 3d table of data. What numpy
# incantation do I chant to get
# output[i, j, k] == data[j, ids[i, j], k]?
doc_vectors[:, nCv] = E[nCv, doc_ids[:, nCv]] # type: ignore[call-overload, index]
output.append(doc_vectors.reshape((len(doc), nO)))
ids.append(doc_ids)
def backprop(d_output):
dE = model.ops.alloc(E.shape, dtype=E.dtype)
for doc_ids, d_doc_vectors in zip(ids, d_output):
d_doc_vectors = d_doc_vectors.reshape((len(doc_ids), nC, nM))
dE[nCv, doc_ids[:, nCv]] += d_doc_vectors[:, nCv]
model.inc_grad("E", dE)
return []
return output, backprop