TorchLean API

NN.Runtime.RL.PPO.Collect

PPO Rollout Collection (Checked Sessions) #

This file provides the rollout-collection loop used by executable PPO workflows. The key goals are:

The unified session interface lives in NN.Runtime.RL.Session (Session.CheckedSession). The lower-level Gymnasium subprocess protocol is implemented in NN.Runtime.RL.Gymnasium.

References:

Rollout collection (ergonomic core API) #

def Runtime.RL.PPO.collectRolloutSessionWith {α : Type} [Context α] {obsShape : Spec.Shape} {nActions horizon : } {Sess : Type} [Fact (0 < horizon)] [Fact (0 < nActions)] (start : IO Sess) (observe : SessSpec.Tensor Float obsShape) (stepChecked : SessFin nActionsIO (Boundary.Transition obsShape nActions × Sess)) (castObs castReward : Floatα) (predictLogits : Spec.Tensor α obsShapeSpec.Tensor α (Spec.Shape.dim nActions Spec.Shape.scalar)) (predictValue : Spec.Tensor α obsShapeα) (rngSeed rngCounter : ) :
IO (Rollout α obsShape nActions horizon × )

Collect a fixed-horizon rollout from any stateful environment session that can produce fully-observed, contract-checked transitions.

The caller provides:

  • start: how to initialize the session (often reset),
  • observe: how to read the current observation from the session,
  • stepChecked: one checked step returning an observed transition and the updated session,
  • castObs to inject host Float observations into the chosen scalar backend α,
  • castReward to inject host Float rewards into the chosen scalar backend α,
  • predictLogits for the current actor,
  • predictValue for the current critic (returns a scalar α).

This keeps the PPO runtime API small while still supporting the “compiled model + parameters” calling convention used throughout TorchLean.

Instances For

    Rollout collection from a checked session #

    def Runtime.RL.PPO.collectRolloutCheckedSessionWith {α : Type} [Context α] {obsShape : Spec.Shape} {nActions horizon : } [Fact (0 < horizon)] [Fact (0 < nActions)] (sess : Session.CheckedSession obsShape nActions) (castObs castReward : Floatα) (predictLogits : Spec.Tensor α obsShapeSpec.Tensor α (Spec.Shape.dim nActions Spec.Shape.scalar)) (predictValue : Spec.Tensor α obsShapeα) (rngSeed rngCounter : ) :
    IO (Rollout α obsShape nActions horizon × )

    Collect a fixed-horizon rollout from a unified Runtime.RL.Session.CheckedSession.

    Instances For

      Rollout collection from Gymnasium (subprocess bridge) #

      def Runtime.RL.PPO.collectRolloutWith {α : Type} [Context α] {obsShape : Spec.Shape} {nActions horizon : } [Fact (0 < horizon)] [Fact (0 < nActions)] (castObs castReward : Floatα) (gym : Gymnasium.Client obsShape nActions) (predictLogits : Spec.Tensor α obsShapeSpec.Tensor α (Spec.Shape.dim nActions Spec.Shape.scalar)) (predictValue : Spec.Tensor α obsShapeα) (rngSeed rngCounter resetSeed : ) :
      IO (Rollout α obsShape nActions horizon × )

      Collect a fixed-horizon rollout from a Gymnasium subprocess environment.

      This is a thin wrapper around collectRolloutSessionWith specialized to Gymnasium.Session.

      Instances For