So this is indeed nontrivial. I was wondering if there is a fast heuristic algorithm for performing grouped stratified dataset split on a multilabel dataset.
Stratification is usually performed to ensure balanced label distribution in train, val, and test splits. This post describes how you can do a stratified split for multilabel dataset using skmultilearn. Grouping is usually performed when there are some dependencies in the dataset and you don't want there to be leakage of information about validation set in the train set. The visualization here nicely points out why we want further grouped stratified, but it is for binary/multiclass. Grouping is essentially not dependent on the multilabel nature of the dataset, but I am still looking for a way to combine them for a miltilabel dataset.
Practically, assume we prefer grouping more than stratification in the final split. That is, if it's not possible to achieve both grouping and stratification, which is highly possible, we are more strict in grouping and more lenient for stratification.
The work-around I have now is to use skmultilearn to generate a stratified split, then manually tune the groupings to group those stuff together with a simple greedy for loop. This is slow and might usually be suboptimal.