How a machine learning model is trained

Written by Simon Althoff

As AI is taking a larger role in society and public discourse, and more and more people are being exposed to it, it is ever more important for everyone to understand how AI works. Understanding how Machine learning models, the most prominent type of models within AI, are trained, results in a much better grasp of the capabilities and limitations of AI. This article gives a high level explanation of how machine learning models are trained and what this means for data science projects.

What is a machine learning model

A machine learning model is in essence an approximation of a function. A function can in mathematical terms be almost anything, as long as there is an input and an output (with a few conditions). From mathematics, we recognize the form the function takes, f(x) = y, where x is our input and y is the output when f is applied to x. Working with math in school is usually straight forward since we may have an analytical expression for the function, think f(x) = 3*x + 5 for instance. However, in the real world, things are rarely that simple, and we need to make due with models; approximations that mimic the behaviour we are looking for as close as possible. This is where machine learning comes in. For simpler problems, using classical statistical methods to form models may suffice, but for more complex problems we may need more complex methods. 

As a visualization, we may have a function mapping points in input space X to points in output space Y. Points in these spaces could be numerical values, such as our example f(x)=3*x+5, but could also represent other things. The spaces represent the collections of all possible values that the inputs and outputs can take respectively. 

             

Figure 1: A function f will map a point in X to a point in Y, a model tries to mimic this behaviour by training on observations of these mappings.

If we know the point x, and the function f, we can calculate what the output y is. This would of course be immensely useful if, for instance, f gave the exact demand for a product next month (y) based on the amount sold last month (x). However, as we know, we rarely have the expression for f (if there even is one), and that is where machine learning models come in. By looking at a lot of data points, meaning pairs of corresponding points in X and Y, we can get a picture of how f works. This “picture” of f is formed by training the ML model on the observed data points. The analogy here is how we can have an idea what a picture looks like, even with parts of it covered. We might not be able to reconstruct the entire image perfectly, however we have a rough idea of what it resembles.

Figure 2: Even if we can’t see the entire picture, enough “data” is present for us to understand what it depicts. In this case our CEO Jens is telling us to report our OKRs, much like how an ML model uses data to approximate an underlying function

With training done, we hope that the model has a generalized approximation of the function, such that when we give it a new input x, which it hasn’t seen before, the output ŷ from the model is reasonably close to the real (unknown) output y. We say reasonably close, since we are dealing with uncertainty in these types of approximations. Uncertainty can come from many sources, but it can reasonably be sorted into the following categories.

  1. Uncertainty in the problem itself, such as random behaviour in the phenomenon we are trying to model. 

  2. Uncertainty from lack of information, such as missing dimensions in X. For example, if you are trying to predict how much it will rain, but you are missing data on temperature, you will likely have higher uncertainty in your predictions. 

  3. Computational and/or measurement errors. To continue the rain prediction; if we measure temperature with a bad thermometer, and/or perform the calculations on an old slide rule, we should expect higher uncertainty.

Having the correct data, with high quality and a suitable model will help minimize potential uncertainties, but it cannot be completely eliminated. If we have enough data, and avoid training pitfalls such as overfitting (which we will get to), we can realistically get an approximation that produces a likely output, but not exact. Furthermore, there are robust techniques for handling uncertainty within machine learning, you can read about that in our article on uncertainty quantification.

Why understanding training matters

Being aware of data driven solutions and understanding where to apply them within your organizations is important to remain competitive. These solutions offer immense opportunities for automation and optimization in all industries. Meanwhile, having too high hopes in the capabilities of ML may lead to bad investments or worse, serious failures in critical applications. With a grasp of how training is performed also comes an increased understanding of the limitations and possibilities of data utilization. This helps in shaping strategies, identifying valuable use cases and thus puts AI investments in the right direction. Additionally, it can help in communication between Data Scientists and other stakeholders within projects. Allowing for more streamlined and efficient project timelines. Furthermore, knowing how a model is trained gives a better understanding of what type of data is good or bad for a given use-case. Having this knowledge within the organization enables better data gathering, which leads to better ML projects, further increasing the likelihood of success for data science projects.

