Rating: 7.9/10.
Book about design patterns specific to machine learning training and productionization. Design patterns are useful since they’re tried-and-tested solutions to reoccurring problems. Even though I’ve used ML in my work for several years, some of these patterns are still new to me.
The book is aimed at ML practitioners in the industry and less towards researchers, so some of it was not so relevant to me. The explanation of design patterns were clear, but the authors gave lots of code examples using Tensorflow, Keras, and BigQuery. Since there are lots of technologies in this space, you’re probably not using this exact combination, so then the code examples can be skimmed. Another thing I found annoying was that the references were given in O’Reilly shortcodes and not a standard citation format.
Ch2: Data representation design patterns
The form of the input as seen by the model is important for model performance. Numerical features are often scaled to [-1, 1] to improve optimization stability, and logarithms / Box-Cox transformation are useful for reducing skew.
The Hashed Feature pattern is useful for categorical features of large cardinalities: basically you take a hash to map each feature value to a smaller number of buckets. Disadvantage is that everything that ends up in the same bucket will be treated the same, and also should use regularization so that buckets with zero elements don’t become unstable.
The Embedding pattern is anything that maps an input into a lower-dimensional vector space, can be applied to categorical features (by training an autoencoder), text, or images.
The Feature Cross pattern is to break down 2 features into buckets and taking their cartesian product as a single feature. This lets you learn complex patterns (like the XOR example) using linear models which would otherwise require a nonlinear model to learn.
The Multimodal Input pattern lets you handle problems where the input is a combination of categorical, text, and image data: in this case, you can extract embeddings separately and concatenate them together, then add a dense layer on top.
Ch3: Problem representation design patterns
The Reframing pattern says to consider different ways of framing the problem. For example, predicting rainfall might intuitively be a regression problem, but by framing it as a classification into buckets, you can get a distribution of possible values and not just a single value.
The Multilabel pattern deals with problems where each instance may have multiple labels. In this case, can still use classification setup, but use sigmoid at the last step instead of softmax, and set a threshold for whether to predict each label that accounts for base rate.
The Ensemble pattern lets you combine multiple models, several types of ensembles include bagging (training models on different subsets of data), boosting (sequence of models that iteratively improves prediction), and stacking (second-level model taking output of models as its input).
The Cascade pattern is when you feed the outputs of one model into another, best avoided since it captures things but sometimes useful. Later models should expect important predictions from previous models instead of clean data.
The Neutral Class pattern is to make adjustments when the values in the middle are essentially arbitrary and you should assign them all a neutral class instead of trying to learn their differences.
The Rebalancing pattern deals with problems where the classes are imbalanced. Need to choose an appropriate metric, like F1 score instead of accuracy. Downsampling is when you throw away some samples of the majority class so they’re less imbalanced, can be combined with a bagging ensemble. Upsampling is also possible using SMOTE.
Ch4: Model training patterns
The Useful Overfitting pattern is for problems where all your input and outputs are known beforehand, so we don’t care about the generalization behavior. The model essentially encodes a giant lookup table but using less space. Another use case is as a debugging technique where you overfit a small batch.
The Checkpointing pattern is to save intermediate states during a long training process. Need to save not just the model weights, but also optimizer states and anything else to resume training. Best to define an epoch in terms of number of optimizer steps so it’s not affected by changes to dataset size and batch size.
The Transfer Learning pattern leverages large pretrained models like ResNet or BERT to your task. Can use it either as a feature extractor (generate embeddings for your input) or fine-tuning.
The Distribution Strategy pattern talks about how to train on multiple GPUs / TPUs. Typically, the batch is split across multiple workers and each computes a gradient update, and the central server aggregates all the updates and sends everyone a new copy of the model weights. The synchronous strategy waits for all workers to finish their batch before proceeding, while the asynchronous strategy is more robust against slow workers or failures. It is also possible to split the model across multiple machines.
The Hyperparameter Tuning pattern is an additional outer loop to help you choose hyperparameters. Libraries can help you do this, with grid search, random search, and Bayesian optimization.
Ch5: Design patterns for resilient serving
The Stateless Serving Function pattern wraps a model around a stateless REST endpoint that can be called by client in many languages. This allows you to benefit from autoscaling technology, and you can adjust the output format to be friendly to clients. It’s also possible to wrap it in a library instead of a REST endpoint.
The Batch Serving pattern is to group together a lot of queries to the model and process them in a batch, instead of one at a time. This lets you leverage big data tools like BigQuery.
The Continuous Model Evaluation pattern means monitor the model’s performance after it’s deployed, to catch data / concept drift. The ground truth may exist implicitly, or you can get human labels. Can set it to retrain when performance drops below threshold, or periodically.
The Two-phase Prediction pattern lets you make the tradeoff between model size and performance on low-storage edge devices. Can use a small model on the device and only send a subset of data that needs the full model to process, eg: an edge model to recognize “Hey Siri” and a cloud model to do speech recognition.
The Keyed Predictions pattern is make the client pass a key for each instance, then you can return the results in any order. This is useful because ordering is expensive in distributed settings.
Ch6: Reproducibility design patterns
The Transform pattern avoids possible training-inference mismatch when preprocessing is not matched, by coupling the preprocessing logic with the model (eg: as a layer in Keras). Can cache the transformations to avoid repeating it during training.
The Repeatable Splitting pattern lets you get reproducible train-test splits, one way is take a modulo of the hash of some part of the data. Want train / test to avoid correlations, so hash lets you put everything (eg: in the same date / same user) into the same bucket.
The Bridged Schema pattern is used when you want to change the data schema and don’t want to wait for enough data to use entirely the new schema to train a new model. For example, when [card, cash] becomes [debit card, credit card, cash], then a solution is to replace old instances of “card” with a static vector containing [0.6 debit, 0.4 credit].
The Windowed Inference pattern is when a model needs a rolling window of past history to make a prediction. Then it may be more efficient to periodically update the models, then pass the model to the function making the inference.
The Workflow Pipeline pattern uses tools like TFX or Airflow to create a pipeline that gets the data, trains the model, and deploys it in one step. This makes the workflow more reproducible and documented.
The Feature Store pattern decouples the feature engineering step from model training. Whereas ad-hoc features are created on a per-project basis, these features are not easy to share and have to be duplicated for other projects, so feature store pattern stores them in a central place for models to use. Feast allows you to retrieve features in batch (for training) or in real-time (for inference).
The Model Versioning pattern is deploy different versions of a model to different REST endpoints so that you can have different versions running concurrently (eg: for backward compatibility or A/B testing purposes).
Ch7: Responsible AI
The Heuristic Benchmark pattern trains a very simple model (like predict the median) so that the actual model’s performance can be communicated.
The Explainable Predictions pattern uses explainable AI tools to improve trust in the model. If the model is simple, the parameters can be interpreted directly, like linear regression. Otherwise, Shapely values and Integrated Gradients are useful tools (the latter requires the model to be differentiable). The SHAP tool gives feature attribution values for a prediction, telling you how much each input feature contributed.
The Fairness Lens pattern analyzes the model for fairness when the data is biased. The What-If Tool lets you explore the model’s output and performance across different slices of the data.
Ch8: Connected patterns
The ML life cycle is an iterative process to align the models with business objectives so that the models provide business value. As the business becomes more mature in how they use ML, the workflow goes from manual experimentation to more automated pipelines, and different design patterns are involved.