TorchLean API

NN.Spec.Models.Hmm

Hidden Markov Model (HMM) (spec model) #

This file defines an HMM with discrete observations:

The model parameters are:

We represent observations as List (Fin nObservations) to keep the observation alphabet explicit and avoid mixing “probabilities” with “indices” in the scalar type α.

Notation and shapes #

We use the conventional HMM notation:

An observation sequence is o₀, o₁, ..., o_{T-1} where each o_t : Fin nObservations.

References:

PyTorch analogy:

In practice, PyTorch users often reach for a dedicated HMM library (e.g. hmmlearn) or implement HMMs in log-space with logsumexp; TorchLean keeps the spec in a simple, explicit form that is good for reading and proofs.

structure Spec.HMMSpec (α : Type) (nStates nObservations : ) :

A discrete-observation HMM.

We do not enforce probabilistic validity (nonnegativity / rows summing to 1) at the type level; that is a modeling assumption, similar to how PyTorch will happily store unconstrained tensors until you feed them to a distribution or a loss.

Instances For
    @[reducible, inline]
    abbrev Spec.ObservationSeq (nObservations : ) :

    Observation sequence as a list of discrete symbols (indices into the observation alphabet).

    Instances For

      Basic helpers #

      def Spec.getEmissionProbDiscrete {α : Type} {nStates nObservations : } (m : HMMSpec α nStates nObservations) (state : Fin nStates) (obs : Fin nObservations) :
      α

      Get emission probability B[state, obs] for a discrete observation symbol.

      Instances For

        Baum–Welch (EM) training #

        The forward-pass APIs above are enough to use a fixed HMM, but a “fully implemented” baseline should also include classical training. For discrete-observation HMMs, the standard training procedure is the Baum–Welch algorithm (an EM procedure):

        This implementation uses scaled forward–backward to reduce numerical underflow: each forward message α_t is normalized by a scalar c_t, and the backward messages divide by those same scalars. The sequence likelihood is then ∏_t c_t, so the log-likelihood is Σ_t log c_t.

        Concretely:

        This is the same basic idea used in many practical HMM implementations (sometimes also expressed as log-space forward–backward).

        This is deterministic and written for clarity; it is not intended to be a high-performance HMM trainer.

        Normalize a nonnegative vector v to sum to 1, returning (v / sum(v), sum(v)).

        If the sum is 0, we fall back to a uniform distribution. This keeps the forward pass total and avoids propagating NaN/division-by-zero behavior into later computations.

        Instances For
          def Spec.emissionVec {α : Type} {nStates nObservations : } (m : HMMSpec α nStates nObservations) (obs : Fin nObservations) :

          Emission probabilities B[:, obs] as a vector over states.

          Instances For
            def Spec.hmmForwardScaled {α : Type} [Context α] {nStates nObservations : } [Inhabited (Fin nObservations)] (m : HMMSpec α nStates nObservations) (observations : ObservationSeq nObservations) :

            Scaled forward pass, returning (α_t, c_t) for each timestep.

            • Each α_t is normalized to sum to 1.
            • Each c_t is the normalization constant used at step t.

            If you need the total likelihood, multiply the scales: p(o₀:T-1) = ∏_t c_t.

            Instances For
              def Spec.baumWelchStepSpec {α : Type} [Context α] {nStates nObservations : } [Inhabited (Fin nObservations)] [DecidableEq (Fin nObservations)] (m : HMMSpec α nStates nObservations) (observations : ObservationSeq nObservations) :
              HMMSpec α nStates nObservations × α

              One Baum–Welch (EM) step on a single sequence.

              Instances For
                def Spec.baumWelchEpochSpec {α : Type} [Context α] {nStates nObservations : } [Inhabited (Fin nObservations)] [DecidableEq (Fin nObservations)] (m : HMMSpec α nStates nObservations) (dataset : List (ObservationSeq nObservations)) :
                HMMSpec α nStates nObservations × α

                One Baum–Welch epoch over a dataset of observation sequences (sums expected counts).

                Instances For

                  Forward / likelihood #

                  def Spec.hmmForwardSpec {α : Type} [Context α] {nStates nObservations : } [Inhabited (Fin nObservations)] (m : HMMSpec α nStates nObservations) (observations : ObservationSeq nObservations) :
                  α

                  Forward algorithm (scaled) returning the total sequence likelihood.

                  Implementation note: we compute the likelihood from the per-timestep scaling factors produced by hmm_forward_scaled. This avoids the worst underflow behavior of multiplying many small probabilities directly.

                  Instances For
                    def Spec.hmmBatchedForwardSpec {α : Type} [Context α] {nStates nObservations : } [Inhabited (Fin nObservations)] (m : HMMSpec α nStates nObservations) (observations : List (ObservationSeq nObservations)) :
                    List α

                    Batched forward pass: compute likelihood for each observation sequence in a list.

                    Instances For
                      def Spec.hmmInitSpec {α : Type} [Context α] {nStates nObservations : } :
                      HMMSpec α nStates nObservations

                      Initialize an HMM with uniform (uninformative) parameters.

                      This is a deterministic uniform initializer (useful for examples/tests); it is not intended as a statistically meaningful random initialization.

                      Instances For
                        def Spec.hmmLogLikelihoodSpec {α : Type} [Context α] {nStates nObservations : } [Inhabited (Fin nObservations)] (m : HMMSpec α nStates nObservations) (observations : ObservationSeq nObservations) :
                        α

                        Log-likelihood of an observation sequence.

                        We compute this from the same scaling factors used in the EM implementation: log p(x_{0:T-1}) = Σ_t log c_t.

                        Instances For