Iterate & Improve

Iterate & Improve

Model development is iterative. Use insights from evaluation to systematically improve your model.

Improvement Framework

graph TD
    A[Evaluate Model] --> B{Performance OK?}
    B -->|Yes| C[Deploy]
    B -->|No| D[Diagnose Issue]

    D --> E{Issue Type?}
    E -->|Data| F[Improve Data]
    E -->|Model| G[Adjust Model/Config]
    E -->|Specific Classes| H[Target Weak Areas]

    F --> A
    G --> A
    H --> A

Diagnose the Problem

Decision Tree

graph TD
    A[Low Performance] --> B{Val loss still decreasing?}
    B -->|Yes| C[Train longer]
    B -->|No| D{Train vs Val gap?}

    D -->|Large gap| E[Overfitting]
    D -->|Small gap| F{Loss value?}

    F -->|High| G[Underfitting]
    F -->|Low but accuracy bad| H[Data quality issue]

    E --> I[More regularization
More data
Simpler model] G --> J[Larger model
Less regularization
More epochs] H --> K[Check labels
Clean data
Balance classes]

Common Issues and Solutions

SymptomLikely CauseSolutions
Low overall accuracyUnderfittingLarger model, more epochs, less regularization
High train, low val accuracyOverfittingMore data, augmentation, regularization
One class performs poorlyClass imbalance or hard classMore samples, weighted loss, better features
Low confidence predictionsModel uncertaintyMore training data, better augmentation
High loss, decent accuracyPoor calibrationTemperature scaling, label smoothing

Strategy 1: Improve Data

Add More Training Data

## Check which classes need more data
stats = client.get_dataset_stats(version_id=version.id)

min_samples = min(stats['label_counts'].values())
target_samples = max(stats['label_counts'].values())

for label, count in stats['label_counts'].items():
    if count < target_samples * 0.5:  # Less than 50% of max
        needed = target_samples - count
        print(f"Class '{label}': needs ~{needed} more samples")

Use Automated Labeling

Speed up data collection with Automated Labeling:

# Use current model to pre-label new data
processor = client.create_post_processor(
    dataset_id=unlabeled_dataset.id,
    name="Pre-label with current model",
    model_type="classification",
    model_id=model.id,
    output_target="annotations",
    confidence_threshold=0.8,
    auto_create_labels=True,
    enabled=True
)

# Upload unlabeled images
for image_path in new_images:
    client.create_dataset_item(
        version_id=unlabeled_version.id,
        file_path=image_path
    )

# Review and correct labels in web UI, then add to training set

Fix Label Errors

# Export low-confidence and misclassified for review
review_items = []

for pred in predictions:
    # Low confidence
    if pred.confidence < 0.6:
        review_items.append(pred)
    # Misclassified
    elif pred.prediction != pred.ground_truth:
        review_items.append(pred)

print(f"Items to review: {len(review_items)}")

# Create a review dataset
for item in review_items:
    # Flag item for review
    client.update_dataset_item(
        item_id=item.item_id,
        metadata={"needs_review": True, "review_reason": "low_confidence_or_error"}
    )

Balance Classes

# Option 1: Oversample minority classes
config = {
    "oversampling": True,
    "oversampling_strategy": "minority"
}

# Option 2: Weighted loss function
config = {
    "class_weights": "balanced"
}

# Option 3: Collect more data for minority classes
# (see automated labeling above)

Strategy 2: Adjust Model Configuration

Try Different Base Models

# Compare different architectures
base_models = ["mobilenet_v2", "efficientnet_b0", "efficientnet_b2", "resnet50"]

for base_model in base_models:
    job = client.create_job(
        dataset_id=dataset.id,
        version_id=version.id,
        name=f"Experiment: {base_model}",
        config={
            "base_model": base_model,
            "epochs": 20,
            "batch_size": 32,
            "early_stopping": True
        }
    )
    print(f"Started: {base_model}")

Tune Hyperparameters

# Systematic hyperparameter search
from itertools import product

learning_rates = [0.0001, 0.001, 0.01]
batch_sizes = [16, 32, 64]
dropout_rates = [0.2, 0.3, 0.5]

experiments = []
for lr, bs, dropout in product(learning_rates, batch_sizes, dropout_rates):
    experiments.append({
        "learning_rate": lr,
        "batch_size": bs,
        "dropout": dropout
    })

# Run top experiments (or use hyperparameter search)
for i, exp in enumerate(experiments[:9]):  # Top 9 combinations
    job = client.create_job(
        dataset_id=dataset.id,
        version_id=version.id,
        name=f"HP Search {i+1}",
        config={
            "base_model": "efficientnet_b0",
            "epochs": 20,
            **exp
        }
    )

Adjust Regularization

# If overfitting: increase regularization
config_more_reg = {
    "dropout": 0.5,              # Increase from 0.2
    "weight_decay": 0.05,        # Increase from 0.01
    "label_smoothing": 0.2,      # Add or increase
    "augmentation": {
        "horizontal_flip": True,
        "rotation": 30,          # Increase from 15
        "brightness": 0.3,
        "contrast": 0.3,
        "cutout": 0.5            # Add cutout
    }
}

