Embedding #
Module wrappers for spec-layer embeddings.
We intentionally expose the one-hot embedding variant here (purely numeric, no integer indices).
Why one-hot in the spec layer:
- It avoids committing to an "index tensor" representation. In Lean, indices would typically live
in
Nat/Fin, which is great for proofs, but many numeric backends are scalar-only. - It keeps the forward definition completely algebraic: an embedding becomes
one_hot @ W.
In PyTorch terms: the usual API is nn.Embedding(vocab, embed_dim) on integer indices. This file
packages the equivalent "one_hot then matmul" semantics.
def
Spec.EmbeddingOneHotModuleSpec
{α : Type}
[Context α]
{vocab embedDim seqLen : ℕ}
(emb : EmbeddingSpec vocab embedDim α)
:
ModSpec.NNModuleSpec α (Shape.dim seqLen (Shape.dim vocab Shape.scalar))
(Shape.dim seqLen (Shape.dim embedDim Shape.scalar))
One-hot embedding wrapper: (seqLen,vocab) → (seqLen,embedDim).