https://pytorch.org/audio/main/generated/torchaudio.functional.rnnt_loss.html

 

 

Rnnt loss를 구현하다가 에러를 맞닥트렸다.

RNN Transducer에 대해 공부한다고 했는데, joint이후로 메커니즘이 어떻게 되는지 이해를 못하고 있어 발생한 문제였다.

 

error 문구는 

RuntimeError: output length mismatch

 

possible path에 대한 alignment path를 그릴 때 "null"값이 존재해야한다. 이는 아래 그림을 보면 이해가 될 것.

 

알고리즘 상 null값이 존재하는 자리가 있어야 하기 때문에 rnnt loss의 logits의 shape에 target+1을 해주어야한다.

 

 

아래의 글을 참고하면 쉽게 이해할 수 있을 것이다.

https://lorenlugosch.github.io/posts/2020/11/transducer/

https://github.com/pytorch/audio/issues/3750#issuecomment-1964109967

 

I have some questions about RNNT loss. · Issue #3750 · pytorch/audio

hello I would like to ask you a question that may be somewhat trivial. The shape of logits of RNN T loss is Batch, max_seq_len, max_target_len+1, class. Why is max_target_len+1 here? Shouldn't the ...

github.com

 

계속이해를 못해서 torchaudio issue에다가 올렸고, 어떤 분이 댓글 달아주셨다.

 

다들 오늘도 파이튕

+ Recent posts