Cookie Consent
Hi, this website uses essential cookies to ensure its proper operation and tracking cookies to understand how you interact with it. The latter will be set only after consent.
Read our Privacy Policy
Back

Your validation set won’t tell you if a model generalizes. Here’s what will.

As we all know from machine learning 101, you should split your dataset into three parts: the training, validation, and test set. You train your models on the training set. You choose your hyperparameters by selecting the best model from the validation set. Finally, you look at your accuracy (F1 score, ROC curve...) on the test set. And voilà, you’ve just achieved XYZ% accuracy.

Václav Volhejn
December 1, 2023
February 7, 2023
Learn how to protect against the most common LLM vulnerabilities

Download this guide to delve into the most common LLM security risks and ways to mitigate them.

In-context learning

As users increasingly rely on Large Language Models (LLMs) to accomplish their daily tasks, their concerns about the potential leakage of private data by these models have surged.

[Provide the input text here]

[Provide the input text here]

Lorem ipsum dolor sit amet, consectetur adipiscing elit. Suspendisse varius enim in eros elementum tristique. Duis cursus, mi quis viverra ornare, eros dolor interdum nulla, ut commodo diam libero vitae erat. Aenean faucibus nibh et justo cursus id rutrum lorem imperdiet. Nunc ut sem vitae risus tristique posuere.

Lorem ipsum dolor sit amet, Q: I had 10 cookies. I ate 2 of them, and then I gave 5 of them to my friend. My grandma gave me another 2boxes of cookies, with 2 cookies inside each box. How many cookies do I have now?

Title italic

A: At the beginning there was 10 cookies, then 2 of them were eaten, so 8 cookies were left. Then 5 cookieswere given toa friend, so 3 cookies were left. 3 cookies + 2 boxes of 2 cookies (4 cookies) = 7 cookies. Youhave 7 cookies.

English to French Translation:

Q: A bartender had 20 pints. One customer has broken one pint, another has broken 5 pints. A bartender boughtthree boxes, 4 pints in each. How many pints does bartender have now?

Lorem ipsum dolor sit amet, line first
line second
line third

Lorem ipsum dolor sit amet, Q: I had 10 cookies. I ate 2 of them, and then I gave 5 of them to my friend. My grandma gave me another 2boxes of cookies, with 2 cookies inside each box. How many cookies do I have now?

Title italic Title italicTitle italicTitle italicTitle italicTitle italicTitle italic

A: At the beginning there was 10 cookies, then 2 of them were eaten, so 8 cookies were left. Then 5 cookieswere given toa friend, so 3 cookies were left. 3 cookies + 2 boxes of 2 cookies (4 cookies) = 7 cookies. Youhave 7 cookies.

English to French Translation:

Q: A bartender had 20 pints. One customer has broken one pint, another has broken 5 pints. A bartender boughtthree boxes, 4 pints in each. How many pints does bartender have now?

Hide table of contents
Show table of contents

As we all know from machine learning 101, you should split your dataset into three parts: the training, validation, and test set. You train your models on the training set. You choose your hyperparameters by selecting the best model from the validation set. Finally, you look at your accuracy (F1 score, ROC curve...) on the test set. And voilà, you’ve just achieved XYZ% accuracy.

This is only half the story. The real-world data that your model will run on in operation will never match your dataset. And over time, it will shift. That means that the accuracy on real-world data will be lower than your training and validation accuracies: this is a non-IID version of the traditional generalization gap.

The fundamental question when testing ML models is then how to select the model with the best generalization properties. The gold standard is picking the model with the highest validation accuracy. But your validation set is lying to you: reaching a great validation accuracy doesn’t mean you’re any closer to having a production-ready model. Our experiments confirm this entirely.

The validation set covers only a small part of the inputs your model will encounter in real-world operation.

Model selection the right way

Instead, you need to go into more depth in your evaluation: measure the model’s robustness to variations in the input image. As a basic example, if the model predicted “tumor” for a certain image, it should also predict “tumor” if we vary the brightness slightly, tilt or flip the image, change the hue, and so on. Since these variations are not semantically meaningful, the model should not change its prediction.

What’s really cool about this approach is that you don’t even need true labels for the image. If the model changes its prediction after a tiny change to the image, that means it’s brittle – no matter the true label.

By measuring robustness, you can observe how the model performs on data beyond the training distribution that is likely to appear in the real world. And this is exactly what generalization means! Robustness tests allow you to go way beyond validation set accuracy and are a great predictor of model performance in the wild. If you want to get a better grasp on which of your models generalize, adding robustness tests is the way to go.

Robustness tests allow you to massively extend your testing coverage with no extra data.

To demonstrate, here’s an example. Camelyon17-WILDS [1, 2] is a histopathology dataset in which the goal is to predict whether the tissue slide contains any tumor tissue or not (binary classification). The training set contains slides from hospitals A, B, and C. The validation set has slides from hospital D, and the test set, hospital E. Images from each hospital are different, so we have a domain generalization problem:

The training, validation and test sets each contain data from different hospitals. The data distributions are therefore different for each one.

The goal is to maximize the accuracy on the test set [2]. (Of course, in real medical imaging applications, model evaluation is much more complex. We’re using this simplified setup to illustrate how regular validation set metrics can be deceiving. Using a more involved evaluation only corroborates the problem.)

And this is where things get tricky. Since we’re no longer in IID world, validation accuracy stops being reliable, since performing well on hospital D might not mean you’ll do well on the more pinkish images from hospital E.

Experiment: validation accuracy vs model robustness

