Chapter 32: Training vs Inference - Two Different Worlds
Training a model and serving it in production are fundamentally different operations. They have different performance requirements, different cost structures, different failure modes, and different engineering challenges. Understanding this divide is understanding why deploying ML is hard.
Training is offline, batch-oriented, and runs on powerful GPU clusters over hours or days. Accuracy is what matters. Inference is online, request-oriented, and must respond in milliseconds on limited hardware. Latency is what matters. A model that achieves 99% accuracy but takes 10 seconds to respond is useless for most applications.
This chapter explains the training-inference divide, why it creates constraints, and how production systems navigate these trade-offs.
Offline vs Online: Different Performance Worlds
Training is offline batch processing. You have a fixed dataset. You process it multiple times (epochs). You can use large batches (64, 128, 512 examples at once) to maximize GPU utilization. You can take as long as neededâif training takes a week, you wait a week. The goal is to minimize loss on the training set (and generalize to test set).
Training characteristics:
- Time scale: Hours to weeks
- Throughput over latency: Process millions of examples, donât care about per-example time
- Hardware: High-end GPUs (A100, H100), often 8-64 GPUs in parallel
- Cost structure: Fixed upfront cost (train once, deploy forever⊠until you retrain)
- Optimization target: Minimize loss function
Inference is online request processing. A user sends a query. The model must respond immediately. You process one request at a time (or small batches). Latency matters more than throughputâusers wonât wait 10 seconds for a search result. The goal is to minimize response time while maintaining accuracy.
Inference characteristics:
- Time scale: Milliseconds to seconds
- Latency over throughput: Each request must complete quickly
- Hardware: CPUs, GPUs, or specialized accelerators (TPUs, edge devices)
- Cost structure: Per-query cost that scales with traffic
- Optimization target: Minimize latency, minimize cost per query
Latency requirements vary by application:
- Search: 50-100ms total budget (query understanding, retrieval, ranking, rendering)
- Recommendations: 100-200ms (more tolerance than search)
- Chatbots: 1-2 seconds (users tolerate slight delay for conversational UI)
- Real-time systems: <10ms (trading algorithms, autonomous vehicles)
Every 100ms of latency costs engagement. Google found that 500ms of delay reduces traffic by 20%. Amazon found that 100ms of latency costs 1% of sales. Users are impatient. Latency kills.
Example: Google Search latency budget
A Google Search query has ~50ms latency budget from query submission to results displayed. Within this budget:
- ~10ms: Network latency (user to data center)
- ~5ms: Query understanding (spelling correction, intent classification)
- ~20ms: Retrieval and ranking (fetch candidates, score with ML models)
- ~10ms: Snippet generation and rendering
- ~5ms: Network return
The ranking modelâpotentially the most complex ML componentâgets 20ms. If the model takes 200ms, it is unusable. The model must be fast, even if that means sacrificing accuracy. A 95% accurate model that runs in 15ms beats a 99% accurate model that runs in 100ms.
Model Size vs Latency: The Central Trade-Off
Bigger models are generally more accurate. More parameters capture more patterns. But bigger models are slower at inference. Every parameter must be loaded from memory, every layer must be computed. The trade-off is fundamental.
GPT-3 has 175 billion parameters, occupying ~350GB of memory in FP16. Running inference requires loading these parameters and computing matrix multiplications. On a single A100 GPU, GPT-3 inference takes ~1-2 seconds per token. For a 100-token response, that is 100-200 secondsâunusable for chat applications.
OpenAI solves this with model sharding (splitting the model across multiple GPUs) and batching (processing multiple requests simultaneously). But even optimized, GPT-3 inference costs ~$0.002 per 1K tokens. At billions of queries, this adds up.
Model compression techniques trade accuracy for speed:
Quantization: Reduce numerical precision. Standard training uses FP32 (32-bit floating point). Inference can use FP16 (16-bit), INT8 (8-bit integers), or even INT4. Lower precision means:
- 2-4x smaller models: 110M parameter BERT in FP32 is 440MB; in INT8 it is 110MB
- 2-4x faster inference: Fewer bits to move, faster arithmetic
- Small accuracy loss: Typically <1% accuracy drop with INT8
Quantization works because models are over-parameterized. Many weights contribute little to predictions. Rounding them to lower precision has minimal impact.
Distillation: Train a small model (student) to mimic a large model (teacher). The teacher model produces soft predictions (probabilities over all classes), which contain more information than hard labels. The student learns to match these predictions.
DistilBERT achieves 97% of BERTâs performance with 40% fewer parameters and 60% faster inference. TinyBERT goes further: 96% performance, 7.5x smaller, 9.4x faster. Distilled models are not just smallerâthey are often more efficient because they are trained to be small from the start.
Pruning: Remove unimportant weights. After training, identify weights close to zero and set them to zero. Sparse models (many zero weights) can be stored and computed more efficiently.
Pruning can remove 50-90% of weights with <1% accuracy loss. But exploiting sparsity requires specialized hardware or libraries (NVIDIAâs sparsity support, Intelâs Deep Learning Boost). On standard hardware, pruning saves memory but not necessarily compute.
Early exit: Add intermediate classifiers at early layers. For easy examples, the model can exit early without computing all layers. For hard examples, compute the full model. This makes average latency faster while maintaining accuracy on hard cases.
Mobile models like MobileNet and EfficientNet are designed for inference. They use depthwise separable convolutions (cheaper than standard convolutions) and carefully balance width, depth, and resolution to maximize accuracy per FLOP.
Caching and Batching: How Systems Survive at Scale
Even with fast models, inference at billions of queries per day requires system-level optimizations: caching and batching.
Caching stores results of previous queries. If a user searches âweather New York,â the results are cached. If another user searches the same query within minutes, serve the cached resultâno inference needed. Latency drops from 50ms to 5ms. Cost drops to near zero.
Caching works well for:
- Duplicate queries: Many users search the same popular terms
- Near-duplicate queries: Minor variations (case, whitespace) map to same cache key
- Temporal locality: Same user repeats queries
Caching does not work for:
- Personalized queries: Results depend on user context
- Long-tail queries: Unique queries never hit cache
- Time-sensitive queries: Results go stale quickly (news, stock prices)
Cache hit rates vary by application. Google Search might have 30-40% cache hit rate (many unique queries). Netflix recommendations have lower hit rates (personalized). ChatGPT cannot cache much (every conversation is unique).
Batching processes multiple requests together. GPUs are parallel processorsâthey are most efficient when computing on large batches. Processing 1 request on a GPU is wasteful. Processing 64 requests simultaneously is efficient.
Batching trades latency for throughput. If you batch 64 requests:
- Throughput: 64 requests processed in one forward pass â 64x throughput increase
- Latency: Each request waits for batch to fill â slight latency increase
Dynamic batching waits a short time (e.g., 10ms) to accumulate requests, then processes the batch. If traffic is high, batches fill quickly. If traffic is low, you wait the full 10ms. This is a latency-throughput trade-off: you accept 10ms extra latency to gain 10-50x throughput.
Production serving systems (TensorFlow Serving, TorchServe, Triton Inference Server) implement dynamic batching automatically. They monitor request arrivals, form batches, and adjust batch size based on latency constraints.
Example: GPT inference with batching
GPT-3 inference on a single request: 1 second per token, 100 tokens = 100 seconds.
GPT-3 inference with batch size 32: 1.5 seconds per token (slightly slower due to larger batch), 100 tokens = 150 seconds total, but 32 requests complete. Per-request time: 150/32 = 4.7 seconds.
Batching made individual requests slower (4.7s vs 1s), but throughput increased 20x. For services with high traffic, this trade-off is essential. The challenge is keeping batches full without making users wait when traffic is low.
Deployment: Why Shipping ML Is Hard
Training produces a model. Deployment makes it serve traffic. Deployment is where most ML projects fail. Models that work perfectly in notebooks break in production.
Deployment challenges:
Model serving infrastructure: You need a service that loads the model, accepts requests, runs inference, and returns results. This sounds simple but involves:
- Model loading: Load GB-sized models into memory at startup (slow)
- Request handling: Parse requests, validate inputs, handle malformed data
- Inference execution: Run the model (GPU vs CPU, batching, caching)
- Response formatting: Convert model outputs to API responses
- Error handling: Timeouts, OOM errors, model crashes
TensorFlow Serving, TorchServe, and Triton abstract some of this, but you still need to configure batching, resource limits, and monitoring.
A/B testing: Before fully replacing an old model, you want to compare it to the new model on real traffic. A/B testing routes a fraction of traffic (e.g., 5%) to the new model, the rest to the old model. Measure metrics: accuracy, latency, user engagement. If the new model is better, gradually increase its traffic. If it is worse, roll back.
A/B testing requires:
- Traffic splitting: Route users deterministically to models (sticky assignment based on user ID)
- Metrics tracking: Log predictions and outcomes for both models
- Statistical testing: Determine if differences are significant
Shadow deployment: Run the new model alongside the old model, but only log its predictionsâdo not serve them to users. This lets you measure the new modelâs performance without risking user experience. If the new model makes wildly wrong predictions, you catch it before users see it.
Shadow mode is safer than A/B testing but does not measure user impact (engagement, satisfaction). It measures model metrics (accuracy, precision, recall), not business metrics.
Gradual rollout: Start with 1% of traffic, monitor for issues, increase to 5%, 10%, 50%, 100%. If problems arise (latency spikes, errors, bad predictions), roll back immediately. Gradual rollout limits blast radiusâif the model fails, only 1% of users are affected.
Canary deployment: Deploy the new model to a single region or data center first. Monitor closely. If it works, deploy globally. If it fails, only one region is affected.
Rollback mechanisms: Models fail in production. Rollback must be fast. Keep the old model loaded in memory so you can switch with a config change. Do not require redeploying code or restarting services.
Example: Netflix recommendation rollout
Netflix tests new recommendation algorithms carefully:
- Offline evaluation: Test on historical data (watch patterns, ratings)
- Online A/B test: Route 1% of users to new algorithm
- Measure engagement: Watch time, retention, satisfaction surveys
- Gradual rollout: If metrics improve, increase to 10%, 50%, 100%
- Rollback plan: Keep old algorithm running, switch back if needed
Even a 1% improvement in engagement is worth millions. But deploying a worse algorithm costs millions. Rigorous testing is mandatory.
Cost at Scale: When Inference Dominates
Training costs are one-time. Inference costs are ongoing, per query. At billions of queries, inference dominates total cost.
Example: GPT-4 deployment costs
Assume:
- Training cost: $100 million (one-time, estimate)
- Inference cost: $0.03 per 1K tokens
- Average query: 200 tokens input + 200 tokens output = 400 tokens
- Cost per query: $0.012
At 1 billion queries/day:
- Daily inference cost: $12 million
- Annual inference cost: $4.4 billion
Inference costs dwarf training costs at scale. After 100 billion queries, inference costs 1000x more than training.
This is why companies optimize inference aggressively: quantization, distillation, caching, batching. A 2x speedup halves inference costs. At billions of queries, that is millions of dollars saved.
Inference hardware matters. GPUs are powerful but expensive ($10K-30K per device, high power consumption). TPUs (Googleâs Tensor Processing Units) are specialized for inferenceâfaster and cheaper per query. Edge devices (phones, cameras, cars) use even cheaper hardware (ARM CPUs, mobile GPUs). Model size and latency must fit the hardware constraints.
Example: Tesla Autopilot inference
Teslaâs Full Self-Driving (FSD) computer has two custom AI chips, each capable of 36 TOPS (trillion operations per second). The model must run on this hardware:
- Process 8 camera feeds at 36 FPS
- Run perception (object detection), prediction (trajectory forecasting), and planning
- Total latency budget: <100ms (real-time control requirement)
The model cannot be too big (must fit in on-device memory) or too slow (must meet latency). Tesla uses custom-designed networks optimized for their hardware. Training uses large GPUs in data centers, but inference runs on $1K custom chips in cars.
The constraints are hardware, not model capability. A better model that does not fit the hardware is useless.
Figure 32.1: Training vs inference comparison. Training is offline batch processing optimized for accuracy on powerful hardware. Inference is online request processing optimized for latency on constrained hardware. Deployment bridges these two worlds, requiring trade-offs between model size, accuracy, and speed.
Engineering Takeaway
Training happens once (or periodically), inference happens billions of timesâoptimize for the common case. Every millisecond of inference latency affects millions of users. Every megabyte of model size increases serving costs. Training budgets can be large (spend 0.001 per query). The model that wins is not the most accurateâit is the model that achieves acceptable accuracy at acceptable latency and cost.
Latency kills user experienceâevery 100ms matters for engagement. Users are impatient. Google, Amazon, and Facebook have measured this precisely: latency costs traffic, sales, and engagement. A slightly more accurate model that is 2x slower loses to a slightly less accurate model that is 2x faster. Latency is a hard constraint, not a nice-to-have. If your model does not meet latency requirements, it does not deploy.
Model size is your enemy at inference timeâcompression techniques are mandatory. Big models are accurate but slow and expensive. Quantization, distillation, pruning, and efficient architectures reduce model size without destroying accuracy. INT8 quantization gives 4x speedup with <1% accuracy loss. Distillation gives 2-3x speedup with 3-5% accuracy loss. These techniques are not optionalâthey are required to serve at scale.
Caching and batching are essential for surviving at scale. Caching eliminates inference for duplicate queriesâfree speedup if your workload has repetition. Batching increases GPU utilization by 10-50x, trading slight latency for massive throughput. Without these optimizations, serving billions of queries is economically infeasible. Every major ML serving system implements caching and dynamic batching.
Deployment is riskyâgradual rollouts, A/B testing, and rollback are mandatory. Models that pass offline evaluation fail in production. Users behave differently than test sets predict. Edge cases appear. Latency degrades under load. Gradual rollout (1% â 5% â 50% â 100%) limits blast radius. A/B testing measures real impact on users, not just model metrics. Rollback lets you revert instantly when things break. Deploy without these safeguards, and you will have outages.
Monitoring inference is harder than monitoring trainingâtrack latency, throughput, error rates, and data drift. Training monitoring is straightforward: loss goes down, accuracy goes up. Inference monitoring is multi-dimensional: p50/p95/p99 latency, queries per second, error rates, cache hit rates, model prediction distributions (drift detection). Inference failures are subtleâlatency spikes at 3am, prediction quality degrades slowly, edge cases increase. Real-time monitoring catches problems before users complain.
Why MLOps existsâproduction ML is reliability engineering, not model training. MLOps is the discipline of operating ML systems in production: model serving, deployment pipelines, monitoring, retraining, versioning, rollback. It is DevOps for ML. The hard problems are not âHow do I train a model?ââthey are âHow do I serve a model to a billion users with 99.9% uptime, 50ms latency, and without breaking the bank?â MLOps is the engineering that makes ML production-ready.
References and Further Reading
TFX: A TensorFlow-Based Production-Scale Machine Learning Platform Baylor, D., Breck, E., Cheng, H.-T., Fiedel, N., Foo, C. Y., Haque, Z., Haykal, S., Ispir, M., Jain, V., Koc, L., et al. (2017). KDD 2017
Why it matters: This paper from Google describes TFX (TensorFlow Extended), the production ML infrastructure used internally at Google. It covers data validation, feature engineering, training orchestration, model analysis, and serving at scale. TFX handles billions of queries per day across products like Search, YouTube, and Gmail. The paper emphasizes that the model is a small part of the systemâdata validation, monitoring, and serving infrastructure are the majority of the work. TFX became the blueprint for many companies building ML platforms.
Clipper: A Low-Latency Online Prediction Serving System Crankshaw, D., Wang, X., Zhou, G., Franklin, M. J., Gonzalez, J. E., & Stoica, I. (2017). NSDI 2017
Why it matters: Clipper is a research system from UC Berkeley that addresses the latency challenge of inference. It introduces adaptive batching (dynamically adjusting batch size to meet latency SLOs) and model caching. Clipper showed that careful system design can reduce latency by 2-4x while increasing throughput by 10x. The paper highlights that inference is a systems problem, not just a model problemâcaching, batching, and scheduling matter as much as model architecture.
Data Management Challenges in Production Machine Learning Polyzotis, N., Roy, S., Whang, S. E., & Zinkevich, M. (2017). SIGMOD 2017
Why it matters: This Google paper examines the data management challenges of production ML. It describes how models trained offline must serve online, how data distributions shift, and how monitoring must detect these changes. The paper introduces the concept of âdata debuggingââtracking data lineage, validating schemas, detecting anomalies. It argues that production ML is fundamentally about data engineering, and that most failures are data failures. The paper influenced the design of tools like TensorFlow Data Validation and Facets.
The next chapter examines evaluation: why accuracy is not enough, why benchmarks mislead, and how production metrics differ from research metrics.