Logistic regression (spec model) #
This file implements a small, deterministic logistic regression baseline.
Model (binary classification):
- logits:
z = X w + b - probabilities:
p = σ(z)whereσis the logistic sigmoid
PyTorch mental model:
- parameters correspond to
nn.Linear(p, 1)(weights + bias), - probabilities correspond to
torch.sigmoid(logits), - training is a simple gradient-descent loop (similar to
torch.optim.SGD), written in a simple, explicit style rather than tuned for performance.
Notes:
- We augment the input matrix with a column of ones to represent the intercept term.
- This is reference/spec code: it prioritizes clarity and auditability over performance.
Numerical note:
PyTorch often uses BCEWithLogitsLoss for stability (it works directly on logits without forming
sigmoid explicitly). Here we keep the math explicit.
Parameters for logistic regression: a weight vector w and scalar intercept b.
We store intercept : α separately rather than folding it into weights, but fitLogistic
internally learns (p + 1) parameters by augmenting the input with a trailing column of ones.
- weights : Spec.Tensor α (Spec.Shape.dim p Spec.Shape.scalar)
p-dimensional weight vectorw. - intercept : α
Scalar intercept term
b.
Instances For
Augment an n × p design matrix with a final column of ones.
This lets us represent the affine model X w + b as a single matrix-vector product with a
(p + 1)-vector of parameters.
Instances For
Gradient of the logistic negative log-likelihood, expressed as Xᵀ (σ(Xw) - y).
This is the standard expression used for (unregularized) logistic regression under labels
y ∈ {0,1}. We do not divide by n here; callers can rescale if they want the mean loss.
Instances For
Fit logistic regression by plain gradient descent (structural recursion).
This is a simple deterministic baseline that is easy to reason about. It does not attempt to match optimized solvers (LBFGS/Newton/IRLS); it is a small reference implementation that can be instantiated over different scalar backends.
Instances For
Instances For
Predict probabilities σ(Xw + b) for each row in X.
Instances For
Convert probabilities to hard labels using a threshold (default 0.5).