Last-axis softmax and log-softmax tape nodes #
Softmax and log-softmax over matrix rows, together with their Jacobian/VJP lemmas and
NodeFDerivCorrect wrappers for graph-level autograd proofs.
@[reducible, inline]
Flattened size for an m×n matrix when viewed as a single vector (m*n).
Instances For
@[simp]
theorem
Proofs.Autograd.TapeNodes.SoftmaxLastAxis.divNat_finProdFinEquiv
{m n : ℕ}
(p : Fin m × Fin n)
:
@[simp]
theorem
Proofs.Autograd.TapeNodes.SoftmaxLastAxis.modNat_finProdFinEquiv
{m n : ℕ}
(p : Fin m × Fin n)
:
theorem
Proofs.Autograd.TapeNodes.SoftmaxLastAxis.hasFDerivAt_forwardMN
{m n : ℕ}
(x : Vec (MNSize m n))
:
HasFDerivAt forwardMN (derivMN x) x
HasFDerivAt statement for forwardMN, using the rowwise softmax derivative.
@[reducible, inline]
Reuse the m*n flattened size from SoftmaxLastAxis.
Instances For
theorem
Proofs.Autograd.TapeNodes.LogSoftmaxLastAxis.hasFDerivAt_forwardMN
{m n : ℕ}
(x : Vec (MNSize m n))
:
HasFDerivAt forwardMN (derivMN x) x
HasFDerivAt statement for logSoftmaxVec applied rowwise.
noncomputable def
Proofs.Autograd.TapeNodes.softmaxLast
{Γ : List Spec.Shape}
{m n : ℕ}
(idx : Idx Γ (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar)))
:
Node Γ (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar))
Tape node for applying softmaxVec along the last axis of an m×n matrix (rowwise).
Instances For
noncomputable def
Proofs.Autograd.TapeNodes.logSoftmaxLast
{Γ : List Spec.Shape}
{m n : ℕ}
(idx : Idx Γ (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar)))
:
Node Γ (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar))
Tape node for applying logSoftmaxVec along the last axis of an m×n matrix (rowwise).
Instances For
noncomputable def
Proofs.Autograd.TapeNodes.softmaxLastFderiv
{Γ : List Spec.Shape}
{m n : ℕ}
(idx : Idx Γ (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar)))
:
NodeFDerivCorrect (softmaxLast idx)
NodeFDerivCorrect for softmax_last (rowwise softmax).
Instances For
noncomputable def
Proofs.Autograd.TapeNodes.logSoftmaxLastFderiv
{Γ : List Spec.Shape}
{m n : ℕ}
(idx : Idx Γ (Spec.Shape.dim m (Spec.Shape.dim n Spec.Shape.scalar)))
:
NodeFDerivCorrect for log_softmax_last (rowwise log-softmax).