Train Student Model

Train Student Model

Train a small, fast model on the teacher-labeled data.

Choose Student Architecture

The student should be small enough to deploy but large enough to learn the task:

TeacherStudent OptionsExpected Gap
ViT-LargeMobileNet v2, EfficientNet-B03-7%
EfficientNet-B4MobileNet v2, EfficientNet-B02-5%
GPT-4 VisionMobileNet v2, EfficientNet-B05-10%
BERT-LargeDistilBERT, TinyBERT2-4%
YOLOv5-LYOLOv5-S, YOLOv5-N5-10%

Selection Criteria

graph TD
    A{Deployment Target?} --> B[Mobile/Edge]
    A --> C[Cloud API]
    A --> D[Browser/WASM]

    B --> E[MobileNet v2
YOLOv5-N] C --> F[EfficientNet-B0
DistilBERT] D --> G[MobileNet v2
Smallest possible]

Configure Training

from seeme import Client

client = Client()

# Get your teacher-labeled dataset
dataset = client.get_dataset("distillation-dataset-id")
version = client.get_dataset_version(dataset.id, "v1")

# Configure student training
student_config = {
    # Model architecture
    "base_model": "mobilenet_v2",  # Small and fast
    "image_size": 224,

    # Training settings
    "epochs": 30,
    "batch_size": 32,
    "learning_rate": 0.001,
    "lr_scheduler": "cosine",

    # Regularization (important for distillation)
    "dropout": 0.2,
    "label_smoothing": 0.1,  # Soft labels help student generalize
    "weight_decay": 0.01,

    # Early stopping
    "early_stopping": True,
    "early_stopping_patience": 5,
    "early_stopping_metric": "val_accuracy",

    # Data augmentation
    "augmentation": {
        "horizontal_flip": True,
        "rotation": 15,
        "brightness": 0.2,
        "contrast": 0.2
    }
}

Start Training

Monitor Training

import time

def monitor_student_training(client, job_id, poll_interval=30):
    """Monitor student training progress."""
    job = client.get_job(job_id)

    print(f"Training: {job.name}")
    print(f"Config: {job.config.get('base_model')}, {job.config.get('epochs')} epochs")
    print("-" * 60)
    print(f"{'Epoch':<8} {'Train Loss':<12} {'Val Loss':<12} {'Val Acc':<12}")
    print("-" * 60)

    last_epoch = 0

    while job.status in ["pending", "queued", "running"]:
        time.sleep(poll_interval)
        job = client.get_job(job_id)

        if job.metrics and job.metrics.get('epoch', 0) > last_epoch:
            last_epoch = job.metrics['epoch']
            print(f"{last_epoch:<8} "
                  f"{job.metrics.get('train_loss', 0):<12.4f} "
                  f"{job.metrics.get('val_loss', 0):<12.4f} "
                  f"{job.metrics.get('val_accuracy', 0):<12.2%}")

    print("-" * 60)
    print(f"Status: {job.status}")

    if job.status == "completed":
        print(f"\n✅ Training complete!")
        print(f"   Best epoch: {job.best_epoch}")
        print(f"   Best val accuracy: {job.best_metrics['val_accuracy']:.2%}")
        print(f"   Model ID: {job.model_id}")
    elif job.status == "failed":
        print(f"\n❌ Training failed: {job.error}")

    return job

# Monitor training
student_job = monitor_student_training(client, student_job.id)

Get the Trained Student

if student_job.status == "completed":
    student_model = client.get_model(student_job.model_id)

    print(f"\nStudent Model: {student_model.id}")
    print(f"  Name: {student_model.name}")
    print(f"  Size: {student_model.size_mb:.1f} MB")
    print(f"  Architecture: {student_model.config.get('base_model')}")
    print(f"  Val Accuracy: {student_model.metrics.get('accuracy', 0):.2%}")

Compare Model Sizes

# Get teacher model
teacher_model = client.get_model(teacher_model_id)

print("\nModel Size Comparison:")
print("-" * 40)
print(f"{'Model':<20} {'Size':<15}")
print("-" * 40)
print(f"{'Teacher':<20} {teacher_model.size_mb:<15.1f} MB")
print(f"{'Student':<20} {student_model.size_mb:<15.1f} MB")
print(f"{'Reduction':<20} {teacher_model.size_mb / student_model.size_mb:<15.1f}x")

Troubleshooting

Student Accuracy is Too Low

If the student’s accuracy is significantly worse than expected:

# Option 1: Train longer
student_config["epochs"] = 50
student_config["early_stopping_patience"] = 10

# Option 2: Use a slightly larger student
student_config["base_model"] = "efficientnet_b0"  # Instead of mobilenet

# Option 3: Reduce regularization
student_config["dropout"] = 0.1
student_config["weight_decay"] = 0.001

# Option 4: More data augmentation
student_config["augmentation"] = {
    "horizontal_flip": True,
    "rotation": 30,  # Increase
    "brightness": 0.3,
    "contrast": 0.3,
    "cutout": 0.5  # Add cutout
}

Student Overfits

If training accuracy is much higher than validation:

# Increase regularization
student_config["dropout"] = 0.4
student_config["weight_decay"] = 0.05
student_config["label_smoothing"] = 0.2

# More augmentation
student_config["augmentation"]["mixup"] = 0.2
student_config["augmentation"]["cutout"] = 0.5

# Shorter training
student_config["epochs"] = 20
student_config["early_stopping_patience"] = 3

Training is Slow

# Reduce image size (with accuracy tradeoff)
student_config["image_size"] = 160  # Instead of 224

# Increase batch size
student_config["batch_size"] = 64

# Use mixed precision
student_config["mixed_precision"] = True

Train Multiple Students

Compare different architectures:

# Define student candidates
student_architectures = [
    {"name": "MobileNet v2", "base_model": "mobilenet_v2", "image_size": 224},
    {"name": "EfficientNet B0", "base_model": "efficientnet_b0", "image_size": 224},
    {"name": "ResNet-18", "base_model": "resnet18", "image_size": 224},
]

# Train each
student_jobs = []
for arch in student_architectures:
    config = {**student_config, **arch}
    job = client.create_job(
        dataset_id=dataset.id,
        version_id=version.id,
        name=f"Student: {arch['name']}",
        job_type="finetune",
        config=config
    )
    student_jobs.append(job)
    print(f"Started: {arch['name']}")

# Wait for all to complete
for job in student_jobs:
    while client.get_job(job.id).status in ["pending", "running"]:
        time.sleep(60)

# Compare results
print("\nStudent Comparison:")
print("-" * 60)
print(f"{'Architecture':<20} {'Val Acc':<12} {'Size (MB)':<12}")
print("-" * 60)

for job in student_jobs:
    job = client.get_job(job.id)
    if job.status == "completed":
        model = client.get_model(job.model_id)
        print(f"{job.config['base_model']:<20} "
              f"{job.best_metrics['val_accuracy']:<12.2%} "
              f"{model.size_mb:<12.1f}")

Best Practices

  1. Start with MobileNet v2 - Good balance of size and accuracy
  2. Use label smoothing - Helps student generalize from teacher’s soft knowledge
  3. Don’t overtrain - Use early stopping to prevent overfitting
  4. Match input size - Use the same image size the base model was trained on
  5. Try multiple architectures - Different students may work better for different tasks
  6. Save checkpoints - You might want a checkpoint from an earlier epoch

Next Step

With your student trained, proceed to Evaluate & Compare to see how it stacks up against the teacher.