How model training works

A machine learning model generally has parameters, and when a model is learning, these parameters are tweaked in order to better represent the underlying function. Let’s return to our initial function f(x)=3x+5, if we were to try and approximate this function we would want a model that mimics the behaviour caused by values 3 and 5. If we made a good choice of model we would have two parameters, let's call them u and v, in the approximate function g(x)=ux+v. Then we would try to tweak u and v so that g(x) displayed a similar behaviour to f(x). That would happen if u and v came close to 3 and 5 respectively. The type and amount of parameters in a model is dictated by the model type and architecture. Choice of model is usually dependent on performance testing of a few relevant models to the problem. 

Different models train in different ways, but we will for simplicity keep to the most general case. To tune the parameters, we need something to tell us how to tune them. This is typically called the loss function, which gives us a metric on how “wrong” a certain output is. To train the model, input data is fed to the model, the model forms a prediction and the “loss” is calculated by comparing the output of the model to the “actual” output. Once that is done, the parameters of the model are tweaked such that if the same sample was fed through the model again, the resulting loss would be lower.

Figure 3: The input data x is fed to the model, the model forms a prediction which is compared with the actual output y through the loss function. The loss is then fed back to train the model.

One can benefit from constantly remembering that all the model does is to minimize the loss function. That is where the magic of machine learning lies, but it also means some things have to be considered. The main point is, the model will not be “smarter” than the data, it is entirely dependent on the information contained in the training data. That means that it is critical to have the right information, and enough of it, available in the training data, to make the model perform in the desired way. We will cover the most important aspect regarding data gathering below. But to make the most out of the available data, one should be familiar with the concepts of underfitting and overfitting.

Underfitting

This is the phenomenon when your trained model cannot capture the complexities of the underlying function it is trying to approximate. Underfitting can happen due to lack of relevant data or poor data quality, things we cover later. But it may also be the result of a too simple machine learning model. The model will minimize the loss as much as possible, if model simplicity is the inhibiting factor for further loss minimization, we will have underfitted the model. A large contributor to underfitting is thus the wrong choice of tool, like using a hammer to make a hole. While the hammer will do something, the result will be bad, and a more complex tool like a drill is needed to do a proper job.

Overfitting

If a model is overfitted, it means it has learned “too much” about the training data. This happens due to uncertainty in the data, since we may have noise and other factors that are not generalizable. The model will (if complex enough) learn to use noise (non-relevant data) to predict output which leads to performance breakdowns when seeing new data, since it has been specialized on the data present in training. This happens since the model tries to minimize the loss function over the training data, while not considering that there may be non-relevant factors present. Much like cooking, if you continue to add ingredients and make the dish very complex, the result will likely not be what you were reaching for, better to keep it simple and stop early. Thus, excessive training of a “too complex” model for the task, may lead to worse performance, all while requiring more resources due to computational complexities.

Figure 4: The figure visualizes how underfitting and overfitting leads to bad models. The overfitted curve has been trained to where it has learned to use the noise in the data, while the underfitted model is not complex enough to model the curved behaviour.

For an interactive overview of how training works and how model complexity relates to how well a model can adapt to a given function, you can try out the tensorflow playground.

Other types of learning

There are other types of learning models, however the concept of minimizing a loss function usually remains. Within reinforcement learning, for instance, we are instead interested in maximizing the reward, which has been mathematically defined. We let the model explore different decisions, which are translated into rewards, which then are fed back to the model. Over time we hope the model has a decision policy that maximizes the reward. So even though the training is different, the minimizing loss concept remains, just slightly tweaked.

Data considerations for successful training

