Od zadania biznesowego do problemu klasyfikacji obrazów
Co chcesz rozpoznać na obrazie?
Bez jasno nazwanego celu technicznego klasyfikacja obrazów w PyTorch szybko zamienia się w eksperymenty bez końca. Od czego zaczynasz: od problemu technicznego („zbuduję CNN”), czy od biznesowego („chcę automatycznie odrzucać wadliwe produkty na linii”)?
Najprościej myśleć o klasyfikacji obrazów jako o przypisaniu każdemu obrazowi jednej lub kilku etykiet z ograniczonego zbioru. Żeby to zadziałało, klasy muszą być praktyczne. Przykład z produkcji: masz zdjęcia butelek z linii i chcesz wykrywać wady. Możliwe klasy:
- OK – butelka bez widocznych wad, dopuszczona do sprzedaży,
- Pęknięcie – widoczne pęknięcie szkła,
- Brud – zabrudzenie powierzchni,
- Niedolewka – poziom cieczy poniżej normy.
Technicznie możesz stworzyć kilkanaście klas („mikropęknięcia powierzchniowe”, „pęknięcia na dnie”, „zarysowania” itd.), ale czy ktoś z produkcji realnie będzie tego używał? W większości przypadków wystarczy podział: „OK” vs „zła” + ewentualnie kilka głównych typów wady. Im prostszy, dobrze zdefiniowany podział, tym łatwiej zebrać dane i utrzymać model.
Z drugiej strony, zbyt mało klas też bywa pułapką. Jeśli masz sklep internetowy i chcesz oznaczać kategorię produktu na zdjęciu, dwie klasy typu „odzież” i „nie odzież” niewiele dają. Realny podział to np. „buty”, „spodnie”, „koszulki”, „sukienki”. Pytanie do ciebie: jaki poziom szczegółowości klas naprawdę daje wartość użytkownikowi lub biznesowi?
Kolejna rzecz: balans klas i jakość etykiet. Jak wygląda rozkład twoich danych? Masz 95% klasy „OK” i po kilka sztuk rzadkich wad? Taki zbiór da bardzo wysoki accuracy przy kompletnie bezużytecznym modelu, który „zawsze mówi OK”. Kluczem są:
- sensowne proporcje klas (lub odpowiednie ważenie strat, jeśli balans jest naturalnie zaburzony),
- spójne etykietowanie – ten sam typ obiektu nie może raz być „OK”, a raz „pęknięcie” zależnie od osoby oznaczającej.
Zanim napiszesz linię kodu w PyTorch, odpowiedz sobie: czy masz choć 100–200 przykładów na każdą klasę oraz jasną definicję klasy?
Czy klasyfikacja to na pewno właściwe podejście?
Nie każdy problem z obrazami da się sensownie zamknąć w czystej klasyfikacji. Zadaj sobie trzy krótkie pytania:
- Czy na jednym obrazie jest zwykle jeden dominujący obiekt, który cię interesuje?
- Czy lokalizacja obiektu nie ma znaczenia (ważne tylko „co”, a nie „gdzie”)?
- Czy dokładny kształt obiektu nie jest kluczowy (wystarczy klasa, nie maska)?
Jeśli odpowiedź na wszystkie jest „tak” – klasyfikacja jest dobrym startem. Jeśli nie – możesz potrzebować detekcji lub segmentacji.
Krótko porównując:
- Klasyfikacja – jeden (lub kilka) globalnych labeli na obraz, bez współrzędnych. Przykład: rozpoznanie, czy zdjęcie jest „dzień” czy „noc”.
- Detekcja obiektów – klasy + prostokątne ramki (bounding boxes) wskazujące gdzie jest obiekt. Przykład: liczenie samochodów na parkingu.
- Segmentacja – etykieta piksela, zwłaszcza przy segmentacji semantycznej lub instancyjnej. Przykład: zaznaczenie dokładnego obrysu guza na obrazie medycznym.
Graniczne przypadki pojawiają się przy zdjęciach z wieloma obiektami. Jeśli na jednym zdjęciu są trzy psy i dwa koty, a ty tylko chcesz skrzynkować zdjęcie w kategoriach „jest pies” / „jest kot”, klasyczna klasyfikacja typu multi-class się nie sprawdzi. Potrzebujesz wtedy multi-label classification (obraz może mieć kilka etykiet jednocześnie, np. [pies=1, kot=1]).
Jak to rozpoznać? Znowu pytanie: czy jeden obraz może „należeć” do kilku klas na raz, czy tylko do jednej? Jeśli tylko do jednej – multi-class. Jeśli do kilku – multi-label. W PyTorch różnica przełoży się na:
- inną funkcję straty (np.
nn.CrossEntropyLossvsnn.BCEWithLogitsLoss), - inną postać etykiet (indeks klasy vs wektor 0/1),
- inny sposób interpretacji wyjścia modelu.
Jakiego wyniku oczekujesz?
Cel użycia modelu wpływa na wszystko: wybór architektury, strategię trenowania, a nawet sposób zapisywania modelu. Zapytaj sam siebie: robisz POC, MVP czy system produkcyjny?
- POC (proof of concept) – ma udowodnić, że da się coś zrobić. Tu wystarczy działający notebook, pretrenowany ResNet, kilka tysięcy przykładów. Chodzi o szybkość, nie perfekcję.
- MVP – pierwsza wersja dla użytkowników. Musisz już zadbać o stabilność, sensowną ewaluację, zapis modelu, pipeline inferencji.
- Produkcja – wymagania rosną: monitoring, retrening, obsługa błędów, integracja z istniejącą infrastrukturą.
Do tego dochodzą ograniczenia sprzętowe. Na czym będziesz trenować i wdrażać modele?
- CPU laptopa – rozbudowany ResNet50 może być koszmarem, lepiej celować w mniejsze architektury (MobileNet, ResNet18) lub mniejsze rozdzielczości.
- Pojedyncze GPU – największy sweet spot; większość standardowych modeli zejdzie w rozsądnym czasie.
- Edge / mobilnie – model musi być lekki, często kwantyzowany (int8), ograniczona pamięć.
Warto też z góry określić oczekiwane metryki. Czy wystarczy ci accuracy, czy potrzebujesz precyzji/recall dla każdej klasy, może F1 dla klasy „wada”? Dla mocno niezbalansowanych danych accuracy może wprowadzać w błąd. Dla klas rzadkich bardziej przydatna bywa recall (ile wad wykrywamy) niż ogólny accuracy.
Gdy znasz swój cel, łatwiej zdecydować, czy iść w prostą CNN zbudowaną od zera, czy w mocny transfer learning z ResNet, oraz ile pracy trzeba włożyć w przygotowanie danych i infrastrukturę.

