5252import torch .nn as nn
5353from asr_datamodule import LibriSpeechAsrDataModule as AsrDataModule
5454from lhotse import set_caching_enabled
55+ from lhotse .cut import Cut
5556from torchaudio .functional import (
5657 forced_align ,
5758 merge_tokens ,
@@ -166,11 +167,16 @@ def get_parser():
166167 )
167168
168169 parser .add_argument (
169- "dataset_manifests" ,
170+ "--max-utt-duration" ,
171+ type = float ,
172+ default = 60.0 ,
173+ help = "Maximal duration of an utterance in seconds, used in cut-set filtering." ,
174+ )
175+
176+ parser .add_argument (
177+ "dataset_manifest" ,
170178 type = str ,
171- nargs = "+" ,
172- help = "CutSet manifests to be aligned (CutSet with features and transcripts). "
173- "Each CutSet as a separate arg : `manifest1 mainfest2 ...`" ,
179+ help = "CutSet manifests to be aligned (CutSet with features and transcripts)." ,
174180 )
175181
176182 add_model_arguments (parser )
@@ -393,16 +399,17 @@ def align_dataset(
393399
394400def save_alignment_output (
395401 params : AttributeDict ,
396- test_set_name : str ,
402+ dataset_name : str ,
397403 results_dict : Dict [str , List [Tuple [str , List [str ], List [str ]]]],
404+ removed_cut_ids : list [str ],
398405):
399406 """
400407 Save the token alignments and per-utterance confidences.
401408 """
402409
403410 for key , results in results_dict .items ():
404411
405- alignments_filename = params .res_dir / f"alignments-{ test_set_name } .txt"
412+ alignments_filename = params .res_dir / f"alignments-{ dataset_name } .txt"
406413
407414 time_step = 0.04
408415
@@ -425,7 +432,7 @@ def save_alignment_output(
425432
426433 # ---------------------------
427434
428- confidences_filename = params .res_dir / f"confidences-{ test_set_name } .txt"
435+ confidences_filename = params .res_dir / f"confidences-{ dataset_name } .txt"
429436
430437 with open (confidences_filename , "w" , encoding = "utf8" ) as fd :
431438 print (
@@ -458,6 +465,15 @@ def save_alignment_output(
458465 file = fd ,
459466 )
460467
468+ # previously removed by `cuts.filter(remove_long_transcripts)`
469+ for utterance_key in removed_cut_ids :
470+ print (f"{ utterance_key } -2.0 -2.0 "
471+ "-2.0 "
472+ "(0,0,0,0,0) "
473+ "(0,0)" ,
474+ file = fd ,
475+ )
476+
461477 logging .info (f"The confidences are stored in `{ confidences_filename } `" )
462478
463479
@@ -605,37 +621,58 @@ def main():
605621 num_param = sum ([p .numel () for p in model .parameters ()])
606622 logging .info (f"Number of model parameters: { num_param } " )
607623
608- # we need cut ids to display recognition results.
624+ # we need cut_ids to display recognition results.
609625 args .return_cuts = True
610626 asr_datamodule = AsrDataModule (args )
611627
612- # create array of dataloaders (one per test-set)
613- testset_labels = []
614- testset_dataloaders = []
615- for testset_manifest in args .dataset_manifests :
616- label = PurePath (testset_manifest ).name # basename
617- label = label .replace (".jsonl.gz" , "" )
628+ dataset_label = PurePath (args .dataset_manifest ).name # basename
629+ dataset_label = dataset_label .replace (".jsonl.gz" , "" )
618630
619- test_cuts = asr_datamodule .load_manifest (testset_manifest )
620- test_dataloader = asr_datamodule .test_dataloaders (test_cuts )
631+ dataset_cuts = asr_datamodule .load_manifest (args .dataset_manifest )
621632
622- testset_labels .append (label )
623- testset_dataloaders .append (test_dataloader )
633+ def remove_long_transcripts (c : Cut ):
624634
625- # align
626- for test_set , test_dl in zip (testset_labels , testset_dataloaders ):
627- results_dict = align_dataset (
628- dl = test_dl ,
629- params = params ,
630- model = model ,
631- sp = sp ,
632- )
635+ if c .duration > params .max_utt_duration :
636+ logging .warning (
637+ f"Exclude cut with ID { c .id } from training. Duration: { c .duration } "
638+ )
639+ return False
640+
641+ T = ((c .num_frames - 7 ) // 2 + 1 ) // 2
642+ tokens = np .array (sp .encode (c .supervisions [0 ].text , out_type = str ))
643+ num_repeats = np .sum (tokens [1 :] == tokens [:- 1 ])
644+
645+ # For CTC `num_tokens + num_repeats` is needed. otherwise inf. in loss appears.
646+ if T < (len (tokens ) + num_repeats ):
647+ logging .warning (
648+ f"Exclude cut with ID { c .id } from training (too many supervision tokens). "
649+ f"Number of frames (before subsampling): { c .num_frames } . "
650+ f"Number of frames (after subsampling): { T } . "
651+ f"Number of tokens: { len (tokens )} "
652+ )
653+ return False
633654
634- save_alignment_output (
635- params = params ,
636- test_set_name = test_set ,
637- results_dict = results_dict ,
638- )
655+ return True
656+
657+ cut_ids_orig = set (list (dataset_cuts .ids ))
658+ dataset_cuts = dataset_cuts .filter (remove_long_transcripts )
659+ cut_ids_removed = cut_ids_orig - set (list (dataset_cuts .ids ))
660+
661+ dataset_dl = asr_datamodule .test_dataloaders (dataset_cuts )
662+
663+ results_dict = align_dataset (
664+ dl = dataset_dl ,
665+ params = params ,
666+ model = model ,
667+ sp = sp ,
668+ )
669+
670+ save_alignment_output (
671+ params = params ,
672+ dataset_name = dataset_label ,
673+ results_dict = results_dict ,
674+ removed_cut_ids = cut_ids_removed ,
675+ )
639676
640677 logging .info ("Done!" )
641678
0 commit comments