Train Your Image Classification Model

Train Your Image Classification Model

With your data prepared, it’s time to train a model. This section covers training configuration, monitoring, and understanding results.

Training Overview

graph LR
    A[Prepared Dataset] --> B[Configure Training]
    B --> C[Start Job]
    C --> D[Monitor Progress]
    D --> E[Trained Model]

Using the Web Platform

Using the Python SDK

Model Architectures

Choose the right architecture for your needs:

ArchitectureAccuracySpeedSizeBest For
ResNet18GoodFast45MBQuick prototypes
ResNet34BetterMedium85MBBalanced default
ResNet50BestSlower100MBProduction quality
EfficientNet-B0BestFast20MBMobile deployment
EfficientNet-B3ExcellentMedium50MBHigh accuracy
# Example: Use EfficientNet for mobile deployment
job = client.create_job(
    dataset_id=dataset.id,
    dataset_version_id=version.id,
    framework="fastai",
    application="image_classification",
    config={
        "architecture": "efficientnet_b0",
        "epochs": 15,
        "export_formats": ["coreml", "tflite"]
    }
)

Data Augmentation

Augmentation creates variations of your images during training, helping the model generalize better.

Available Augmentations

AugmentationDescriptionWhen to Use
flip_horizontalMirror left-rightMost images
flip_verticalMirror top-bottomSatellite imagery
rotationRotate by degreesWhen orientation varies
zoomRandom zoom in/outProduct photos
lightingBrightness/contrastVarying lighting conditions
warpPerspective distortionDocuments, signs
config = {
    "augmentation": {
        "flip_horizontal": True,
        "flip_vertical": False,
        "rotation": 20,          # Rotate up to 20 degrees
        "zoom": 0.2,             # Zoom up to 20%
        "lighting": 0.3,         # Vary lighting 30%
        "warp": 0.1              # Slight perspective changes
    }
}

Understanding Training Metrics

Loss

  • Training loss: How well the model fits training data
  • Validation loss: How well the model generalizes
ℹ️
Healthy training: Both losses decrease. Validation loss slightly higher than training loss.

Accuracy

  • Top-1 accuracy: % of images where the top prediction is correct
  • Top-5 accuracy: % where correct class is in top 5 predictions

Confusion Matrix

# Get confusion matrix after training
model = client.get_model(job.model_id)
confusion = model.confusion_matrix

# Shows predictions vs actual labels
# High values on diagonal = good performance
# Off-diagonal = confused classes

Troubleshooting Training

Low Accuracy

SymptomLikely CauseSolution
< 50% accuracyToo little dataAdd more images
Stuck at random guessBad labelsCheck for labeling errors
Good training, bad validationOverfittingMore data, augmentation

Training Failed

Common failure reasons:

  1. Out of memory: Reduce batch size
  2. Corrupt images: Check dataset for bad files
  3. Empty splits: Ensure train/validation have data
# Check for common issues
version = client.get_dataset_version(version_id)

print(f"Total items: {version.item_count}")
for split in version.splits:
    print(f"  {split.name}: {split.item_count} items")

# Should have items in both train and validation

Best Practices

  1. Start small: Begin with 10 epochs, increase if needed
  2. Use validation: Always split data for honest evaluation
  3. Monitor overfitting: Watch for validation loss increasing
  4. Save checkpoints: Enable model checkpointing for long runs
  5. Experiment: Try different architectures and augmentations

Next Step

3. Optimize →