線性回歸

線性回歸#

讓我們實作一個基本的線性回歸模型,作為學習 MLX 的起點。先匯入核心軟體包並設定一些問題的中繼資料:

import mlx.core as mx

num_features = 100
num_examples = 1_000
num_iters = 10_000  # iterations of SGD
lr = 0.01  # learning rate for SGD

我們將透過以下方式產生合成資料集:

  1. 抽樣設計矩陣 X

  2. 抽樣真實參數向量 w_star

  3. X @ w_star 加上高斯雜訊來計算目標值 y

# True parameters
w_star = mx.random.normal((num_features,))

# Input examples (design matrix)
X = mx.random.normal((num_examples, num_features))

# Noisy labels
eps = 1e-2 * mx.random.normal((num_examples,))
y = X @ w_star + eps

我們會使用 SGD 來尋找最佳權重。首先定義平方損失,並取得損失對參數的梯度函式。

def loss_fn(w):
    return 0.5 * mx.mean(mx.square(X @ w - y))

grad_fn = mx.grad(loss_fn)

先隨機初始化參數 w 以開始最佳化,接著重複更新參數 num_iters 次。

w = 1e-2 * mx.random.normal((num_features,))

for _ in range(num_iters):
    grad = grad_fn(w)
    w = w - lr * grad
    mx.eval(w)

最後計算學到的參數的損失,並確認它們接近真實參數。

loss = loss_fn(w)
error_norm = mx.sum(mx.square(w - w_star)).item() ** 0.5

print(
    f"Loss {loss.item():.5f}, |w-w*| = {error_norm:.5f}, "
)
# Should print something close to: Loss 0.00005, |w-w*| = 0.00364

完整的 線性回歸邏輯迴歸 範例可在 MLX GitHub 專案中取得。