Broadly speaking, most machine learning algorithms fall into one of two categories: linear models or non-linear models. Linear models are easy to interpret, faster to train and deploy, and don’t require exorbitant amounts of compute resources. A linear model learns and produces a weighted sum of the inputs, plus a bias (intercept) term such that it maps a single input feature, X, to a single target, f(X).

learning curve equation

Linear vs. Non-Linear Models

Real world challenges, however, are complex and seldom have a linear input-output relationship. Hence, while linear models are simple and easy to implement and integrate, they are normally unable to adequately model real world systems. This shortcoming thus requires us to change our focus to non-linear models. While many non-linear machine learning models are used to solve complex real world problems, neural networks and their variants are getting amazingly good at solving real world problems in a wide range of applications and are rapidly gaining popularity. Neural networks (which contain non-linear activation functions) are more expressive, can learn complex patterns and correlations in the data, and can generalize better than linear models, but this high robustness and performance of neural networks, like other non-linear models, comes with a cost: the data requirement curse.

Most algorithms, especially neural networks and their variants, are designed to model the underlying distribution of the dataset they are being trained on. These algorithms have millions of trainable parameters that, once learned, will identify patterns and trends in the training dataset. The training dataset itself, however big, is still just a subset of all the possible instances – the general population. As the complexity of the general population and/or the complexity of the algorithms increases, we require more and more data points to reliably estimate the underlying distribution. Consider two examples: 1. A model being trained to differentiate colored boxes, say red, green and blue boxes, in a manufacturing assembly line and 2. A model being trained to differentiate tissues in a human body to assist with surgery. We can make an educated guess that the second case will require many more data points to reliably estimate the underlying distribution as compared to case 1 simply due to the possible variations and fluctuations in the data. A small dataset can be misleading or non-representative with respect to the underlying trend.

Dataset Size

Thanks to the research community and growing interest in machine learning solutions, we have seen a significant growth in the number of high quality benchmark datasets, the quality of data and the quality of annotations, but the number of data points in these benchmark datasets ranges from a few thousand to millions of data points. Furthermore, with the introduction of powerful techniques like transfer learning and data augmentation, the need for immense datasets is on the decline. However, the question remains: how much data do we require to train a robust machine learning solution?

You might be asking yourself this question for a multitude of reasons:

  1. Data collection: You may not have collected data yet and need to know the price and time necessary to collect enough data to train a high performing ML model.
  2. Data augmentation: You might have collected some data and need to know how much you need to augment your dataset.
  3. Historical data: You may already have a large dataset and need to know the optimal dataset size to reduce your computation and storage costs.
  4. Transfer learning: You might have a trained model and want to apply the model to a “similar” problem with minimal possible retraining.

In all four of these situations, knowing the required dataset size becomes a bottleneck. The required dataset size changes from one problem to the next, and is correlated with the complexity of the problem and chosen training algorithm. The bad news upfront – there is no way currently to determine this with 100% accuracy. Real world data has a lot of noise and variation which makes it very difficult to perfectly sample a training dataset. Add to that the variations in the environment, fluctuations in the data collection sensors, logging errors, data corruption, and storage errors, and it makes knowing the exact required dataset size impossible. Does that mean that we should just keep increasing the dataset size with the hope that it will improve the model performance and robustness? Fortunately, there are smarter and simpler ways to deal with this bottleneck. One such method is to use a Learning Curve Graph.

Learning Curve Graph

Learning curve graphs, loosely defined, are plots of model performance over experience or time with respect to a controlled parameter. Learning curve graphs are generally used as a diagnostic tool to assess the incremental performance of a model as the controlled parameter changes. The applications of learning curve graphs are very broad, as they can also be employed to estimate the required dataset size. In this case, the controlled parameter will be the dataset size. The diagram below shows what you can typically expect to see in such a learning curve graph.

learning curve graph for dataset size

Model Performance

The performance of a ML model will initially typically increase with the dataset size. In other words, as the size of the dataset increases, the model learns and updates its estimation of the underlying trends. At some point, the performance of the model saturates, and adding more data does not lead to a significant increase in the performance. When the model performance saturates, we can potentially assume that the general population and the training dataset now have very similar underlying distributions. Thus, the cost of further computation and storage has diminishing rewards. Our goal is to use a learning curve graph and interpolation techniques to either estimate the dataset size required for a target performance value or to find the saturation point for max performance. In part 2, we will see this approach in action with the MNIST fashion dataset and a learning curve graph experiment.