Środowisko pracy i struktura projektu w PyTorch
Konfiguracja środowiska (GPU, biblioteki, wersje)
Zanim uruchomisz pierwsze importy PyTorch, warto uporządkować środowisko. W jakim środowisku pracujesz teraz – lokalny komputer, serwer, Colab?
Minimalny setup to:
- Python 3.9+ (lub inna wspierana przez twoją wersję PyTorch),
- PyTorch (
torch) w wersji kompatybilnej z twoim sprzętem, torchvision– modele i transformacje do obrazów,- opcjonalnie CUDA, jeśli masz GPU NVIDII.
Przykładowa instalacja (dla GPU, wersje dopasuj do aktualnej dokumentacji PyTorch):
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
Jeśli nie masz GPU lub nie chcesz się z nim teraz zmagać:
pip install torch torchvision torchaudio
Przydaje się prosty test w Pythonie:
import torch
print(torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
print("Device name:", torch.cuda.get_device_name(0))
Typowe problemy instalacyjne to:
- konflikt wersji CUDA z driverami,
- instalacja
torchz nieoficjalnych źródeł, - mieszanie środowisk (globalny Python + venv + conda).
Dobrym nawykiem jest osobne środowisko dla każdego większego projektu. Przykładowy plik requirements.txt:
torch
torchvision
numpy
pillow
matplotlib
tqdm
pyyaml
W conda możesz użyć:
conda create -n pytorch-images python=3.10
conda activate pytorch-images
pip install -r requirements.txt
Organizacja katalogów projektu
Nie potrzeba rozbudowanego frameworka, żeby projekt klasyfikacji obrazów w PyTorch był przejrzysty. Wystarczy kilka czytelnych folderów. Jak obecnie przechowujesz swoje dane i skrypty – wszystko w jednym katalogu, czy masz już strukturę?
Przykładowa struktura:
project_root/
data/
raw/
processed/
train/
class_0/
class_1/
...
val/
class_0/
class_1/
test/
class_0/
class_1/
src/
data/
dataset.py
transforms.py
models/
simple_cnn.py
resnet_transfer.py
training/
train.py
scheduler.py
evaluation/
metrics.py
visualize.py
utils/
config.py
misc.py
configs/
baseline.yaml
resnet_experiment.yaml
logs/
runs/
checkpoints/
baseline_epoch10.pt
resnet_best.pt
notebooks/
Najważniejsze punkty:
- dane – logiczny podział na
train/val/test, klasy jako podfoldery (dlaImageFolder), - moduły – osobne pliki na definicje modeli, dane, trening, ewaluację,
- configi – parametry doświadczeń (lrate, batch_size, architektura) poza kodem, np. YAML,
- checkpointy – zapisy stanu modelu i optymalizatora, najlepiej z nazwą modelu i metryką.
Taki porządek ułatwia nie tylko debugging, ale też powrót do starego eksperymentu po kilku tygodniach. Pytanie pomocnicze: czy wiesz za pół roku, jak odtworzyć konkretny wynik, który dziś „wyszedł dobrze”?
Reproducibility – powtarzalność eksperymentów
Powtarzalność w deep learningu jest względna, ale da się ograniczyć losowość. Przy każdym nowym eksperymencie warto:
- ustawić seedy dla
random,numpyitorch, - zapisać konfigurację eksperymentu (np. JSON/YAML),
- zapisać wersje bibliotek (np.
pip freeze > env.txt).
Przykładowy kod z seedami:
import random
import numpy as np
import torch
def set_seed(seed: int = 42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Opcjonalnie bardziej deterministycznie kosztem wydajności:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
set_seed(42)
Deterministyczne vs szybkie trenowanie to kwestia kompromisu. Ustawienie deterministic=True wpływa na to, że wyniki powinny być bliższe sobie między uruchomieniami, ale często spowalnia operacje na GPU. Przy POC zwykle warto iść w szybkość, przy poważnych eksperymentach – w kontrolowaną losowość.
Konfigurację eksperymentu zapisuj w prostym pliku YAML:
model: resnet18
num_classes: 4
batch_size: 32
learning_rate: 0.001
epochs: 20
input_size: 224
W kodzie możesz to wczytać i logować do pliku obok checkpointu. Dzięki temu za jakiś czas szybko sprawdzisz, jakich parametrów użyłeś dla konkretnego modelu.

Przygotowanie danych: od surowych obrazów do DataLoadera
Organizacja i wczytywanie danych obrazowych
Struktura danych to pierwsza decyzja, która potrafi ułatwić albo utrudnić sobie życie. Jak masz teraz swoje obrazy – w jednym folderze, z etykietami w CSV, czy już podzielone na klasy?
Najprostszy scenariusz to użycie torchvision.datasets.ImageFolder. Wymaga on struktury:
data/train/
class0/
img001.jpg
img002.jpg
class1/
img101.jpg
data/val/
class0/
class1/
Kod wczytywania jest wtedy bardzo prosty:
from torchvision import datasets, transforms
train_dataset = datasets.ImageFolder(
root="data/train",
transform=train_transforms
)
val_dataset = datasets.ImageFolder(
root="data/val",
transform=val_transforms
)
print(train_dataset.classes) # nazwy klas
print(train_dataset.class_to_idx) # mapowanie nazwa -> indeks
ImageFolder sprawdza się świetnie, gdy:
- etykieta = nazwa folderu,
- podział na train/val/test robisz „ręcznie” lub osobnym skryptem,
- nie masz niestandardowych źródeł (bazy danych, archiwa zip, adnotacje w JSON).
Własny Dataset w PyTorch przydaje się, gdy:
- etykiety masz w pliku CSV, JSON albo w bazie danych,
- chcesz dynamicznie generować obrazy (np. wycinać regiony z dużych zdjęć),
Własny Dataset – gdy ImageFolder nie wystarcza
Jeśli twoje dane nie mieszczą się w prostym schemacie „folder = klasa”, potrzebny jest własny Dataset. Masz etykiety w CSV, wiele etykiet na obraz albo osobne maski segmentacji? Wtedy przejmujesz kontrolę.
Jądrem własnego datasetu są trzy elementy: lista ścieżek, lista etykiet i transformacje. Najpierw przygotuj prosty indeks – np. CSV:
image_path,label
images/img001.jpg,0
images/img002.jpg,1
images/img003.jpg,0
Następnie definiujesz klasę:
import csv
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
class CSVDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
self.root_dir = Path(root_dir)
self.transform = transform
self.samples = []
with open(csv_file, newline="") as f:
reader = csv.DictReader(f)
for row in reader:
img_path = self.root_dir / row["image_path"]
label = int(row["label"])
self.samples.append((img_path, label))
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
img_path, label = self.samples[idx]
image = Image.open(img_path).convert("RGB")
if self.transform:
image = self.transform(image)
return image, label
Gdzie może tobie się to przydać?
- gdy masz dane w chmurze i generujesz tymczasowe ścieżki,
- gdy łączysz kilka zbiorów (np. dane klienta + publiczne benchmarki),
- gdy do każdego obrazu chcesz dodać dodatkowe informacje (np. metadane).
Wiele osób przeciąga moment napisania własnego datasetu – a to często prostsze rozwiązanie niż walka z dopasowaniem danych do ImageFolder. Co u ciebie jest źródłem prawdy o etykietach – folder, CSV, baza?
Transformacje i augmentacja danych
Transformacje decydują, na jakich danych naprawdę trenuje model. Surowe obrazy to dopiero początek. Jak bardzo różnią się obrazy w produkcji od tych, które masz teraz w zbiorze?
Typowa sekwencja transformacji obejmuje:
- zmianę rozmiaru i przycięcie,
- konwersję do tensora,
- normalizację,
- augmentację (dla train).
from torchvision import transforms
input_size = 224
train_transforms = transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(degrees=10),
transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2,
hue=0.02
),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
val_transforms = transforms.Compose([
transforms.Resize((input_size, input_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
Normalizacja powyżej używa standardowych statystyk z ImageNet – przydają się szczególnie, gdy korzystasz z modeli pretrenowanych. Jeśli trenujesz od zera i masz dużo własnych danych, możesz policzyć średnią i odchylenie samodzielnie.
Augmentacja powinna symulować realne zniekształcenia: inną orientację, lekki blur, szum, różne oświetlenie. Przykład z praktyki: jeśli w fabryce kamera czasem jest lekko przekrzywiona, RandomRotation naprawdę ratuje wyniki. Zadaj sobie pytanie: jakie „błędy” pojawiają się w twoich zdjęciach na produkcji?
DataLoader – batchowanie, shuffle, num_workers
Kiedy masz już dataset, następnym krokiem jest DataLoader. To on odpowiada za batching, mieszanie próbek, równoległe ładowanie:
from torch.utils.data import DataLoader
batch_size = 32
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True
)
Kilka praktycznych wskazówek:
shuffle=Truetylko dla train – zapobiega przeuczeniu na kolejności danych,num_workerseksperymentalnie – na Windows często <= 4, na Linuksie możesz spokojnie próbować 8–16,pin_memory=Trueprzyspiesza kopiowanie na GPU, gdy używasz CUDA.
Przy pierwszym uruchomieniu warto sprawdzić, czy dataloader działa jak trzeba:
images, labels = next(iter(train_loader))
print(images.shape) # spodziewane: [batch_size, 3, H, W]
print(labels.shape) # [batch_size]
Jeśli ostatnia paczka jest mniejsza (np. 10 zamiast 32), wszystko jest w porządku – to zachowanie domyślne. Gdy z jakiegoś powodu potrzebujesz zawsze pełnych batchy (np. dla DistributedDataParallel), użyj drop_last=True.
Radzenie sobie z niezbalansowanymi klasami
Gdy jedna klasa dominuje, model szybko „uczy się” przewidywać ją zawsze. Accuracy rośnie, biznesowo – porażka. Jak wygląda rozkład twoich klas?
Najprostsza diagnoza to policzenie liczby próbek w każdej klasie:
from collections import Counter
labels = [label for _, label in train_dataset.samples]
class_counts = Counter(labels)
print(class_counts)
Jeśli różnice są duże, masz kilka opcji:
- zwiększenie liczby danych dla rzadkich klas (oversampling, więcej augmentacji),
- ważenie klas w funkcji kosztu,
- użycie
WeightedRandomSampler.
Przykład z ważonym loss’em (CrossEntropyLoss):
import torch
import torch.nn as nn
class_counts = torch.tensor([class_counts[i] for i in range(len(class_counts))])
class_weights = 1.0 / class_counts.float()
class_weights = class_weights / class_weights.sum() * len(class_weights) # skalowanie
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
Albo sampler, który częściej wybiera rzadkie klasy:
from torch.utils.data import WeightedRandomSampler
samples_labels = [label for _, label in train_dataset.samples]
class_counts = Counter(samples_labels)
class_weights = {cls: 1.0 / count for cls, count in class_counts.items()}
sample_weights = [class_weights[label] for label in samples_labels]
sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
sampler=sampler,
num_workers=4,
pin_memory=True
)
Jeżeli twoim głównym celem jest „łapanie” rzadkich zdarzeń (np. defektów), lepiej skupić się na recallu dla tych klas niż na globalnym accuracy.
Podgląd i sanity check danych
Zanim przepuścisz dane przez model, dobrze jest po prostu je zobaczyć. Czy transformacje nie odwróciły kolorów, nie obcięły istotnej części elementu? Co byś zobaczył, otwierając kilka losowych obrazów z train_loadera?
import matplotlib.pyplot as plt
import torchvision
def show_batch(images, labels, classes, nrow=8):
grid = torchvision.utils.make_grid(images[:nrow], nrow=nrow, normalize=True, padding=2)
plt.figure(figsize=(12, 3))
plt.imshow(grid.permute(1, 2, 0).cpu())
plt.axis("off")
print("Labels:", [classes[l] for l in labels[:nrow]])
images, labels = next(iter(train_loader))
show_batch(images, labels, train_dataset.classes)
Jedno takie sprawdzenie potrafi wykryć problemy, które inaczej zauważyłbyś dopiero po wielu godzinach treningu: źle załadowane kanały, pomylone etykiety, niepoprawną normalizację.
Definiowanie modeli: od prostej CNN po transfer learning
Prosta konwolucyjna sieć od zera
Na początek warto mieć bazową architekturę, którą rozumiesz od A do Z. Dzięki temu łatwiej diagnozować problemy, zanim wskoczysz w ciężki transfer learning. Jaką masz teraz intuicję o tym, co robi warstwa konwolucyjna?
Przykładowa, prosta CNN dla obrazów 3×64×64:
import torch.nn as nn
import torch.nn.functional as F
class SimpleCNN(nn.Module):
def __init__(self, num_classes: int):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 32 x 32 x 32
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 64 x 16 x 16
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2), # 128 x 8 x 8
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 8 * 8, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
Taki model jest relatywnie mały i nadaje się na start dla własnych, niewielkich datasetów. Jeśli twoje obrazy mają inny rozmiar niż 64×64, odpowiednio zmień pooling lub policz nowy rozmiar wejścia do warstwy liniowej.
Sprawdzenie rozmiarów (zanim cokolwiek uruchomisz na poważnie):
model = SimpleCNN(num_classes=4)
x = torch.randn(1, 3, 64, 64)
out = model(x)
print(out.shape) # spodziewane: [1, 4]
Jeżeli dostajesz błąd w warstwie liniowej, to zwykle znak, że rozmiary po conv/pool nie zgadzają się z tym, co założyłeś. Policz to ręcznie albo dodaj tymczasowe print(x.shape) w metodzie forward.
Konfiguracja urządzenia i przenoszenie modelu na GPU
Nawet najlepsza architektura nic nie da, jeśli model „zostanie” na CPU, a dane na GPU albo odwrotnie. Widziałeś kiedyś błąd „expected device cuda:0 but got cpu”? To właśnie ten przypadek.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=num_classes).to(device)
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
...
Kluczowe jest konsekwentne wysyłanie wszystkiego na to samo urządzenie. Jeden pominięty .to(device) potrafi zatrzymać trening na długo.
Definicja pętli treningowej i walidacyjnej
Warto zbudować jedną, prostą pętlę treningową, której będziesz używać niezależnie od modelu. Potem tylko podmieniasz architekturę. Jak wygląda twoja obecna pętla – masz ją spójną czy kopiujesz kod między notatnikami?
import torch.optim as optim
from tqdm import tqdm
def train_one_epoch(model, train_loader, criterion, optimizer, device):
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in tqdm(train_loader, desc="Train", leave=False):
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * images.size(0)
_, preds = outputs.max(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
epoch_loss = running_loss / total
epoch_acc = correct / total
return epoch_loss, epoch_acc
@torch.no_grad()
def evaluate(model, val_loader, criterion, device):
model.eval()
running_loss = 0.0
correct = 0
total = 0
for images, labels in tqdm(val_loader, desc="Val", leave=False):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
running_loss += loss.item() * images.size(0)
_, preds = outputs.max(1)
correct += (preds == labels).sum().item()
total += labels.size(0)
epoch_loss = running_loss / total
epoch_acc = correct / total
return epoch_loss, epoch_acc
Prosty loop po epokach:
num_epochs = 10
best_val_acc = 0.0
for epoch in range(num_epochs):
train_loss, train_acc = train_one_epoch(
model, train_loader, criterion, optimizer, device
)
val_loss, val_acc = evaluate(
model, val_loader, criterion, device
)
print(
f"Epoch {epoch+1}/{num_epochs} "
f"- train_loss: {train_loss:.4f}, train_acc: {train_acc:.4f}, "
f"val_loss: {val_loss:.4f}, val_acc: {val_acc:.4f}"
)
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), "checkpoints/best_simple_cnn.pt")
Jedna rada z praktyki: loguj metryki nie tylko do konsoli, ale i do pliku albo narzędzia typu TensorBoard/W&B. Po kilku dniach intensywnych eksperymentów trudno odtworzyć historię „z głowy”.
Transfer learning z modelami torchvision
Gdy masz ograniczony dataset, znacznie lepsze rezultaty niż „goła” CNN daje transfer learning. Wykorzystujesz modele trenowane na ImageNet i dopasowujesz je do swojego zadania. Jak duży masz dziś zbiór – kilkaset, kilka tysięcy, czy więcej obrazów na klasę?
Dla klasycznej klasyfikacji obrazów wygodnym wyborem jest resnet18:
ResNet18 – od wczytania do dostosowania pod własne klasy
Jeżeli masz typowe zdjęcia RGB i liczba klas jest relatywnie mała, resnet18 to solidny punkt startowy. Najpierw wczytanie modelu z wagami uczonymi na ImageNet:
from torchvision import models
num_classes = len(train_dataset.classes)
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
# liczba cech na wejściu do ostatniej warstwy
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
model = model.to(device)
Takie podejście trenuje cały model (tzw. fine-tuning end-to-end). Jeżeli masz mało danych lub trening jest zbyt wolny, spróbuj najpierw zamrozić wcześniejsze warstwy:
for param in model.parameters():
param.requires_grad = False
# odmrażamy tylko ostatnią warstwę
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
for param in model.fc.parameters():
param.requires_grad = True
model = model.to(device)
optimizer = optim.Adam(model.fc.parameters(), lr=1e-3)
Jakim budżetem czasowym dysponujesz? Jeżeli trenujesz na CPU lub słabym GPU, zamrożenie większości sieci potrafi skrócić trening kilkukrotnie.
Dostosowanie transformacji pod modele z ImageNet
Większość gotowych architektur z torchvision zakłada konkretne przeskalowanie i normalizację (średnia i odchylenie z ImageNet). Zamiast przepisywać te wartości ręcznie, możesz użyć gotowych presetów:
from torchvision import transforms, models
weights = models.ResNet18_Weights.IMAGENET1K_V1
preprocess = weights.transforms() # automatycznie: Resize, CenterCrop, ToTensor, Normalize
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=weights.meta["mean"], std=weights.meta["std"]),
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=weights.meta["mean"], std=weights.meta["std"]),
])
Jeżeli wcześniej używałeś transformacji pod 64×64, musisz przebudować ImageFolder z nowymi transformacjami i stworzyć nowe DataLoadery. Masz już to zautomatyzowane w funkcji czy zmieniasz ręcznie w kilku miejscach?
Strategia fine-tuningu: zamrozić, odmrażać czy trenować wszystko?
Dobrą praktyką jest stopniowe „odmrażanie” sieci, zamiast od razu trenować wszystkie warstwy. Trzy często stosowane warianty:
- Frozen backbone – uczysz tylko ostatnią warstwę; szybkie, dobre na start diagnostyczny.
- Partial fine-tuning – odmrażasz kilka ostatnich bloków (np.
layer4w ResNet). - Full fine-tuning – trenujesz całość przy mniejszym learning rate.
Przykład odmrażania tylko ostatniego bloku ResNet:
model = models.resnet18(weights=weights)
for param in model.parameters():
param.requires_grad = False
# odmrażamy layer4 i fc
for param in model.layer4.parameters():
param.requires_grad = True
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
params_to_optimize = list(model.layer4.parameters()) + list(model.fc.parameters())
optimizer = optim.Adam(params_to_optimize, lr=1e-4)
Jeżeli widzisz, że model bardzo szybko dochodzi do ~80–90% accuracy i potem stoi w miejscu, spróbuj albo odmrozić więcej warstw, albo zejść z learning rate i wydłużyć trening.
Schedule learning rate i early stopping
Przy fine-tuningu duże skoki learning rate potrafią „psuć” pretrenowane wagi. Stąd przydają się proste schedulery i wstrzymywanie treningu, gdy walidacja się nie poprawia.
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode="max", # optymalizujemy val_acc
factor=0.1,
patience=3,
verbose=True
)
best_val_acc = 0.0
patience = 7
epochs_no_improve = 0
for epoch in range(num_epochs):
train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
val_loss, val_acc = evaluate(model, val_loader, criterion, device)
scheduler.step(val_acc)
if val_acc > best_val_acc:
best_val_acc = val_acc
epochs_no_improve = 0
torch.save(model.state_dict(), "checkpoints/best_resnet18.pt")
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
print("Early stopping triggered")
break
Czy już korzystasz z jakiegokolwiek scheduler’a? Jeżeli nie, zacznij właśnie od ReduceLROnPlateau – jest prosty i dobrze reaguje na stagnację metryki.
Diagnostyka: learning curves, confusion matrix, per-class metrics
Sam accuracy często nie mówi całej prawdy. Jeżeli biznesowo interesują cię konkretne klasy (np. „defekt”), musisz patrzeć na metryki per-klasa.
Podstawowy zrzut predykcji i etykiet z walidacji:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
@torch.no_grad()
def get_all_preds(model, data_loader, device):
model.eval()
all_labels = []
all_preds = []
for images, labels in data_loader:
images = images.to(device)
outputs = model(images)
_, preds = outputs.max(1)
all_labels.append(labels.cpu().numpy())
all_preds.append(preds.cpu().numpy())
return np.concatenate(all_labels), np.concatenate(all_preds)
y_true, y_pred = get_all_preds(model, val_loader, device)
print(classification_report(y_true, y_pred, target_names=train_dataset.classes))
Jeżeli widzisz, że jedna klasa ma recall ~0, zastanów się: czy masz jej wystarczająco w danych? Czy augmentacja nie „psuje” jej charakterystycznych cech? W takich sytuacjach często pomaga:
- silniejsza augmentacja tylko dla większościowych klas,
- doczytanie dodatkowych przykładów trudnej klasy,
- zmiana progu decyzyjnego (np. zamiast argmax używasz progu na prawdopodobieństwo dla danej klasy).
Prosta macierz pomyłek do wizualnej inspekcji:
import seaborn as sns
import matplotlib.pyplot as plt
cm = confusion_matrix(y_true, y_pred, normalize="true")
plt.figure(figsize=(6, 5))
sns.heatmap(cm, annot=True, cmap="Blues",
xticklabels=train_dataset.classes,
yticklabels=train_dataset.classes,
fmt=".2f")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.tight_layout()
plt.show()
Interpretacja predykcji: najbardziej pewne i najbardziej mylące przykłady
Gdy model praktycznie „stoi” z jakością, dobrze jest obejrzeć obrazy, na których najbardziej się myli lub jest najmniej pewny. Na tej podstawie możesz stwierdzić, czy potrzebujesz lepszych danych, innej etykiety, czy zmiany architektury.
@torch.no_grad()
def collect_predictions_with_scores(model, data_loader, device):
model.eval()
all_images = []
all_labels = []
all_preds = []
all_probs = []
for images, labels in data_loader:
images = images.to(device)
outputs = model(images)
probs = torch.softmax(outputs, dim=1)
confs, preds = probs.max(1)
all_images.append(images.cpu())
all_labels.append(labels)
all_preds.append(preds.cpu())
all_probs.append(confs.cpu())
return (
torch.cat(all_images),
torch.cat(all_labels),
torch.cat(all_preds),
torch.cat(all_probs),
)
images, labels, preds, confs = collect_predictions_with_scores(
model, val_loader, device
)
# indeksy najbardziej mylnych, ale bardzo pewnych predykcji
wrong = preds != labels
wrong_indices = torch.where(wrong)[0]
sorted_wrong = wrong_indices[confs[wrong_indices].argsort(descending=True)]
top_k = 16
to_show = sorted_wrong[:top_k]
def show_misclassified(images, labels, preds, idxs, classes, nrow=4):
subset = images[idxs]
grid = torchvision.utils.make_grid(subset, nrow=nrow, normalize=True, padding=2)
plt.figure(figsize=(12, 8))
plt.imshow(grid.permute(1, 2, 0))
plt.axis("off")
print("True:", [classes[labels[i]] for i in idxs])
print("Pred:", [classes[preds[i]] for i in idxs])
show_misclassified(images, labels, preds, to_show, train_dataset.classes)
Na takich obrazach często widać realne problemy: źle opisane klasy, niejednoznaczne przypadki, różnice między zdjęciami z produkcji a tymi z datasetu.
Przygotowanie modelu do inference: tryb eval, eksport i wersjonowanie
Kiedy masz już zadowalający model, przychodzi moment na włączenie go w realny przepływ danych. Masz już określone, jak ten model będzie używany: batchowo, online, czy w aplikacji mobilnej?
Najprostszy scenariusz: serwis w Pythonie, który ładuje zapisane wagi i przyjmuje ścieżkę do obrazu.
from PIL import Image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model_checkpoint(path, num_classes):
model = models.resnet18(weights=None)
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
state_dict = torch.load(path, map_location=device)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
def predict_image(path, model, transform, classes):
img = Image.open(path).convert("RGB")
x = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)
conf, pred = probs.max(1)
return classes[pred.item()], conf.item()
Rzut oka na zachowanie modelu na pojedynczym obrazie z produkcji potrafi ujawnić rozjazd między train/val a realnym światem – np. inne oświetlenie, rozdzielczość, kompresję.
Eksport do TorchScript i ONNX
Jeżeli model ma trafić do środowiska, gdzie nie chcesz trzymać całego Pythona (np. C++, serwis wysokowydajny), przydaje się eksport do TorchScript albo ONNX. Najpierw prosty TorchScript:
example_input = torch.randn(1, 3, 224, 224).to(device)
traced = torch.jit.trace(model, example_input)
traced.save("artifacts/resnet18_traced.pt")
Do ONNX, z myślą o inference np. w ONNX Runtime:
onnx_path = "artifacts/resnet18.onnx"
dummy_input = torch.randn(1, 3, 224, 224, device=device)
torch.onnx.export(
model,
dummy_input,
onnx_path,
input_names=["input"],
output_names=["logits"],
dynamic_axes={"input": {0: "batch_size"}, "logits": {0: "batch_size"}},
opset_version=17
)
Przed integracją zawsze uruchom test: porównaj wyjścia modelu PyTorch i ONNX/TorchScript dla kilku losowych obrazów. Jeżeli różnice są większe niż rząd 1e-4 w logitach, coś jest nie tak z eksportem lub preprocessingiem.
Opakowanie modelu w prosty serwis HTTP (FastAPI)
Gdy środowisko docelowe to backend w Pythonie, naturalnym wyborem jest API REST. Jakim językiem planujesz obsłużyć inference – zostajesz przy Pythonie czy chcesz oddzielny serwis?
from fastapi import FastAPI, UploadFile, File
import uvicorn
import io
app = FastAPI()
weights = models.ResNet18_Weights.IMAGENET1K_V1
inference_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=weights.meta["mean"], std=weights.meta["std"]),
])
classes = train_dataset.classes # zapisz je przy treningu i wczytaj tutaj
model = load_model_checkpoint("checkpoints/best_resnet18.pt", num_classes=len(classes))
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
content = await file.read()
img = Image.open(io.BytesIO(content)).convert("RGB")
x = inference_transform(img).unsqueeze(0).to(device)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)[0]
top_prob, top_idx = torch.max(probs, dim=0)
return {
"class": classes[top_idx.item()],
"probability": float(top_prob.item())
}
# if __name__ == "__main__":
# uvicorn.run(app, host="0.0.0.0", port=8000)
Takie API możesz najpierw odpalić lokalnie i przetestować przez curl lub prosty skrypt klienta, a dopiero potem myśleć o konteneryzacji i wdrożeniu na serwerze.
Wersjonowanie modeli i kontrola eksperymentów
Gdy liczba eksperymentów rośnie, łatwo się pogubić: który checkpoint był trenowany z wagami klas, który bez augmentacji, a który na nowej wersji danych. Jak obecnie zapisujesz swoje eksperymenty – notujesz w notatniku, czy masz format nazw plików?
Prosta konwencja nazewnicza i JSON z metadanymi bardzo pomagają:
import json
from datetime import datetime
from pathlib import Path
exp_name = "resnet18_balanced_aug"
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = Path("runs") / f"{exp_name}_{timestamp}"
run_dir.mkdir(parents=True, exist_ok=True)
config = {
"model": "resnet18",
"img_size": 224,
"optimizer": "Adam",
"lr": 1e-4,
"batch_size": batch_size,
"class_weights": True,
"train_samples": len(train_dataset),
"val_samples": len(val_dataset),
}
with open(run_dir / "config.json", "w") as f:
json.dump(config, f, indent=2)
# zapisywanie najlepszego modelu
best_path = run_dir / "best.pt"
torch.save(model.state_dict(), best_path)
Później, gdy zobaczysz, że „model z września” był lepszy niż wszystko, co zbudowałeś później, łatwo odtworzysz konfigurację i ustawienia treningu.
Dostosowanie do ograniczonych zasobów: kwantyzacja, pruning, mniejsze architektury
Nie zawsze docelowa platforma to serwer z GPU. Jeżeli celem jest uruchomienie na CPU, w lekkim kontenerze lub na krawędzi, przyda się odchudzenie modelu. Jakie są twoje ograniczenia: czas predykcji, pamięć, czy obie rzeczy na raz?
Najczęściej zadawane pytania (FAQ)
Od czego zacząć projekt klasyfikacji obrazów w PyTorch: od modelu czy od problemu biznesowego?
Zacznij od sprecyzowania problemu biznesowego, a nie od wyboru architektury. Zadaj sobie pytanie: co dokładnie ma robić model na obrazie – ma odrzucać wadliwe produkty, kategoryzować zdjęcia w sklepie, czy może tylko wykrywać obecność określonego obiektu? Bez tego szybko wpadniesz w pułapkę „budowania CNN dla samego budowania”.
Dopiero gdy masz jasny cel (jakie decyzje będą podejmowane na podstawie wyniku modelu, kto z tego korzysta, w jakim procesie), przechodzisz do projektu klas, zbierania danych i dopasowania architektury. Zapytaj siebie: jaki rezultat ma pomóc użytkownikowi lub biznesowi, a co jest jedynie ciekawostką techniczną?
Ile klas powinien mieć mój model i jak dobrać ich szczegółowość?
Najpierw odpowiedz: kto i w jakim procesie będzie używał tych klas? Jeśli robisz kontrolę jakości na linii produkcyjnej, zwykle wystarcza prosty podział „OK” vs „wada” plus kilka najważniejszych typów defektów. Nadmierne rozdrobnienie („mikropęknięcia”, „zarysowania boczne”, „zarysowania dolne”) rzadko przekłada się na realną wartość, za to mocno utrudnia etykietowanie.
Z drugiej strony zbyt mało klas też niewiele daje, np. w e‑commerce: „odzież” vs „nie odzież” nie rozwiązuje problemu filtrowania produktów. Lepszy podział to „buty”, „spodnie”, „koszulki”, „sukienki” itd. Zadaj sobie kontrolne pytanie: czy użytkownik podjąłby inną decyzję, gdyby klasa była bardziej szczegółowa? Jeśli nie – prawdopodobnie możesz uprościć schemat.
Ile danych potrzebuję do klasyfikacji obrazów w PyTorch i co z niezbalansowanymi klasami?
Praktyczne minimum to przynajmniej 100–200 dobrze oznaczonych przykładów na klasę, choć im więcej, tym stabilniej będzie się trenować model. Jeśli nie jesteś jeszcze na tym etapie, zapytaj siebie: czy lepiej najpierw doprecyzować definicje klas i proces etykietowania, zamiast od razu pisać kod? Bałagan w etykietach zabija model szybciej niż mały dataset.
Przy silnie niezbalansowanych danych (np. 95% klasy „OK”, kilka procent rzadkich wad) klasyczne accuracy będzie złudnie wysokie, bo model może „zawsze mówić OK”. Wtedy:
- rozważ ważenie klas w funkcji straty lub oversampling rzadkich przypadków,
- patrz na metryki typu precision/recall dla klas rzadkich zamiast na samo accuracy.
Zadaj sobie pytanie: jakie błędy są dla ciebie groźniejsze – fałszywe alarmy czy niewykryte wady?
Jak rozpoznać, czy mój problem to klasyfikacja, detekcja, czy segmentacja?
Przejdź przez trzy krótkie pytania:
- Czy na jednym obrazie zwykle interesuje cię jeden dominujący obiekt?
- Czy nie liczy się dokładne położenie, tylko informacja „co jest na obrazku”?
- Czy nie potrzebujesz kształtu obiektu, tylko samą klasę?
Jeśli na wszystko odpowiadasz „tak”, klasyfikacja jest właściwym wyborem. W każdym innym przypadku warto rozważyć detekcję lub segmentację.
Jeśli chcesz wiedzieć, ile obiektów jest na obrazie i gdzie dokładnie się znajdują (np. samochody na parkingu) – potrzebujesz detekcji. Gdy kluczowy jest kształt (np. obrys guza na MRI) – segmentacji. Zastanów się: czy twoja decyzja biznesowa zależy od samej obecności obiektu, czy również od jego lokalizacji/rozmiaru?
Czym różni się klasyfikacja multi-class od multi-label w PyTorch i jak wybrać właściwą?
Najważniejsze pytanie kontrolne: czy jeden obraz może należeć do wielu klas jednocześnie? Jeśli nie (np. „pies” albo „kot” albo „ptak”) – używasz klasyfikacji multi-class. Jeśli tak (np. „jest pies” i jednocześnie „jest kot” na jednym zdjęciu) – potrzebujesz multi-label.
W PyTorch przekłada się to na:
- inną funkcję straty: zwykle
nn.CrossEntropyLossdla multi-class inn.BCEWithLogitsLossdla multi-label, - inną postać etykiet: pojedynczy indeks klasy vs wektor 0/1,
- inny sposób interpretacji wyjścia: softmax (prawdopodobieństwo dla każdej z wykluczających się klas) vs niezależne sigmoidy dla każdej etykiety.
Zanim wybierzesz architekturę, spisz kilka przykładowych obrazów i przypisz im etykiety „tak jak naprawdę mają wyglądać w systemie”. To bardzo szybko ujawnia, czy masz przypadki multi-label.
Jak skonfigurować środowisko do klasyfikacji obrazów w PyTorch (CPU, GPU, biblioteki)?
Najpierw odpowiedz: na czym faktycznie będziesz trenować model – laptop (CPU), pojedyncze GPU, a może środowisko typu Colab? Od tego zależy wersja PyTorch i to, czy instalujesz build z CUDA. Podstawowy zestaw to:
- Python 3.9+ (lub wersja zgodna z dokumentacją PyTorch),
torch,torchvision, opcjonalnietorchaudio,- dodatki:
numpy,pillow,matplotlib,tqdm,pyyaml.
Po instalacji od razu uruchom prosty test torch.cuda.is_available(), żeby upewnić się, że GPU jest widoczne.
Warto trzymać osobne wirtualne środowisko (venv lub conda) na każdy większy projekt i zapisać zależności w requirements.txt. Zadaj sobie pytanie: czy chcesz mieć możliwość łatwego odtworzenia setupu za pół roku na innym serwerze? Jeśli tak, kontrola wersji bibliotek i izolacja środowiska to obowiązek, nie opcja.
Jak zorganizować katalogi projektu klasyfikacji obrazów w PyTorch, żeby nie utonąć w chaosie?
Dobrze uporządkowana struktura na starcie oszczędza sporo czasu później. Dla prostego projektu wystarczy podział na:
data/– z podfolderamiraw/,processed/,train/,val/,test/, najlepiej z podkatalogami klas,src/– osobne moduły na datasety/transformacje, modele, trenowanie, ewaluację i narzędzia,conf/– pliki konfiguracyjne, jeśli stosujesz konfigurację zewnętrzną.
Zapytaj siebie: czy po miesiącu bez problemu znajdziesz miejsce, gdzie definiujesz transformacje, a gdzie wczytujesz dane? Jeśli odpowiedź brzmi „niekoniecznie”, to sygnał, że warto rozdzielić kod i dane w jasne, powtarzalne foldery.






