TorchLean API

NN.Proofs.Autograd.Tape.Nodes.Elementwise

Elementwise tape nodes #

Reusable NodeFDerivCorrect wrappers for scalar functions lifted pointwise to tensors, including common activations such as ReLU, sigmoid, tanh, SiLU, GELU, ELU, and safe differentiable variants.

noncomputable def Proofs.Autograd.TapeNodes.getVec {Γ : List Spec.Shape} {n : } (idx : Idx Γ (Spec.Shape.dim n Spec.Shape.scalar)) (x : CtxVec Γ) :
Vec n

CtxVec.get specialized to vector shapes.

Instances For

    CtxVec.getCLM specialized to vector shapes .dim n .scalar.

    Instances For
      @[simp]
      theorem Proofs.Autograd.TapeNodes.getCLM_apply_ofLp {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (x : CtxVec Γ) (i : Fin s.size) :
      ((CtxVec.getCLM idx) x).ofLp i = (CtxVec.get idx x).ofLp i
      noncomputable def Proofs.Autograd.TapeNodes.singleVec {Γ : List Spec.Shape} {n : } (idx : Idx Γ (Spec.Shape.dim n Spec.Shape.scalar)) (v : Vec n) :

      Inject a Vec n into a vectorized context at idx (fills other blocks with zeros).

      Instances For
        noncomputable def Proofs.Autograd.TapeNodes.elemwise {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (f f' : ) :
        Node Γ s

        Elementwise node: apply a scalar function pointwise on a context entry.

        Instances For
          noncomputable def Proofs.Autograd.TapeNodes.elemwiseFderiv {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (f f' : ) (hf : ∀ (z : ), HasDerivAt f (f' z) z) :

          Analytic correctness for elemwise nodes from a scalar HasDerivAt hypothesis.

          Instances For
            noncomputable def Proofs.Autograd.TapeNodes.elemwiseFderivAt {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (f f' : ) (xV : CtxVec Γ) (hf : ∀ (i : Fin s.size), HasDerivAt f (f' ((CtxVec.get idx xV).ofLp i)) ((CtxVec.get idx xV).ofLp i)) :

            Pointwise analytic correctness for elemwise nodes from a coordinatewise HasDerivAt hypothesis.

            Instances For
              noncomputable def Proofs.Autograd.TapeNodes.relu {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
              Node Γ s

              Runtime relu node (elementwise; nondifferentiable at zero).

              Instances For
                noncomputable def Proofs.Autograd.TapeNodes.reluFderivAt {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (xV : CtxVec Γ) (hx : ∀ (i : Fin s.size), (CtxVec.get idx xV).ofLp i 0) :

                Pointwise NodeFDerivCorrectAt for relu under the assumption that inputs are nonzero.

                Instances For
                  noncomputable def Proofs.Autograd.TapeNodes.abs {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                  Node Γ s

                  Runtime abs node (elementwise; nondifferentiable at zero).

                  Instances For
                    noncomputable def Proofs.Autograd.TapeNodes.absFderivAt {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (xV : CtxVec Γ) (hx : ∀ (i : Fin s.size), (CtxVec.get idx xV).ofLp i 0) :

                    Pointwise NodeFDerivCorrectAt for abs under the assumption that inputs are nonzero.

                    Instances For
                      noncomputable def Proofs.Autograd.TapeNodes.log {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                      Node Γ s

                      Runtime log node (elementwise; differentiable only away from zero).

                      Instances For
                        noncomputable def Proofs.Autograd.TapeNodes.logFderivAt {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (xV : CtxVec Γ) (hx : ∀ (i : Fin s.size), (CtxVec.get idx xV).ofLp i 0) :

                        Pointwise NodeFDerivCorrectAt for log under the assumption that inputs are nonzero.

                        Instances For
                          noncomputable def Proofs.Autograd.TapeNodes.inv {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                          Node Γ s

                          Elementwise inverse node (differentiable only away from zero).

                          Instances For
                            noncomputable def Proofs.Autograd.TapeNodes.invFderivAt {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (xV : CtxVec Γ) (hx : ∀ (i : Fin s.size), (CtxVec.get idx xV).ofLp i 0) :

                            Pointwise NodeFDerivCorrectAt for inv under the assumption that inputs are nonzero.

                            Instances For
                              theorem Proofs.Autograd.TapeNodes.hasDerivAt_sqrt_clamp_of_pos {x : } (hx : 0 < x) :
                              HasDerivAt (fun (y : ) => (max y 0)) (1 / (2 * x)) x

                              Derivative of the scalar function y ↦ sqrt (max y 0) at positive points.

                              noncomputable def Proofs.Autograd.TapeNodes.sqrtClamp {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                              Node Γ s

                              Elementwise "clamped sqrt": sqrt (max x 0) (differentiable on x > 0).

                              Instances For
                                noncomputable def Proofs.Autograd.TapeNodes.sqrtClampFderivAt {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (xV : CtxVec Γ) (hx : ∀ (i : Fin s.size), 0 < (CtxVec.get idx xV).ofLp i) :

                                Pointwise NodeFDerivCorrectAt for sqrt_clamp under the assumption that inputs are strictly positive.

                                Instances For
                                  noncomputable def Proofs.Autograd.TapeNodes.sqrt {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                                  Node Γ s

                                  Runtime sqrt node (elementwise; nondifferentiable at zero).

                                  Instances For
                                    noncomputable def Proofs.Autograd.TapeNodes.sqrtFderivAt {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (xV : CtxVec Γ) (hx : ∀ (i : Fin s.size), (CtxVec.get idx xV).ofLp i 0) :

                                    Pointwise NodeFDerivCorrectAt for sqrt under the assumption that inputs are nonzero.

                                    Instances For
                                      noncomputable def Proofs.Autograd.TapeNodes.logistic {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                                      Node Γ s

                                      Runtime scalar logistic node, applied elementwise.

                                      Vector and matrix softmax use the dedicated last-axis softmax nodes below; this node is the one-dimensional logistic map used by scalar activations.

                                      Instances For

                                        Global NodeFDerivCorrect for logistic (uses the scalar derivative lemma).

                                        Instances For
                                          noncomputable def Proofs.Autograd.TapeNodes.sigmoid {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                                          Node Γ s

                                          Runtime sigmoid node (elementwise).

                                          Instances For

                                            Global NodeFDerivCorrect for sigmoid.

                                            Instances For
                                              noncomputable def Proofs.Autograd.TapeNodes.tanh {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                                              Node Γ s

                                              Runtime tanh node (elementwise).

                                              Instances For
                                                noncomputable def Proofs.Autograd.TapeNodes.tanhFderiv {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :

                                                Global NodeFDerivCorrect for tanh.

                                                Instances For
                                                  noncomputable def Proofs.Autograd.TapeNodes.softplus {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                                                  Node Γ s

                                                  Runtime softplus node (elementwise, smooth ReLU surrogate).

                                                  Instances For

                                                    Global NodeFDerivCorrect for softplus.

                                                    Instances For
                                                      noncomputable def Proofs.Autograd.TapeNodes.silu {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                                                      Node Γ s

                                                      Runtime silu node (elementwise).

                                                      Instances For
                                                        noncomputable def Proofs.Autograd.TapeNodes.siluFderiv {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :

                                                        Global NodeFDerivCorrect for SiLU.

                                                        Instances For
                                                          noncomputable def Proofs.Autograd.TapeNodes.gelu {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                                                          Node Γ s

                                                          Runtime tanh-approximate gelu node (elementwise).

                                                          Instances For
                                                            noncomputable def Proofs.Autograd.TapeNodes.geluFderiv {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :

                                                            Global NodeFDerivCorrect for tanh-approximate GELU.

                                                            Instances For
                                                              noncomputable def Proofs.Autograd.TapeNodes.safeLog {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (ε : ) :
                                                              Node Γ s

                                                              Runtime safe_log node (elementwise, always-defined log surrogate).

                                                              Instances For
                                                                noncomputable def Proofs.Autograd.TapeNodes.safeLogFderiv {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (ε : ) ( : 0 < ε) :

                                                                Global NodeFDerivCorrect for safe_log (requires 0 < ε).

                                                                Instances For
                                                                  noncomputable def Proofs.Autograd.TapeNodes.smoothAbs {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (ε : ) :
                                                                  Node Γ s

                                                                  Runtime smooth_abs node (elementwise, smooth abs surrogate).

                                                                  Instances For
                                                                    noncomputable def Proofs.Autograd.TapeNodes.smoothAbsFderiv {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (ε : ) ( : 0 < ε) :

                                                                    Global NodeFDerivCorrect for smooth_abs (requires 0 < ε).

                                                                    Instances For
                                                                      noncomputable def Proofs.Autograd.TapeNodes.exp {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                                                                      Node Γ s

                                                                      Runtime exp node (elementwise).

                                                                      Instances For
                                                                        noncomputable def Proofs.Autograd.TapeNodes.expFderiv {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :

                                                                        Global NodeFDerivCorrect instance for the elementwise exponential.

                                                                        Instances For
                                                                          noncomputable def Proofs.Autograd.TapeNodes.sinh {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                                                                          Node Γ s

                                                                          Runtime sinh node (elementwise).

                                                                          Instances For
                                                                            noncomputable def Proofs.Autograd.TapeNodes.sinhFderiv {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :

                                                                            Global NodeFDerivCorrect for elementwise hyperbolic sine.

                                                                            Instances For
                                                                              noncomputable def Proofs.Autograd.TapeNodes.cosh {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :
                                                                              Node Γ s

                                                                              Runtime cosh node (elementwise).

                                                                              Instances For
                                                                                noncomputable def Proofs.Autograd.TapeNodes.coshFderiv {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) :

                                                                                Global NodeFDerivCorrect for elementwise hyperbolic cosine.

                                                                                Instances For
                                                                                  noncomputable def Proofs.Autograd.TapeNodes.elu {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (alpha : ) :
                                                                                  Node Γ s

                                                                                  Runtime elu node (elementwise; nondifferentiable at zero unless alpha = 1).

                                                                                  Instances For
                                                                                    noncomputable def Proofs.Autograd.TapeNodes.eluFderivAt {Γ : List Spec.Shape} {s : Spec.Shape} (idx : Idx Γ s) (alpha : ) (xV : CtxVec Γ) (hx : ∀ (i : Fin s.size), (CtxVec.get idx xV).ofLp i 0) :
                                                                                    NodeFDerivCorrectAt (elu idx alpha) xV

                                                                                    Pointwise NodeFDerivCorrectAt for ELU under the usual no-coordinate-at-the-kink assumption.

                                                                                    For arbitrary alpha, ELU has left derivative alpha and right derivative 1 at zero. Keeping the hypothesis here avoids baking PyTorch's subgradient convention into a mathematical derivative theorem.

                                                                                    Instances For
                                                                                      noncomputable def Proofs.Autograd.TapeNodes.unaryOp {Γ : List Spec.Shape} {inDim outDim : } (idx : Idx Γ (Spec.Shape.dim inDim Spec.Shape.scalar)) (C : OpSpecFDerivCorrect inDim outDim) :

                                                                                      Unary node applying an analytically-correct OpSpec at a context index.

                                                                                      Instances For
                                                                                        noncomputable def Proofs.Autograd.TapeNodes.unaryOpFderiv {Γ : List Spec.Shape} {inDim outDim : } (idx : Idx Γ (Spec.Shape.dim inDim Spec.Shape.scalar)) (C : OpSpecFDerivCorrect inDim outDim) :

                                                                                        NodeFDerivCorrect for unaryOp.

                                                                                        Instances For
                                                                                          noncomputable def Proofs.Autograd.TapeNodes.linear {Γ : List Spec.Shape} {inDim outDim : } (x : Idx Γ (Spec.Shape.dim inDim Spec.Shape.scalar)) (m : Spec.LinearSpec inDim outDim) :

                                                                                          Linear layer as a single tape node (fixed weights/bias in the Spec.LinearSpec).

                                                                                          Instances For
                                                                                            noncomputable def Proofs.Autograd.TapeNodes.linearFderiv {Γ : List Spec.Shape} {inDim outDim : } (x : Idx Γ (Spec.Shape.dim inDim Spec.Shape.scalar)) (m : Spec.LinearSpec inDim outDim) :

                                                                                            NodeFDerivCorrect for linear: the node derivative matches the spec's OpSpec derivative.

                                                                                            Instances For