📢 公告:OpenCV 系列文章重構完成(75%)。專案實作篇仍在製作中,完成時間未定,敬請期待!→ 查看文章索引

熱門系列
Like Share Discussion Bookmark Smile

J.J. Huang   2026-03-28   Python OpenCV 07.物件偵測與辨識篇   瀏覽次數:次   DMCA.com Protection Status

Python | OpenCV PyTorch 微調範例

📚 前言

在上一篇 遷移學習與微調原理 中,我們了解了 Feature Extraction、Fine-Tuning、Full Training 三種策略,以及如何根據資料量做選擇。

這一篇進入實作,完整示範如何用 PyTorch 將 ResNet18 以 Feature Extraction 策略微調為自訂分類模型。程式碼的每個環節都對應上一篇的原理,建議對照閱讀。

🛠️ 環境安裝

1
pip install torch torchvision


圖:執行 pip install 安裝 torch 與 torchvision 的結果

💡 以上安裝的是 CPU 版本,這篇範例使用 CPU 執行即可。GPU 加速的安裝與設定將在後續篇章介紹。

🗃️ 資料集目錄結構


圖:專案目錄結構建議 ─ 使用樹狀方式呈現原始素材、訓練資料與模型輸出的分層管理

PyTorch 的 ImageFolder 會自動依資料夾名稱建立類別對應,目錄結構如下:

1
2
3
4
5
6
7
8
9
10
11
12
data/
└── cat_dog/
├── train/
│ ├── cat/
│ │ ├── 0001.jpg
│ │ └── ...
│ └── dog/
│ ├── 0001.jpg
│ └── ...
└── val/
├── cat/
└── dog/

☝ 這只是建議的目錄結構,並非強制規定,你可以依照自己的習慣調整。但請注意,資料夾名稱(例如 cat、dog)會被 ImageFolder 自動當作類別標籤。

如何蒐集圖片?

這篇範例使用 catdog 兩個類別。可以用在 資料蒐集 篇介紹過的 BingImageCrawler 批次下載,分別蒐集訓練集與驗證集:

1
2
3
4
5
6
7
8
9
10
11
12
13
# download_dataset.py
from icrawler.builtin import BingImageCrawler

classes = ["cat", "dog"]

for cls in classes:
# 訓練集:最多下載 200 張(實際數量依搜尋結果而定)
crawler = BingImageCrawler(storage={"root_dir": f"data/cat_dog/train/{cls}"})
crawler.crawl(keyword=cls, max_num=200)

# 驗證集:最多下載 50 張(keyword 加條件避免與訓練集重複)
crawler = BingImageCrawler(storage={"root_dir": f"data/cat_dog/val/{cls}"})
crawler.crawl(keyword=f"{cls} photo", max_num=50)


圖:使用 BingImageCrawler 批次下載 cat 與 dog 圖片並分別存入 train/val 資料夾

💡 max_num 是下載上限,不保證一定能下載到該數量,實際數量取決於 Bing 的搜尋結果。訓練時只要每類有 50 張以上即可正常執行。

⚠️ 注意版權問題,建議僅用於研究與學習目的。

🧠 函式與參數說明


圖:資料載入的兩個核心工具 ─ ImageFolder 與 DataLoader 的功能與常用參數說明

🗂️ 模型輸出目錄結構

訓練完成後,模型權重與類別設定會存放在同一個目錄,方便日後對應使用:

1
2
3
4
5
models/
└── resnet18/
└── cat_dog/
├── best_model.pth ← 模型權重
└── config.json ← 記錄類別數與類別名稱

config.json 內容如下:

1
2
3
4
5
{
"model": "resnet18",
"num_classes": 2,
"classes": ["cat", "dog"]
}

推論時先讀 config.json,就不會忘記當初訓練的類別數與名稱。

🔎 訓練流程總覽

在看程式碼之前,先了解整個流程的每個步驟在做什麼:


圖:PyTorch 模型訓練流程總覽 ─ 從資料前處理到儲存最佳模型的完整六大步驟

💡 train.py 的程式碼區塊順序就對應以上六個步驟,對照著看會清楚很多。

💻 完整範例程式

採用 Feature Extraction 策略:凍結所有預訓練層,只訓練替換後的分類頭。每類資料量在 200 張以下時,這是最穩健的起手方式。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# train.py
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader

# ── 超參數設定 ──────────────────────────────
BATCH_SIZE = 32
NUM_EPOCHS = 15
LEARNING_RATE = 1e-3
NUM_CLASSES = 2 # 依實際類別數修改
TASK_NAME = "cat_dog" # 換任務只改這一行
DATA_DIR = f"data/{TASK_NAME}"
SAVE_DIR = f"models/resnet18/{TASK_NAME}"
# ────────────────────────────────────────────

os.makedirs(SAVE_DIR, exist_ok=True)

# ① 資料前處理
train_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(10),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])

