Review — MetaReg (NeurIPS 2018)

Classifier-based domain generalization method

Kevin Li
5 min readApr 14, 2021

More posts about domain generalization here.

For anyone interested in reading this paper, I highly recommend to also read the Understanding deep learning requires rethinking generalization paper to get a full picture on when regularization can/cannot promote generalization.

MetaReg attempts to solve DG problems by applying meta-training on a regularizer. Different from most papers, which either the train the entire “feature extractor — classifier” pipeline or just the feature extractor, this paper aims to train a classifier to promote generalization. Results are evaluated on both Computer Vision and Natural Language Processing fields (sentiment analysis).

Intuition

Traditionally, training-testing distribution mismatch can be mitigated by regularization (weight decay) if overfitting is present, since it discourages the model from learning an over-complicated decision boundary. This method often works well since both training and validation/testing sets are from the same distribution, but what about out-of-distribution mismatch? It’s not clear how we could “smoothen” the decision boundary across distinct domains. Well, then why not make a Neural Net learn the regularization?

Method

MetaReg deals with homogeneous DG problems, by meta-training a regularizer. At the beginning of model training, we divide the dataset into p source domains, and q target domains that are not accessible during training. The entire training process involves one shared feature extractor, and p classifiers (generically called task networks in the paper), one for each of the source domains, but ultimately what we want is the shared regularizer R(θ) across all p domains. The learnt regularizer is basically a weighted L1 norm, where the weights are learnable parameters.

Network architecture

So how do we actually train a regularizer ? In each iteration, training includes the following 3 steps:

  • Update baseline model (MDL)
    We start by a naive multi-domain learning model as a baseline with a standard cross-entropy loss.
Loss function for naive multi-domain training across all source domains
  • Episode creation
    Based on the MDL model, we randomly select two distinct domains (a, b) from all p source domains as our meta-train and meta-test pair.
  • Regularizer Update
    In the pseudocode below, the first and second line are for meta-train, while the third and forth are for meta-test. β denotes the variable that holds the meta-updated classifier weight. For example, the first line of code below means that at iteration k, initialize β at the first time-step as θ from domain a. During meta-train, β is updated using both cross-entropy loss and the regularizer R(θ) for l time-steps. Then during meta-test, we assign the classifier’s weight with β (line 3), and then update the regularizer weights φ (line 4) based on the meta-updated θ, before proceeding to the next iteration.
Meta-learning for the regularizer

The overall training procedure thus looks like this. Line 2–7 corresponds to baseline model update, line 8 is creating the episode, line 10–14 corresponds to the meta-train phase for our regularizer R(θ), and line 15–16 denotes the meta-test phase.

Final training procedure

Now we get our regularizer fully trained, we can then apply it any conventional “feature extractor — classifier” pipeline for tasks on the target domains.

Result and Discussion

The team evaluates MetaReg on the PACS dataset (photo, art paintings, cartoon, sketch), and sentiment classification on Amazon Reviews dataset. Here we will focus only on the PACS experiment.

Result on PACS dataset

For the PACS experiment, it also follows the “one held-out” rule for domain split (for example, domain P as the target, and domain A,C,S as the source). We can see that on average MetaReg prevails MLDG by a solid 2.6%, which indicates that explicit regularization can be able to solve DG problems.

Thought: Does explicit regularization actually help generalization

This is the question that I have while reading. Sometimes in deep learning research literature, I find the explanation somewhat a rationalization simply based on their observably better result. Specifically, I was wondering whether explicit regularization makes sense to be an improvement direction (this is sort of a hindsight because after MetaReg, I don’t recognize any paper that approaches DG with regularization-based method).

In ICML 2017, the paper mentioned above indicated that deep neural nets can memorize the entire dataset that contains absolutely no pattern (in technical terms, the effective capacity of some reasonable models, such as Inception and AlexNet, can shatter most of the datasets). The paper conducts 5 types of randomization tests, such as random labels, or even random pixels, to emphasize the result.

Intuitively, if we first train a neural net on CIFAR-10 till convergence, and then use the same set of hyperparameters to train on the same dataset with randomized labels, one would probably expect the model to fail converging during training because there is absolutely no pattern present. Surprisingly, the model can converge without too much trouble, while its generalization performance is unsurprisingly bad. This leads to the emphasis of the paper:

Neural Net is capable of memorizing the dataset, and explicit regularization does not improve the generalization performance.

If the claim holds true, does it imply that MetaReg is actually overfitting to source domains, but not (substantially) generalizing to the target domain? Could the improvement that the paper demonstrated simply be attributed to the similarity of domains (i.e. small domain shift)? Note the score drop from cartoon to sketch domain (70% → 59%), while later models are able to achieve 75% accuracy on sketch domain with other training methods. It’s interesting to think about how the regularization’s role in generalization would evolve in the future!

What to Read Next

A natural follow-up for MetaReg is Feature-Critic Network, which introduces a feature-critic score for updating parameters, instead of just relying on conventional cross-entropy loss.

Happy coding!

--

--

Kevin Li

Student Researcher @ Berkeley AI Research | Incoming ML Engineer @ Adobe (Firefly)