📢 公告:OpenCV 系列文章目前正在重構整理中(進度約 65%),部分文章已暫時下架,後續會陸續補上,完成時間待定。感謝耐心等候!

Like Share Discussion Bookmark Smile

J.J. Huang   2026-04-04   Python OpenCV 07.物件偵測與辨識篇   瀏覽次數:次   DMCA.com Protection Status

Python | OpenCV 模型保存與載入

📚 前言

PyTorch 微調範例train.py 裡,訓練結束時你已經執行了這行:

1
torch.save(model.state_dict(), f"{SAVE_DIR}/best_model.pth")

這就是 state_dict 保存,也是最基本、最常用的格式。

但你可能會碰到幾個 train.py 無法直接解決的情境:

  • 訓練跑到一半當機,想從上次停下來的地方繼續
  • 把模型給別人用,對方不想再手刻一次你的模型結構
  • 要部署到沒有 Python 的環境(C++ 伺服器、手機 App)
  • 下一步要接入 OpenCV,需要轉成 OpenCV 看得懂的格式

這篇先分框架介紹各自的保存格式,再說明如何統一匯出成 ONNX。

💾 模型保存格式比較


圖:模型保存格式比較 ─ PyTorch、TensorFlow/Keras 與跨框架部署推薦格式一覽

💻 PyTorch 的保存格式

🔸 Checkpoint — 訓練中斷後從頭來過?

train.py 只在 val_acc 更新時存一次最佳權重。
如果你訓練 50 個 epoch,跑到第 42 個停電了,重新來就是從 epoch 1 開始。

Checkpoint 的目的就是定期存下「當下的完整狀態」,讓你可以接著跑,而不是重頭來。

「完整狀態」包含:

  • 模型當前的所有權重(不只最佳,而是當下)
  • optimizer 的狀態(學習率衰減、動量等都記在裡面)
  • 當前的 epoch 數
1
2
3
4
5
6
7
8
9
10
11
# save_checkpoint.py
import torch

# 在每個 epoch 結束時存一次 checkpoint
checkpoint = {
"epoch": epoch,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"val_acc": val_acc,
}
torch.save(checkpoint, "checkpoint.pth")
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# resume_training.py
import torch
import torch.nn as nn
from torchvision import models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 還原模型
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to(device)

# 還原 optimizer
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)

# 讀取 checkpoint
checkpoint = torch.load("checkpoint.pth", map_location=device)
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
start_epoch = checkpoint["epoch"] + 1

print(f"從 Epoch {start_epoch} 繼續訓練,上次 Val Acc:{checkpoint['val_acc']:.2f}%")

# 繼續訓練迴圈從 start_epoch 跑到 NUM_EPOCHS ...

💡 optimizer 的狀態為什麼重要?
Adam 等優化器會累積每個參數的動量(momentum)與學習率調整量。
如果只還原 model 權重,optimizer 狀態從零開始,訓練曲線會抖動,甚至退步。

🔸 完整模型格式 — 快速共用,但有代價

有時你只是想把模型丟給同事,不想讓他從頭重建模型結構。
PyTorch 支援把「模型結構 + 權重」一起打包:

1
2
3
4
5
6
7
8
9
# save_load_full.py
import torch

# 保存完整模型
torch.save(model, "full_model.pth")

# 載入
model = torch.load("full_model.pth", map_location="cpu")
model.eval()

聽起來很方便,但有個限制:對方的環境必須能 import 到你定義模型結構的那個 .py
PyTorch 儲存的不是結構本身,而是「哪個 class、在哪個路徑」。
如果對方的專案目錄不一樣,或你用了自訂的 class,就會在 torch.load 時直接報錯。

這就是為什麼正式部署時不用這個格式,而是用下面的 TorchScript。

🔸 TorchScript — 不想依賴 Python 環境時

如果你的推論服務是用 C++ 寫的,或者要部署到 Android/iOS App,Python 程式碼根本沒辦法跑。

TorchScript 把模型序列化成一個獨立的格式,不需要原始 Python class,也不需要 PyTorch 的 Python 環境,只要有 LibTorch(C++ 版的 PyTorch)就能載入執行。

匯出時有兩種方式,選哪個看模型結構:

1
2
3
4
5
6
7
8
9
10
11
12
13
# torchscript_export.py
import torch

model.eval()

# 方式一:trace — 輸入形狀固定時使用(大多數 CNN 都適合)
example_input = torch.randn(1, 3, 224, 224)
scripted = torch.jit.trace(model, example_input)
scripted.save("model_scripted.pt")

# 方式二:script — 模型裡有 if/for 等動態控制流時使用
scripted = torch.jit.script(model)
scripted.save("model_scripted.pt")
1
2
3
4
5
6
7
8
9
10
11
# load_torchscript.py
import torch

# 載入時完全不需要原始 model class
model = torch.jit.load("model_scripted.pt")
model.eval()

dummy = torch.randn(1, 3, 224, 224)
with torch.no_grad():
output = model(dummy)
print(f"輸出形狀:{output.shape}")

💡 trace vs script 怎麼選?
trace 會追蹤一次前向傳遞的計算路徑並記錄下來,所以模型裡若有「根據輸入內容走不同分支」的邏輯,trace 只會記錄那一次的路徑。這類情況要改用 script,讓 TorchScript 直接分析程式碼。ResNet、VGG、EfficientNet 等標準架構用 trace 都沒問題。

💻 TensorFlow/Keras 的保存格式

TF/Keras 的情況相對簡單,主要有三種格式,對應不同的使用情境:

