Skip to content

forward() got an unexpected keyword argument 'log_probs' #220

@ChaofanTao

Description

@ChaofanTao

Environment info

  • Platform: Linux
  • Python version: 3.9.18
  • PyTorch version (GPU?): 2.0.0+cu118
  • Using GPU in script?: yes

Information

I want to train context-net on the librispeech dataset. Here is my training script located in openspeech/scripts: (First time I set dataset.dataset_download=True to download the dataset).

# sh scripts/train.sh 
python3 ./openspeech_cli/hydra_train.py \
    dataset=librispeech \
    dataset.dataset_download=False \
    dataset.dataset_path=$DATASET_PATH \
    dataset.manifest_file_path=$MANIFEST_FILE_PATH \
    tokenizer=libri_subword \
    model=contextnet \
    audio=fbank \
    lr_scheduler=warmup_reduce_lr_on_plateau \
    trainer=gpu \
   criterion=cross_entropy

It returns

-- Process 0 terminated with the following error:                                                               
Traceback (most recent call last):                                                                              
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 
69, in _wrap                                                                                                    
    fn(i, *args)                                                                                                
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/strategies/launchers
/multiprocessing.py", line 139, in _wrapping_function                                                           
    results = function(*args, **kwargs)                                                                         
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 645, in _fit_impl                                                                                         
    self._run(model, ckpt_path=self.ckpt_path)                                                                  
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1098, in _run      
     results = self._run_stage()                                                                                   File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1177, in _run_stage                                                                                       
    self._run_train()                                                                                           
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1190, in _run_train                                                                                           self._run_sanity_check()                                                                                    
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1262, in _run_sanity_check                                                                                
    val_loop.run()                                                                                              
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line
 199, in run                                                                                                    
    self.advance(*args, **kwargs)                                                                               
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/dataloader/eva
luation_loop.py", line 152, in advance                                                                          
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)                                
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/loop.py", line
 199, in run                                                                                                    
    self.advance(*args, **kwargs)                                                                               
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluati
on_epoch_loop.py", line 137, in advance                                                                         
    output = self._evaluation_step(**kwargs)                                                                    
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/loops/epoch/evaluati
on_epoch_loop.py", line 234, in _evaluation_step                                                                
    output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())                                      
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py",
 line 1480, in _call_strategy_hook                                                                              
    output = fn(*args, **kwargs)                                                                                
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/strategies/ddp_spawn
.py", line 288, in validation_step                                                                              
    return self.model(*args, **kwargs)         
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl                                                                                                 
    return forward_call(*args, **kwargs)                                                                        
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", lin
e 1156, in forward                                                                                              
    output = self._run_ddp_forward(*inputs, **kwargs)                                                             File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", lin
e 1110, in _run_ddp_forward                                                                                     
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]                                        
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501
, in _call_impl                                                                                                 
    return forward_call(*args, **kwargs)                                                                        
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/pytorch_lightning/overrides/base.py", 
line 110, in forward                                                                                            
    return self._forward_module.validation_step(*inputs, **kwargs)                                              
  File "/home/mnt/cftao/openspeech/openspeech/models/contextnet/model.py", line 133, in validation_step         
    return self.collect_outputs(                                                                                
  File "/home/mnt/cftao/openspeech/openspeech/models/openspeech_ctc_model.py", line 73, in collect_outputs      
    loss = self.criterion(                                                                                      
  File "/home/mnt/cftao/anaconda3/envs/speech/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1501
, in _call_impl                                                                                                 
    return forward_call(*args, **kwargs)                                                                        
TypeError: forward() got an unexpected keyword argument 'log_probs'    

How to solve this problem? Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions