Model Distillation

Model Distillation

Train a small, fast model that matches the quality of a large, expensive model. Use the large model as a “teacher” to label data, then train a “student” model on those labels.

Why Distillation?

graph LR
    subgraph "Before Distillation"
        A[Large Model
Accurate but slow
$$$] --> B[Predictions] end subgraph "After Distillation" C[Small Model
Same accuracy
10-100x faster
$] --> D[Predictions] end
MetricLarge Model (Teacher)Small Model (Student)
Accuracy95%93%
Latency500ms10ms
Cost per 1000 predictions$10$0.10
Runs on mobileNoYes
GPU requiredYesNo

When to Use Distillation

Good candidates:

  • You have a large model (or LLM) that works well but is too slow/expensive
  • You need to deploy to mobile, edge, or high-throughput APIs
  • You want to reduce inference costs
  • You have unlabeled data you can label with the teacher

Not ideal when:

  • You already have enough labeled data for direct training
  • The task is too complex for a small model
  • Teacher model accuracy is insufficient

The Distillation Process

graph TD
    A[1. Select Teacher Model] --> B[2. Label Data with Teacher]
    B --> C[3. Review Labels]
    C --> D[4. Train Student Model]
    D --> E[5. Evaluate Both]
    E --> F{Student good enough?}
    F -->|Yes| G[6. Deploy Student]
    F -->|No| H[Add more data]
    H --> B

Guide Sections

Quick Start

Complete Example

from seeme import Client
import time

client = Client()

## --- Step 1: Setup ---
# Teacher: Large accurate model (or LLM)
teacher_model = client.get_model("large-classifier-id")  # or LLM

# Dataset for distillation
dataset = client.create_dataset(
    name="Distillation Dataset",
    task_type="image_classification"
)
version = client.create_dataset_version(dataset_id=dataset.id, name="v1")

# Create labels
for label in ["cat", "dog", "bird", "other"]:
    client.create_label(version_id=version.id, name=label)

# --- Step 2: Label data with teacher ---
processor = client.create_post_processor(
    dataset_id=dataset.id,
    name="Teacher Labeler",
    model_type="classification",  # or "llm" for LLM teacher
    model_id=teacher_model.id,
    output_target="annotations",
    confidence_threshold=0.8,
    auto_create_labels=True,
    enabled=True
)

# Upload unlabeled images
import glob
for image_path in glob.glob("./unlabeled/*.jpg"):
    client.create_dataset_item(
        version_id=version.id,
        split_id=train_split.id,
        file_path=image_path
    )

# Wait for labeling to complete
while True:
    jobs = client.get_post_processor_jobs(dataset_id=dataset.id, status="pending")
    if len(jobs) == 0:
        break
    time.sleep(10)

print("Labeling complete!")

# --- Step 3: Review labels (in web UI or programmatically) ---
# Check a sample, correct errors

# --- Step 4: Train student ---
student_job = client.create_job(
    dataset_id=dataset.id,
    version_id=version.id,
    name="Student: MobileNet",
    config={
        "base_model": "mobilenet_v2",  # Small, fast model
        "epochs": 30,
        "learning_rate": 0.001,
        "batch_size": 32
    }
)

# Wait for training
while student_job.status in ["pending", "running"]:
    time.sleep(30)
    student_job = client.get_job(student_job.id)

student_model = client.get_model(student_job.model_id)
print(f"Student trained! Accuracy: {student_job.best_metrics['val_accuracy']:.2%}")

# --- Step 5: Compare teacher vs student ---
# On held-out validation set with ground truth labels
val_dataset = client.get_dataset("validation-dataset-id")

teacher_results = client.evaluate_model(
    model_id=teacher_model.id,
    dataset_id=val_dataset.id,
    split="test"
)

student_results = client.evaluate_model(
    model_id=student_model.id,
    dataset_id=val_dataset.id,
    split="test"
)

print(f"\nComparison:")
print(f"Teacher accuracy: {teacher_results['accuracy']:.2%}")
print(f"Student accuracy: {student_results['accuracy']:.2%}")
print(f"Gap: {teacher_results['accuracy'] - student_results['accuracy']:.2%}")

# --- Step 6: Deploy student if good enough ---
if student_results['accuracy'] > 0.90:  # Your threshold
    optimized = client.optimize_model(
        model_id=student_model.id,
        target_format="onnx",
        quantize=True
    )
    client.deploy_model(model_id=optimized.id, name="Production Classifier")
    print("Student deployed!")

Teacher Model Options

Teacher TypeBest ForExample
Large classifierWhen you have a big pre-trained modelEfficientNet-B4, ViT-Large
LLM (Ollama)Flexible labeling with reasoningLlama, Mistral, Mixtral
External LLMBest accuracy, highest costGPT-4, Claude
EnsembleMaximum accuracyMultiple models voting

Using an LLM as Teacher

# LLM teacher can handle complex classification with reasoning
processor = client.create_post_processor(
    dataset_id=dataset.id,
    name="LLM Teacher",
    model_type="llm",
    model_id=llm_model.id,
    prompt="""
    Classify this image into exactly one category:
    - cat: Any cat, kitten, or feline
    - dog: Any dog, puppy, or canine
    - bird: Any bird species
    - other: Anything else

    Look carefully at the image and return only the category name.
    """,
    output_target="annotations",
    auto_create_labels=True
)

Student Model Selection

Choose the smallest model that can still learn the task:

TaskTeacherStudentTypical Gap
Image ClassificationViT-LargeMobileNet v22-5%
Image ClassificationEfficientNet-B4EfficientNet-B01-3%
Object DetectionYOLOv5-LYOLOv5-S3-8%
Text ClassificationBERT-LargeDistilBERT1-3%
Text ClassificationGPT-4DistilBERT3-7%

Best Practices

  1. Use confidence thresholds - Only keep high-confidence teacher predictions
  2. Review a sample - Teacher labels aren’t perfect; review 5-10%
  3. Evaluate on ground truth - Compare both models on human-labeled test set
  4. Iterate - Add more data where student struggles
  5. Right-size the student - Too small = can’t learn; too big = no benefit
  6. Don’t forget inference cost - The whole point is to reduce it

Related Guides