@@ -437,6 +437,131 @@ def save(self, path: str):
437437 print ("Done." )
438438
439439
440+ class SameInnerSplitSplitter (Splitter ):
441+ r"""
442+ Splitter subclass that can be used to have multiple training runs of the
443+ same configuration at model selection time. It is not meant to be combined
444+ with a double-nested CV, for which the different inner splits are already
445+ enough to gauge the training stability of each configuration.
446+ """
447+ def split (
448+ self ,
449+ dataset : pydgn .data .dataset .DatasetInterface ,
450+ targets : np .ndarray = None ,
451+ ):
452+ r"""
453+ Computes the splits and stores them in the list fields
454+ ``self.outer_folds`` and ``self.inner_folds``.
455+ IMPORTANT: calling split() sets the seed of numpy, torch, and
456+ random for reproducibility.
457+
458+ Args:
459+ dataset (:class:`~pydgn.data.dataset.DatasetInterface`):
460+ the Dataset object
461+ targets (np.ndarray]): targets used for stratification.
462+ Default is ``None``
463+ """
464+ np .random .seed (self .seed )
465+ torch .manual_seed (self .seed )
466+ torch .cuda .manual_seed (self .seed )
467+ random .seed (self .seed )
468+
469+ idxs = range (len (dataset ))
470+
471+ stratified = self .stratify
472+ outer_idxs = np .array (idxs )
473+
474+ outer_splitter = self ._get_splitter (
475+ n_splits = self .n_outer_folds ,
476+ stratified = stratified ,
477+ eval_ratio = self .test_ratio ,
478+ ) # This is the true test (outer test)
479+
480+ for train_idxs , test_idxs in outer_splitter .split (
481+ outer_idxs , y = targets if stratified else None
482+ ):
483+
484+ assert set (train_idxs ) == set (outer_idxs [train_idxs ])
485+ assert set (test_idxs ) == set (outer_idxs [test_idxs ])
486+
487+ inner_fold_splits = []
488+ inner_idxs = outer_idxs [
489+ train_idxs
490+ ] # equals train_idxs because outer_idxs was ordered
491+ inner_targets = (
492+ targets [train_idxs ] if targets is not None else None
493+ )
494+
495+ inner_splitter = self ._get_splitter (
496+ n_splits = self .n_inner_folds ,
497+ stratified = stratified ,
498+ eval_ratio = self .inner_val_ratio ,
499+ ) # The inner "test" is, instead, the validation set
500+
501+ for inner_train_idxs , inner_val_idxs in inner_splitter .split (
502+ inner_idxs , y = inner_targets if stratified else None
503+ ):
504+ inner_fold = InnerFold (
505+ train_idxs = inner_idxs [inner_train_idxs ].tolist (),
506+ val_idxs = inner_idxs [inner_val_idxs ].tolist (),
507+ )
508+
509+ # False if empty
510+ assert not bool (
511+ set (inner_train_idxs )
512+ & set (inner_val_idxs )
513+ & set (test_idxs )
514+ )
515+ assert not bool (
516+ set (inner_idxs [inner_train_idxs ])
517+ & set (inner_idxs [inner_val_idxs ])
518+ & set (test_idxs )
519+ )
520+
521+ # we ignore the different inner splits and use only the first
522+ # one to be reused multiple times (effectively simulating
523+ # multiple training runs of the same configuration on the same
524+ # training/validation data split
525+ for _ in range (self .n_inner_folds ):
526+ inner_fold_splits .append (inner_fold )
527+ break
528+
529+ self .inner_folds .append (inner_fold_splits )
530+
531+ # Obtain outer val from outer train in an holdout fashion
532+ outer_val_splitter = self ._get_splitter (
533+ n_splits = 1 ,
534+ stratified = stratified ,
535+ eval_ratio = self .outer_val_ratio ,
536+ )
537+ outer_train_idxs , outer_val_idxs = list (
538+ outer_val_splitter .split (inner_idxs , y = inner_targets )
539+ )[0 ]
540+
541+ # False if empty
542+ assert not bool (
543+ set (outer_train_idxs ) & set (outer_val_idxs ) & set (test_idxs )
544+ )
545+ assert not bool (
546+ set (outer_train_idxs ) & set (outer_val_idxs ) & set (test_idxs )
547+ )
548+ assert not bool (
549+ set (inner_idxs [outer_train_idxs ])
550+ & set (inner_idxs [outer_val_idxs ])
551+ & set (test_idxs )
552+ )
553+
554+ np .random .shuffle (outer_train_idxs )
555+ np .random .shuffle (outer_val_idxs )
556+ np .random .shuffle (test_idxs )
557+ outer_fold = OuterFold (
558+ train_idxs = inner_idxs [outer_train_idxs ].tolist (),
559+ val_idxs = inner_idxs [outer_val_idxs ].tolist (),
560+ test_idxs = outer_idxs [test_idxs ].tolist (),
561+ )
562+ self .outer_folds .append (outer_fold )
563+
564+
440565class TemporalSplitter (Splitter ):
441566 r"""
442567 Reads the entire dataset and returns the targets. In this case, each
0 commit comments