Iterate & Improve
Model development is iterative. Use insights from evaluation to systematically improve your model.
Improvement Framework
graph TD
A[Evaluate Model] --> B{Performance OK?}
B -->|Yes| C[Deploy]
B -->|No| D[Diagnose Issue]
D --> E{Issue Type?}
E -->|Data| F[Improve Data]
E -->|Model| G[Adjust Model/Config]
E -->|Specific Classes| H[Target Weak Areas]
F --> A
G --> A
H --> ADiagnose the Problem
Decision Tree
graph TD
A[Low Performance] --> B{Val loss still decreasing?}
B -->|Yes| C[Train longer]
B -->|No| D{Train vs Val gap?}
D -->|Large gap| E[Overfitting]
D -->|Small gap| F{Loss value?}
F -->|High| G[Underfitting]
F -->|Low but accuracy bad| H[Data quality issue]
E --> I[More regularization
More data
Simpler model]
G --> J[Larger model
Less regularization
More epochs]
H --> K[Check labels
Clean data
Balance classes]Common Issues and Solutions
| Symptom | Likely Cause | Solutions |
|---|---|---|
| Low overall accuracy | Underfitting | Larger model, more epochs, less regularization |
| High train, low val accuracy | Overfitting | More data, augmentation, regularization |
| One class performs poorly | Class imbalance or hard class | More samples, weighted loss, better features |
| Low confidence predictions | Model uncertainty | More training data, better augmentation |
| High loss, decent accuracy | Poor calibration | Temperature scaling, label smoothing |
Strategy 1: Improve Data
Add More Training Data
## Check which classes need more data
stats = client.get_dataset_stats(version_id=version.id)
min_samples = min(stats['label_counts'].values())
target_samples = max(stats['label_counts'].values())
for label, count in stats['label_counts'].items():
if count < target_samples * 0.5: # Less than 50% of max
needed = target_samples - count
print(f"Class '{label}': needs ~{needed} more samples")Use Automated Labeling
Speed up data collection with Automated Labeling:
# Use current model to pre-label new data
processor = client.create_post_processor(
dataset_id=unlabeled_dataset.id,
name="Pre-label with current model",
model_type="classification",
model_id=model.id,
output_target="annotations",
confidence_threshold=0.8,
auto_create_labels=True,
enabled=True
)
# Upload unlabeled images
for image_path in new_images:
client.create_dataset_item(
version_id=unlabeled_version.id,
file_path=image_path
)
# Review and correct labels in web UI, then add to training setFix Label Errors
# Export low-confidence and misclassified for review
review_items = []
for pred in predictions:
# Low confidence
if pred.confidence < 0.6:
review_items.append(pred)
# Misclassified
elif pred.prediction != pred.ground_truth:
review_items.append(pred)
print(f"Items to review: {len(review_items)}")
# Create a review dataset
for item in review_items:
# Flag item for review
client.update_dataset_item(
item_id=item.item_id,
metadata={"needs_review": True, "review_reason": "low_confidence_or_error"}
)Balance Classes
# Option 1: Oversample minority classes
config = {
"oversampling": True,
"oversampling_strategy": "minority"
}
# Option 2: Weighted loss function
config = {
"class_weights": "balanced"
}
# Option 3: Collect more data for minority classes
# (see automated labeling above)Strategy 2: Adjust Model Configuration
Try Different Base Models
# Compare different architectures
base_models = ["mobilenet_v2", "efficientnet_b0", "efficientnet_b2", "resnet50"]
for base_model in base_models:
job = client.create_job(
dataset_id=dataset.id,
version_id=version.id,
name=f"Experiment: {base_model}",
config={
"base_model": base_model,
"epochs": 20,
"batch_size": 32,
"early_stopping": True
}
)
print(f"Started: {base_model}")Tune Hyperparameters
# Systematic hyperparameter search
from itertools import product
learning_rates = [0.0001, 0.001, 0.01]
batch_sizes = [16, 32, 64]
dropout_rates = [0.2, 0.3, 0.5]
experiments = []
for lr, bs, dropout in product(learning_rates, batch_sizes, dropout_rates):
experiments.append({
"learning_rate": lr,
"batch_size": bs,
"dropout": dropout
})
# Run top experiments (or use hyperparameter search)
for i, exp in enumerate(experiments[:9]): # Top 9 combinations
job = client.create_job(
dataset_id=dataset.id,
version_id=version.id,
name=f"HP Search {i+1}",
config={
"base_model": "efficientnet_b0",
"epochs": 20,
**exp
}
)Adjust Regularization
# If overfitting: increase regularization
config_more_reg = {
"dropout": 0.5, # Increase from 0.2
"weight_decay": 0.05, # Increase from 0.01
"label_smoothing": 0.2, # Add or increase
"augmentation": {
"horizontal_flip": True,
"rotation": 30, # Increase from 15
"brightness": 0.3,
"contrast": 0.3,
"cutout": 0.5 # Add cutout
}
}
# If underfitting: decrease regularization
config_less_reg = {
"dropout": 0.1, # Decrease
"weight_decay": 0.001, # Decrease
"freeze_layers": "none", # Unfreeze all layers
"epochs": 50 # Train longer
}Strategy 3: Target Weak Areas
Focus on Problem Classes
# Identify worst-performing classes
per_class = calculate_per_class_metrics(predictions, labels)
weak_classes = [label for label, m in per_class.items() if m['f1'] < 0.8]
print(f"Weak classes (F1 < 80%): {weak_classes}")
# Strategies for weak classes:
# 1. Collect more samples
# 2. Review and fix mislabeled data
# 3. Add class-specific augmentation
# 4. Consider merging similar classesAnalyze Confusion Patterns
# Find commonly confused class pairs
confusion_pairs = []
for actual in labels:
for predicted in labels:
if actual != predicted:
count = confusion[actual][predicted]
if count > 5: # More than 5 confusions
confusion_pairs.append((actual, predicted, count))
confusion_pairs.sort(key=lambda x: -x[2])
print("Most Confused Pairs:")
for actual, predicted, count in confusion_pairs[:5]:
print(f" {actual} confused as {predicted}: {count} times")
# Solutions:
# - Add more distinguishing examples
# - Review if classes should be merged
# - Add specific augmentation for these classesHard Example Mining
# Find hardest examples (high loss or wrong predictions)
hard_examples = []
for pred in predictions:
if pred.prediction != pred.ground_truth:
hard_examples.append({
"item_id": pred.item_id,
"actual": pred.ground_truth,
"predicted": pred.prediction,
"confidence": pred.confidence
})
elif pred.confidence < 0.6: # Correct but uncertain
hard_examples.append({
"item_id": pred.item_id,
"actual": pred.ground_truth,
"predicted": pred.prediction,
"confidence": pred.confidence
})
print(f"Hard examples: {len(hard_examples)}")
# Create a "hard examples" dataset version with extra weightStrategy 4: Advanced Techniques
Use Model Distillation
If you have a larger model that performs well, distill to a smaller one:
# Train large model first
large_job = client.create_job(
dataset_id=dataset.id,
version_id=version.id,
name="Large Teacher Model",
config={
"base_model": "efficientnet_b4",
"epochs": 30
}
)
# Use large model to label more data
# Then train small model on expanded dataset
# See: /guides/model-distillation/Ensemble Models
Combine predictions from multiple models:
# Train multiple models with different seeds/configs
models = []
for i in range(3):
job = client.create_job(
dataset_id=dataset.id,
version_id=version.id,
name=f"Ensemble Member {i+1}",
config={
"base_model": "efficientnet_b0",
"seed": 42 + i, # Different random seeds
"epochs": 20
}
)
models.append(job)
# Ensemble predictions (voting or averaging)
def ensemble_predict(client, model_ids, file_path):
predictions = []
for model_id in model_ids:
pred = client.predict(model_id=model_id, item=file_path)
predictions.append(pred)
# Majority voting
from collections import Counter
votes = Counter(p.prediction for p in predictions)
return votes.most_common(1)[0][0]Progressive Training
Gradually increase difficulty:
# Stage 1: Train on easy examples (high confidence from teacher)
# Stage 2: Add medium difficulty examples
# Stage 3: Add hard examples
stages = [
{"confidence_threshold": 0.95, "epochs": 10},
{"confidence_threshold": 0.80, "epochs": 10},
{"confidence_threshold": 0.50, "epochs": 10},
]
for i, stage in enumerate(stages):
# Filter dataset to examples above confidence threshold
# Train incrementally
passTrack Experiments
Keep records of all experiments:
import json
from datetime import datetime
def log_experiment(name, config, results, notes=""):
"""Log experiment for future reference."""
experiment = {
"name": name,
"timestamp": datetime.now().isoformat(),
"config": config,
"results": results,
"notes": notes
}
# Append to experiment log
try:
with open("experiments.json", "r") as f:
experiments = json.load(f)
except FileNotFoundError:
experiments = []
experiments.append(experiment)
with open("experiments.json", "w") as f:
json.dump(experiments, f, indent=2)
print(f"Logged experiment: {name}")
# Log each experiment
log_experiment(
name="EfficientNet-B0 + Strong Augmentation",
config=job.config,
results={
"val_accuracy": 0.923,
"test_accuracy": 0.918,
"training_time": "45 minutes"
},
notes="Added cutout augmentation, improved from 0.89 to 0.92"
)Iteration Checklist
- Document baseline - Record current performance
- Identify bottleneck - Is it data, model, or config?
- Change one thing - Isolate variables
- Run experiment - Same evaluation protocol
- Compare results - Statistical significance?
- Document findings - What worked, what didn’t
- Repeat - Until requirements met
When to Stop
Stop iterating when:
- ✅ Performance meets requirements
- ✅ Additional improvements are marginal
- ✅ Cost of more data/compute exceeds benefit
- ✅ Deadline requires deployment
Next Steps
Once your model meets requirements: