Skip to content

Commit 3431f6a

Browse files
committed
minimum length of encoder output for CTC training
There might be consecutive repetition of symbol in the reference, and for this the CTC alignment must put a blank in between, so the reverse mapping of aligned symbols produces the original reference. I realised it recently while playing with CTC aligner from torachaudio with the noisy yodas2 dataset. To illustrate: "a a b c c d e f" - len(tokens) is 8 - but, because of duplications 'a a', 'c c' - the minimum length of encoder output is 10 - the shortest valid CTC alignment is: "a ∅ a b c ∅ c d e f"
1 parent 0904e49 commit 3431f6a

File tree

2 files changed

+85
-36
lines changed

2 files changed

+85
-36
lines changed

egs/librispeech/ASR/zipformer/ctc_align.py

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import torch.nn as nn
5353
from asr_datamodule import LibriSpeechAsrDataModule as AsrDataModule
5454
from lhotse import set_caching_enabled
55+
from lhotse.cut import Cut
5556
from torchaudio.functional import (
5657
forced_align,
5758
merge_tokens,
@@ -166,11 +167,16 @@ def get_parser():
166167
)
167168

168169
parser.add_argument(
169-
"dataset_manifests",
170+
"--max-utt-duration",
171+
type=float,
172+
default=60.0,
173+
help="Maximal duration of an utterance in seconds, used in cut-set filtering.",
174+
)
175+
176+
parser.add_argument(
177+
"dataset_manifest",
170178
type=str,
171-
nargs="+",
172-
help="CutSet manifests to be aligned (CutSet with features and transcripts). "
173-
"Each CutSet as a separate arg : `manifest1 mainfest2 ...`",
179+
help="CutSet manifests to be aligned (CutSet with features and transcripts).",
174180
)
175181