Getting good results from ML models will first and foremost require good data. The model will learn from the data, no matter how good or bad it is, meaning that bad data may lead to serious outcomes if not handled properly. Below are three critical data dimensions that should be considered when gathering data.

Data quality

Data quality is perhaps the most important aspect when it comes to machine learning success. The data which the model will be based on should be

  • Clean, without unnecessary noise, errors and missing values 

  • Consistent across units, categorizations or similar qualities

  • Relevant to the use-case you are trying to solve

Figure 5: Data that is noisy, inconsistent or irrelevant will decrease the likelihood of a successful ML implementation

One must always assume that the model will find any connection between data points that is prevalent, if those connections are wrong it will lead to a breakdown in performance. A good example is the following

Say you are training a model to identify if someone is wearing glasses or not. When you collect the data, you photograph a set of people wearing glasses and a set of people not wearing glasses. Say that the two datasets come from two different cameras, where the one photographing people with glasses has a corrupt pixel. This corruption may not be visible to a person, but the machine learning model will latch on to it instantly. If all pictures with glasses has the corruption, and all without don't, then that will be the strongest factor for predicting if someone is wearing glasses or not. The model will likely base the majority of the prediction on the prevalence of the corruption, since that minimizes the prediction error (loss function). The model will show strong performance on the gathered dataset but will likely break down when encountering data outside it.

The same goes for relevance, if the data is not relevant, especially in the output space, the model will learn to predict the wrong thing. The more closely the data relates to the underlying function, the better the approximation will be. In many cases, one has to use “heuristic” data due to lack of necessary data. For instance, say you are trying to predict the demand for a product, but only have data on the total amount of products sold within a suite, for which the desired product is part of. You then have to bet on the desired product's demand being strongly correlated with the demand of the entire product suite. The issue then becomes to form an approximation of an approximation. Think of it as trying to reconstruct an image from a partially available, hopefully similar, image. The result will be mediocre at best. 

Figure 6: Approximating an image from a partially visible hopefully similar image is difficult, why relevant data is crucial for any ML project

Data quantity

Not having enough is the next big roadblock when it comes to having good data for a project. Simply put, one cannot get the granularity needed for a good approximation if there is too little data.

Figure 7: Too few data points will make it very difficult to produce any good approximation

Balanced Data

Balanced data can be a sneaky issue, it means that we want to avoid having too many data points in a limited “region” of the spaces when training our model. Imbalanced datasets have two major problems

  1. Areas not inside “high density” regions with regard to data, will suffer lack of performance, since the model will not know what to do in those cases.

  2. Models will likely predict output in accordance with the data that is plentiful, since that has a high chance of giving a low error in training. 

This leads to biased models which can have huge implications for organizations that use them, read our article on biased models.

Figure 8: With half of the image covered, it will be almost impossible to know what should be there. Even with a little data available, this is difficult. Our best bet is to use the majority data visible to us and “guess” what should be in the rest of the image.

Figure 9: This shows the technical issue of imbalanced data. Since the model is trying to “minimize loss” it will likely predict an output that it has seen often before, since, to the model, that output is more probable.

Conclusion

To succeed in model training there are several things to take into account. When we are trying to approximate an unknown function, good data and a suitable model is critical to get the desired performance. Noisy or inaccurate data will make it harder to get an accurate and general model. Irrelevant data will lead to predicting the wrong thing, essentially approximating an approximation. Unbalanced data will lead to models that are biased or break down in real world settings. Too few data points will make it impossible to model the underlying function reliably. A too simple model will not be able to approximate important characteristics of the function, while a too complex model may see noise and data inaccuracies as part of the underlying function. 

The model will do what it is told, it will minimize the loss function, which means that it is up to us to choose the right model and provide it with the right data, to translate a minimized loss function to the problem we are trying to solve.

Previous
Previous

Powering the future: AI’s potential in the energy sector

Next
Next

Build or buy AI: Rethinking the conventional wisdom