# ② 資料載入
train_dataset = datasets.ImageFolder(f"{DATA_DIR}/train", transform=train_transform)
val_dataset = datasets.ImageFolder(f"{DATA_DIR}/val", transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

print(f"類別:{train_dataset.classes}")
print(f"訓練集:{len(train_dataset)} 張,驗證集:{len(val_dataset)} 張")

# ③ 模型建立(Feature Extraction:凍結預訓練層,只訓練分類頭)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用裝置:{device}")

model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

for param in model.parameters():
param.requires_grad = False

model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model = model.to(device)

# ④ 損失函數與優化器(只優化分類頭)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# ⑤ 訓練迴圈
best_val_acc = 0.0

for epoch in range(NUM_EPOCHS):
# 訓練階段
model.train()
running_loss = 0.0
correct = total = 0

for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)

optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()

running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

train_acc = correct / total * 100
train_loss = running_loss / len(train_loader)

# 驗證階段
model.eval()
val_correct = val_total = 0

with torch.no_grad():
for inputs, labels in val_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
_, predicted = torch.max(outputs, 1)
val_total += labels.size(0)
val_correct += (predicted == labels).sum().item()

val_acc = val_correct / val_total * 100
scheduler.step()

print(f"Epoch [{epoch+1:02d}/{NUM_EPOCHS}] "
f"Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%")

# ⑥ 儲存最佳模型
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), f"{SAVE_DIR}/best_model.pth")
config = {
"model": "resnet18",
"num_classes": NUM_CLASSES,
"classes": train_dataset.classes
}
with open(f"{SAVE_DIR}/config.json", "w") as f:
json.dump(config, f, ensure_ascii=False, indent=2)
print(f" → 已儲存最佳模型 (Val Acc: {best_val_acc:.2f}%)")

print(f"\n訓練完成,最佳驗證準確率:{best_val_acc:.2f}%")


圖:凍結 ResNet18 預訓練層後只訓練分類頭,完整執行訓練與驗證迴圈並儲存最佳模型

💻 單張圖片推論

訓練完成後,用以下程式驗證模型對單張圖片的推論結果。測試圖片可直接使用驗證集中的任一張,例如 data/cat_dog/val/cat/000001.jpg

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# inference.py
import json
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image

MODEL_DIR = "models/resnet18/cat_dog" # 與 train.py 的 SAVE_DIR 對應
IMAGE_PATH = "data/cat_dog/val/cat/000001.jpg" # 替換為實際測試圖片路徑

# 從 config.json 讀取類別資訊,不需要手動填寫
with open(f"{MODEL_DIR}/config.json") as f:
config = json.load(f)
NUM_CLASSES = config["num_classes"]
CLASS_NAMES = config["classes"]

model = models.resnet18(weights=None)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model.load_state_dict(torch.load(f"{MODEL_DIR}/best_model.pth", map_location="cpu"))
model.eval()

transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])

img = Image.open(IMAGE_PATH).convert("RGB")
input_tensor = transform(img).unsqueeze(0)

with torch.no_grad():
outputs = model(input_tensor)
probs = torch.nn.functional.softmax(outputs[0], dim=0)
conf, pred = torch.max(probs, 0)

print(f"預測類別:{CLASS_NAMES[pred]},信心度:{conf:.4f}")


圖:載入已訓練的 ResNet18 模型對單張圖片進行推論並輸出預測類別與信心度

⚠️ 注意事項

  • 換任務時 NUM_CLASSES 要同步修改TASK_NAME 只控制資料路徑與模型儲存位置,類別數需手動更新。兩者不一致會導致模型輸出維度錯誤,訓練時直接報錯。
  • num_workers 在 Windows 上設 0:設成其他數值在 Windows 上可能造成 DataLoader 死鎖。
  • 驗證集不做資料增強val_transform 只做 Resize 與 Normalize,不做翻轉或旋轉,確保評估結果穩定。
  • model.eval()torch.no_grad():驗證與推論時都必須同時呼叫,前者關閉 Dropout/BatchNorm 的訓練行為,後者節省記憶體並加速計算。
  • scheduler.step() 的位置:需在每個 epoch 結束後呼叫,而非每個 batch 後。

📊 應用場景

  • 寵物品種辨識:以本篇的 cat/dog 分類為基礎,擴展到更多品種,只需調整 TASK_NAMENUM_CLASSES 與對應的訓練資料即可。
  • 自訂商品分類:以少量商品圖片微調,建立特定品類辨識模型。
  • 工廠品管:辨識生產線上的良品與瑕疵品,每個產品線建立獨立的任務目錄,互不干擾。

🎯 結語

這篇完整示範了 Feature Extraction 策略在 PyTorch 的實作,從資料集準備、模型建立、訓練迴圈到推論,每個環節都對應上一篇介紹的原理。
跑完整個流程之後,如果想進一步提升準確率,可以試著按照上一篇的建議逐步解凍 layer4,加入 Fine-Tuning。

下一步是 TensorFlow/Keras 微調範例,換一個框架做同樣的事,對比之下會更容易理解兩者的差異。

📖 如在學習過程中遇到疑問,或是想了解更多相關主題,建議回顧一下 Python | OpenCV 系列導讀,掌握完整的章節目錄,方便快速找到你需要的內容。

註:以上參考了
PyTorch 官方文件 — Transfer Learning Tutorial
PyTorch 官方文件 — torchvision.datasets
PyTorch 官方文件 — torch.optim