Say you’ve decided to build a classification model for Camelyon17. You only have the training and validation sets, and you have to select the model to deploy in the real world, that is, on the unknown test set.

You experiment with a lot of models and in the end, narrow the choices down to ResNet-34 and ResNet-101. Obviously, you want to select the architecture that will do better on the test set. To make an informed model comparison, you train a bunch of models with different seeds for each architecture.

(To keep performance stable, all models were frozen and we only trained the final classification layer. Otherwise, “the test performance of models trained on this dataset tend to exhibit a large degree of variability over random seeds”, as the WILDS authors mention on GitHub. The other hyperparameters are the defaults from the Wilds package.)

Now you plot the validation accuracy of these models to decide which one to use. As you can see from the plot, the two groups perform equally well. To be clear, this is not hypothetical – we did train these models and all plots you see here are their actual results.

The two model architectures are indistinguishable in validation accuracy.

When you’re working off of validation accuracy alone, you should choose ResNet-34 over ResNet-101 as your model: both perform equally well and ResNet-34 is smaller and faster. But on the test set – AKA the real world – things look a lot different than the validation set. By choosing ResNet-34, you’ve thrown out a model that has a 4% higher accuracy. That’s a missed 25% reduction in error probability!

But ResNet-101 has a 25% lower error probability than ResNet-34.

What if you instead do an in-depth evaluation using robustness tests? After you assign a score to the models based on their robustness to lighting changes, blurring, image quality, viewpoint changes, and noise, you see this (lower is better):

The MLTest robustness scores of ResNet-34 and ResNet-101.
The robustness score correctly predicts that ResNet-101 is better (it has a lower risk score) without seeing the test data.

So using a robustness risk score correctly predicts that ResNet-101 is more robust – and better on real data – than ResNet-34. The score is consistent: it rates every ResNet-34 worse than every ResNet-101 across all seeds. These robustness tests are run on the validation set, not the unknown test set, so there is no data leakage. Instead, they utilize the available data in a smarter manner to make the evaluation more representative.

Getting robustness testing right isn’t straightforward. At Lakera, we’ve already done the hard work and packaged it into MLTest, a part of the Lakera platform that you can use to test your models. In just a few lines of code, you can run MLTest on your model and measure its robustness. MLTest, as the name suggests, runs a battery of diverse tests on your model, including robustness tests. You can explore their individual results using our dashboard, but you also get an overall risk score: a single number that summarizes how well your model did. This includes the robustness tests described above along with a host of other goodies.

💡 Want to assess the generalization capabilities of your own models? You can integrate MLTest in minutes.

To summarize, we showed that testing machine learning models is hard, why validation accuracy is flawed and explained why robustness tests work much better. We demonstrated this in a real example where selecting your model based on validation accuracy would make you miss out on a model with a 25% lower error rate. On the same data, robustness tests easily manage to identify the better model. Now go test some models! Also, feel free to get in touch with us at vv@lakera.ai.

Lakera LLM Security Playbook
Learn how to protect against the most common LLM vulnerabilities

Download this guide to delve into the most common LLM security risks and ways to mitigate them.

Unlock Free AI Security Guide.

Discover risks and solutions with the Lakera LLM Security Playbook.

Download Free

Explore Prompt Injection Attacks.

Learn LLM security, attack strategies, and protection tools. Includes bonus datasets.

Unlock Free Guide

Learn AI Security Basics.

Join our 10-lesson course on core concepts and issues in AI security.

Enroll Now

Evaluate LLM Security Solutions.

Use our checklist to evaluate and select the best LLM security tools for your enterprise.

Download Free

Uncover LLM Vulnerabilities.

Explore real-world LLM exploits, case studies, and mitigation strategies with Lakera.

Download Free

The CISO's Guide to AI Security

Get Lakera's AI Security Guide for an overview of threats and protection strategies.

Download Free

Explore AI Regulations.

Compare the EU AI Act and the White House’s AI Bill of Rights.

Download Free
Václav Volhejn

The CISO's Guide to AI Security

Get Lakera's AI Security Guide for an overview of threats and protection strategies.

Free Download
Read LLM Security Playbook

Learn about the most common LLM threats and how to prevent them.

Download

Explore AI Regulations.

Compare the EU AI Act and the White House’s AI Bill of Rights.

Understand AI Security Basics.

Get Lakera's AI Security Guide for an overview of threats and protection strategies.

Uncover LLM Vulnerabilities.

Explore real-world LLM exploits, case studies, and mitigation strategies with Lakera.

Optimize LLM Security Solutions.

Use our checklist to evaluate and select the best LLM security tools for your enterprise.

Master Prompt Injection Attacks.

Discover risks and solutions with the Lakera LLM Security Playbook.

Unlock Free AI Security Guide.

Discover risks and solutions with the Lakera LLM Security Playbook.

You might be interested
12
min read
Machine Learning

The ELI5 Guide to Retrieval Augmented Generation

Discover the inner workings of Retrieval Augmented Generation (RAG) and how it enhances language model responses by dynamically sourcing information from external databases.
Blessin Varkey
December 1, 2023
12
min read
Machine Learning

Why we need better data management for mission-critical AI

In order to enable mission-critical ML applications, we need to create appropriate guidance for data management, both at the formal regulatory level and in our everyday best practices.
Mateo Rojas-Carulla
December 4, 2023
Activate
untouchable mode.
Get started for free.

Lakera Guard protects your LLM applications from cybersecurity risks with a single line of code. Get started in minutes. Become stronger every day.

Join our Slack Community.

Several people are typing about AI/ML security. 
Come join us and 1000+ others in a chat that’s thoroughly SFW.