格式 使用情境
SavedModel(資料夾) 部署到 TF Serving、生產環境推薦
.keras Keras 環境內部使用,Keras 2.12+
.h5 相容舊版專案
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# save_load_keras.py
import tensorflow as tf

# 保存
model.save("saved_model/") # SavedModel(推薦,部署用)
model.save("model.keras") # .keras 格式
model.save("model.h5") # .h5 舊版格式

# 載入(三種格式都用同一個函式)
model = tf.keras.models.load_model("saved_model/")
model = tf.keras.models.load_model("model.keras")
model = tf.keras.models.load_model("model.h5")

model.summary()

💡 SavedModel vs .h5:SavedModel 保存了完整計算圖與簽名,是 TF Serving 部署的標準格式。.h5 只在 Keras 環境內使用,搬到其他部署工具通常不認識。如果沒有特別原因,SavedModel 是首選。

💻 ONNX — 下一篇要用的格式

ONNX(Open Neural Network Exchange) 是一個開放的模型交換格式,設計目標是讓不同框架之間可以互通。

這篇特別提它,是因為下一步要讓 OpenCV 直接讀取你的模型。
OpenCV 的 cv2.dnn 模組不認識 .pth,也不認識 .keras,但它認識 ONNX。

先安裝:

1
pip install onnx onnxscript onnxruntime

🔸 PyTorch → ONNX

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# export_onnx_pytorch.py
import os
import torch
import torch.nn as nn
from torch.export import Dim
from torchvision import models

TASK_NAME = "cat_dog"
SAVE_DIR = f"models/resnet18/{TASK_NAME}"
ONNX_DIR = f"{SAVE_DIR}/onnx"
os.makedirs(ONNX_DIR, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 重建模型結構(必須與訓練時完全相同)
model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, 2)
model.load_state_dict(torch.load(f"{SAVE_DIR}/best_model.pth", map_location=device))
model = model.to(device)
model.eval()

dummy_input = torch.randn(1, 3, 224, 224, device=device)

batch_dim = Dim("batch_size")

torch.onnx.export(
model,
dummy_input,
f"{ONNX_DIR}/model.onnx",
input_names=["input"],
output_names=["output"],
dynamic_shapes={"x": {0: batch_dim}},
opset_version=18,
)
print(f"ONNX 匯出完成:{ONNX_DIR}/model.onnx")


圖:載入 best_model.pth 的權重,重建 ResNet18 後匯出為 ONNX 格式

🔸 TensorFlow/Keras → ONNX

1
pip install tf2onnx
1
2
3
4
5
6
7
8
9
10
11
12
# export_onnx_keras.py
import tensorflow as tf
import tf2onnx
import onnx

model = tf.keras.models.load_model("saved_model/")

input_spec = (tf.TensorSpec(model.input_shape, tf.float32, name="input"),)
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=input_spec)

onnx.save(onnx_model, "model_keras.onnx")
print("ONNX 匯出完成")

驗證 ONNX 模型

匯出後先用 onnxruntime 跑一次確認輸出正常,再接入 OpenCV:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
# verify_onnx.py
import onnxruntime as ort
import numpy as np

TASK_NAME = "cat_dog"
ONNX_PATH = f"models/resnet18/{TASK_NAME}/onnx/model.onnx"

session = ort.InferenceSession(ONNX_PATH)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

dummy = np.random.randn(1, 3, 224, 224).astype(np.float32)
result = session.run([output_name], {input_name: dummy})
print(f"輸出形狀:{result[0].shape}")


圖:使用 onnxruntime 驗證 ONNX 模型的輸入輸出節點與推論結果形狀

⚠️ 注意事項

  • 跨裝置載入要加 map_location:在 CPU 機器上載入原本在 GPU 上訓練的 .pth,必須加 map_location="cpu",否則 PyTorch 會找不到 GPU 裝置而報錯。
  • 載入後呼叫 model.eval():PyTorch 模型載入後預設是訓練模式,Dropout 和 BatchNorm 的行為與推論時不同,推論前必須切換。
  • ONNX opset 版本:PyTorch 2.5+ 的 dynamo exporter 從 opset 18 起實作,建議直接用 opset_version=18dynamic_axes 在新版已棄用,改用 dynamic_shapes 搭配 torch.export.Dim 宣告動態維度。
  • checkpoint 只是工具,不替代最佳模型checkpoint.pth 存的是「當下狀態」,用來恢復訓練;best_model.pth 存的是「最佳準確率的權重」,用來推論。兩個用途不同,不要混用。

🎯 結語

模型保存的核心原則很簡單:
「先用 state_dict 完成基本保存與推論,再根據實際需求選擇其他格式。」

  • 想繼續訓練 → 用 Checkpoint
  • 想在 Python 環境推論 → 用 state_dict(train.py 已經在用)
  • 想快速共用給同環境的人 → 可以考慮完整模型(但要注意路徑問題)
  • 想部署到沒有 Python 的環境 → 用 TorchScript
  • 想接入 OpenCV 或跨框架使用 → 用 ONNX(下一篇文章的重點)

掌握這些格式的選擇時機後,你就能根據不同情境,選擇最適合的保存方式。

下一步是 與 OpenCV 整合推論,把剛才匯出的 ONNX 模型接入 OpenCV,對圖片或即時攝影機畫面進行推論。

📖 如在學習過程中遇到疑問,或是想了解更多相關主題,建議回顧一下 Python | OpenCV 系列導讀,掌握完整的章節目錄,方便快速找到你需要的內容。

註:以上參考了
PyTorch 官方文件 — Saving and Loading Models
PyTorch 官方文件 — TorchScript
TensorFlow 官方文件 — Save and load models
ONNX 官方文件
tf2onnx GitHub