176182
add_model_arguments(parser)
@@ -393,16 +399,17 @@ def align_dataset(
393399

394400
def save_alignment_output(
395401
params: AttributeDict,
396-
test_set_name: str,
402+
dataset_name: str,
397403
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
404+
removed_cut_ids: list[str],
398405
):
399406
"""
400407
Save the token alignments and per-utterance confidences.
401408
"""
402409

403410
for key, results in results_dict.items():
404411

405-
alignments_filename = params.res_dir / f"alignments-{test_set_name}.txt"
412+
alignments_filename = params.res_dir / f"alignments-{dataset_name}.txt"
406413

407414
time_step = 0.04
408415

@@ -425,7 +432,7 @@ def save_alignment_output(
425432

426433
# ---------------------------
427434

428-
confidences_filename = params.res_dir / f"confidences-{test_set_name}.txt"
435+
confidences_filename = params.res_dir / f"confidences-{dataset_name}.txt"
429436

430437
with open(confidences_filename, "w", encoding="utf8") as fd:
431438
print(
@@ -458,6 +465,15 @@ def save_alignment_output(
458465
file=fd,
459466
)
460467

468+
# previously removed by `cuts.filter(remove_long_transcripts)`
469+
for utterance_key in removed_cut_ids:
470+
print(f"{utterance_key} -2.0 -2.0 "
471+
"-2.0 "
472+
"(0,0,0,0,0) "
473+
"(0,0)",
474+
file=fd,
475+
)
476+
461477
logging.info(f"The confidences are stored in `{confidences_filename}`")
462478

463479

@@ -605,37 +621,58 @@ def main():
605621
num_param = sum([p.numel() for p in model.parameters()])
606622
logging.info(f"Number of model parameters: {num_param}")
607623

608-
# we need cut ids to display recognition results.
624+
# we need cut_ids to display recognition results.
609625
args.return_cuts = True
610626
asr_datamodule = AsrDataModule(args)
611627

612-
# create array of dataloaders (one per test-set)
613-
testset_labels = []
614-
testset_dataloaders = []
615-
for testset_manifest in args.dataset_manifests:
616-
label = PurePath(testset_manifest).name # basename
617-
label = label.replace(".jsonl.gz", "")
628+
dataset_label = PurePath(args.dataset_manifest).name # basename
629+
dataset_label = dataset_label.replace(".jsonl.gz", "")
618630

619-
test_cuts = asr_datamodule.load_manifest(testset_manifest)
620-
test_dataloader = asr_datamodule.test_dataloaders(test_cuts)
631+
dataset_cuts = asr_datamodule.load_manifest(args.dataset_manifest)
621632

622-
testset_labels.append(label)
623-
testset_dataloaders.append(test_dataloader)
633+
def remove_long_transcripts(c: Cut):
624634

625-
# align
626-
for test_set, test_dl in zip(testset_labels, testset_dataloaders):
627-
results_dict = align_dataset(
628-
dl=test_dl,
629-
params=params,
630-
model=model,
631-
sp=sp,
632-
)
635+
if c.duration > params.max_utt_duration:
636+
logging.warning(
637+
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
638+
)
639+
return False
640+
641+
T = ((c.num_frames - 7) // 2 + 1) // 2
642+
tokens = np.array(sp.encode(c.supervisions[0].text, out_type=str))
643+
num_repeats = np.sum(tokens[1:] == tokens[:-1])
644+
645+
# For CTC `num_tokens + num_repeats` is needed. otherwise inf. in loss appears.
646+
if T < (len(tokens) + num_repeats):
647+
logging.warning(
648+
f"Exclude cut with ID {c.id} from training (too many supervision tokens). "
649+
f"Number of frames (before subsampling): {c.num_frames}. "
650+
f"Number of frames (after subsampling): {T}. "
651+
f"Number of tokens: {len(tokens)}"
652+
)
653+
return False
633654

634-
save_alignment_output(
635-
params=params,
636-
test_set_name=test_set,
637-
results_dict=results_dict,
638-
)
655+
return True
656+
657+
cut_ids_orig = set(list(dataset_cuts.ids))
658+
dataset_cuts = dataset_cuts.filter(remove_long_transcripts)
659+
cut_ids_removed = cut_ids_orig - set(list(dataset_cuts.ids))
660+
661+
dataset_dl = asr_datamodule.test_dataloaders(dataset_cuts)
662+
663+
results_dict = align_dataset(
664+
dl=dataset_dl,
665+
params=params,
666+
model=model,
667+
sp=sp,
668+
)
669+
670+
save_alignment_output(
671+
params=params,
672+
dataset_name=dataset_label,
673+
results_dict=results_dict,
674+
removed_cut_ids=cut_ids_removed,
675+
)
639676

640677
logging.info("Done!")
641678

egs/librispeech/ASR/zipformer/train.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363

6464
import k2
6565
import optim
66+
import numpy as np
6667
import sentencepiece as spm
6768
import torch
6869
import torch.multiprocessing as mp
@@ -384,7 +385,10 @@ def get_parser():
384385
)
385386

386387
parser.add_argument(
387-
"--base-lr", type=float, default=0.045, help="The base learning rate."
388+
"--base-lr",
389+
type=float,
390+
default=0.045,
391+
help="The base learning rate.",
388392
)
389393

390394
parser.add_argument(
@@ -1407,16 +1411,24 @@ def remove_short_and_long_utt(c: Cut):
14071411
# In ./zipformer.py, the conv module uses the following expression
14081412
# for subsampling
14091413
T = ((c.num_frames - 7) // 2 + 1) // 2
1410-
tokens = sp.encode(c.supervisions[0].text, out_type=str)
1414+
tokens = np.array(sp.encode(c.supervisions[0].text, out_type=str))
1415+
1416+
if args.use_ctc:
1417+
# For CTC `T < num_tokens + num_repeats` is needed, blanks are added.
1418+
num_repeats = np.sum(tokens[1:] == tokens[:-1])
1419+
min_T = len(tokens) + num_repeats
1420+
else:
1421+
# For Transducer `T < num_tokens` is okay.
1422+
min_T = len(tokens)
14111423

1412-
if T < len(tokens):
1424+
if T < min_T:
14131425
logging.warning(
1414-
f"Exclude cut with ID {c.id} from training. "
1426+
f"Exclude cut with ID {c.id} from training (too many supervision tokens). "
14151427
f"Number of frames (before subsampling): {c.num_frames}. "
14161428
f"Number of frames (after subsampling): {T}. "
14171429
f"Text: {c.supervisions[0].text}. "
14181430
f"Tokens: {tokens}. "
1419-
f"Number of tokens: {len(tokens)}"
1431+
f"Number of tokens: {len(tokens)}, min_T: {min_T}"
14201432
)
14211433
return False
14221434

0 commit comments

Comments
 (0)