# If underfitting: decrease regularization
config_less_reg = {
    "dropout": 0.1,              # Decrease
    "weight_decay": 0.001,       # Decrease
    "freeze_layers": "none",     # Unfreeze all layers
    "epochs": 50                 # Train longer
}

Strategy 3: Target Weak Areas

Focus on Problem Classes

# Identify worst-performing classes
per_class = calculate_per_class_metrics(predictions, labels)
weak_classes = [label for label, m in per_class.items() if m['f1'] < 0.8]

print(f"Weak classes (F1 < 80%): {weak_classes}")

# Strategies for weak classes:
# 1. Collect more samples
# 2. Review and fix mislabeled data
# 3. Add class-specific augmentation
# 4. Consider merging similar classes

Analyze Confusion Patterns

# Find commonly confused class pairs
confusion_pairs = []
for actual in labels:
    for predicted in labels:
        if actual != predicted:
            count = confusion[actual][predicted]
            if count > 5:  # More than 5 confusions
                confusion_pairs.append((actual, predicted, count))

confusion_pairs.sort(key=lambda x: -x[2])

print("Most Confused Pairs:")
for actual, predicted, count in confusion_pairs[:5]:
    print(f"  {actual} confused as {predicted}: {count} times")

# Solutions:
# - Add more distinguishing examples
# - Review if classes should be merged
# - Add specific augmentation for these classes

Hard Example Mining

# Find hardest examples (high loss or wrong predictions)
hard_examples = []

for pred in predictions:
    if pred.prediction != pred.ground_truth:
        hard_examples.append({
            "item_id": pred.item_id,
            "actual": pred.ground_truth,
            "predicted": pred.prediction,
            "confidence": pred.confidence
        })
    elif pred.confidence < 0.6:  # Correct but uncertain
        hard_examples.append({
            "item_id": pred.item_id,
            "actual": pred.ground_truth,
            "predicted": pred.prediction,
            "confidence": pred.confidence
        })

print(f"Hard examples: {len(hard_examples)}")

# Create a "hard examples" dataset version with extra weight

Strategy 4: Advanced Techniques

Use Model Distillation

If you have a larger model that performs well, distill to a smaller one:

# Train large model first
large_job = client.create_job(
    dataset_id=dataset.id,
    version_id=version.id,
    name="Large Teacher Model",
    config={
        "base_model": "efficientnet_b4",
        "epochs": 30
    }
)

# Use large model to label more data
# Then train small model on expanded dataset
# See: /guides/model-distillation/

Ensemble Models

Combine predictions from multiple models:

# Train multiple models with different seeds/configs
models = []
for i in range(3):
    job = client.create_job(
        dataset_id=dataset.id,
        version_id=version.id,
        name=f"Ensemble Member {i+1}",
        config={
            "base_model": "efficientnet_b0",
            "seed": 42 + i,  # Different random seeds
            "epochs": 20
        }
    )
    models.append(job)

# Ensemble predictions (voting or averaging)
def ensemble_predict(client, model_ids, file_path):
    predictions = []
    for model_id in model_ids:
        pred = client.predict(model_id=model_id, item=file_path)
        predictions.append(pred)

    # Majority voting
    from collections import Counter
    votes = Counter(p.prediction for p in predictions)
    return votes.most_common(1)[0][0]

Progressive Training

Gradually increase difficulty:

# Stage 1: Train on easy examples (high confidence from teacher)
# Stage 2: Add medium difficulty examples
# Stage 3: Add hard examples

stages = [
    {"confidence_threshold": 0.95, "epochs": 10},
    {"confidence_threshold": 0.80, "epochs": 10},
    {"confidence_threshold": 0.50, "epochs": 10},
]

for i, stage in enumerate(stages):
    # Filter dataset to examples above confidence threshold
    # Train incrementally
    pass

Track Experiments

Keep records of all experiments:

import json
from datetime import datetime

def log_experiment(name, config, results, notes=""):
    """Log experiment for future reference."""
    experiment = {
        "name": name,
        "timestamp": datetime.now().isoformat(),
        "config": config,
        "results": results,
        "notes": notes
    }

    # Append to experiment log
    try:
        with open("experiments.json", "r") as f:
            experiments = json.load(f)
    except FileNotFoundError:
        experiments = []

    experiments.append(experiment)

    with open("experiments.json", "w") as f:
        json.dump(experiments, f, indent=2)

    print(f"Logged experiment: {name}")

# Log each experiment
log_experiment(
    name="EfficientNet-B0 + Strong Augmentation",
    config=job.config,
    results={
        "val_accuracy": 0.923,
        "test_accuracy": 0.918,
        "training_time": "45 minutes"
    },
    notes="Added cutout augmentation, improved from 0.89 to 0.92"
)

Iteration Checklist

  1. Document baseline - Record current performance
  2. Identify bottleneck - Is it data, model, or config?
  3. Change one thing - Isolate variables
  4. Run experiment - Same evaluation protocol
  5. Compare results - Statistical significance?
  6. Document findings - What worked, what didn’t
  7. Repeat - Until requirements met

When to Stop

Stop iterating when:

  • ✅ Performance meets requirements
  • ✅ Additional improvements are marginal
  • ✅ Cost of more data/compute exceeds benefit
  • ✅ Deadline requires deployment

Next Steps

Once your model meets requirements: