Skip to content

Commit 89e8d0d

Browse files
authored
feat(tools): Analyze dataset (wenet-e2e#1452)
* feat(tools): Analyze dataset * fix(tools): Analyze dataset, bugfix * fix(tools): Analyze dataset, lintfix
1 parent c6391c0 commit 89e8d0d

File tree

1 file changed

+155
-0
lines changed

1 file changed

+155
-0
lines changed

tools/analyze_dataset.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright (c) 2022 Horizon Inc. (authors: Xingchen Song)
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
"""Analyze Dataset, Duration/TextLength/Speed etc."""
18+
19+
import argparse
20+
import logging
21+
import queue
22+
import threading
23+
24+
import torchaudio
25+
26+
from wenet.utils.file_utils import read_lists
27+
28+
29+
def get_args():
30+
parser = argparse.ArgumentParser(description='Analyze dataset')
31+
parser.add_argument('--data_type',
32+
default='wav_scp',
33+
choices=['wav_scp', 'raw', 'shard'],
34+
help='dataset type')
35+
parser.add_argument('--data_list', default=None,
36+
help='used in raw/shard mode')
37+
parser.add_argument('--wav_scp', default=None,
38+
help='used in wav_scp mode')
39+
parser.add_argument('--text', default=None,
40+
help='used in wav_scp mode')
41+
parser.add_argument('--num_thread', type=int,
42+
default=4, help='number of threads')
43+
args = parser.parse_args()
44+
print(args)
45+
return args
46+
47+
48+
def query_dict(wavs_queue, datas, wavs, texts):
49+
while not wavs_queue.empty():
50+
key = wavs_queue.get()
51+
if key in texts.keys():
52+
waveform, sample_rate = torchaudio.load(wavs[key])
53+
dur = len(waveform[0]) / sample_rate
54+
text_length = len(texts[key])
55+
speed = text_length / dur
56+
datas.append([dur, text_length, speed, key])
57+
else:
58+
logging.warning("{} not in text, pass".format(key))
59+
60+
61+
def main():
62+
args = get_args()
63+
logging.basicConfig(level=logging.DEBUG,
64+
format='%(asctime)s %(levelname)s %(message)s')
65+
datas = [] # List of [duration, textlenghth, speed, id]
66+
threads = []
67+
if args.data_type == "shard":
68+
assert args.data_list is not None
69+
lists = read_lists(args.data_list)
70+
raise NotImplementedError("Feel free to make a PR :)")
71+
elif args.data_type == "raw":
72+
assert args.data_list is not None
73+
lists = read_lists(args.data_list)
74+
raise NotImplementedError("Feel free to make a PR :)")
75+
elif args.data_type == "wav_scp":
76+
assert args.wav_scp is not None
77+
assert args.text is not None
78+
logging.info("Start Analyze {}".format(args.wav_scp))
79+
wavs, texts = {}, {}
80+
wavs_queue = queue.Queue()
81+
# wavs & wavs_queue
82+
for line in read_lists(args.wav_scp):
83+
line = line.strip().split()
84+
wavs[line[0]] = line[1]
85+
wavs_queue.put(line[0])
86+
# texts
87+
for line in read_lists(args.text):
88+
line = line.strip().split(maxsplit=1)
89+
texts[line[0]] = line[1]
90+
# threads
91+
for i in range(args.num_thread):
92+
t = threading.Thread(target=query_dict,
93+
args=(wavs_queue, datas, wavs, texts))
94+
threads.append(t)
95+
96+
for t in threads:
97+
t.start()
98+
for t in threads:
99+
t.join()
100+
101+
total_dur = sum([x[0] for x in datas])
102+
total_len = sum([x[1] for x in datas])
103+
num_datas = len(datas)
104+
logging.info("==================")
105+
datas.sort(key=lambda x: x[0]) # sort by duration
106+
logging.info("max duration: {:.3f} s (wav_id: {})".format(
107+
datas[-1][0], datas[-1][3]))
108+
logging.info("P99 duration: {:.3f} s".format(
109+
datas[int(num_datas * 0.99)][0]))
110+
logging.info("P75 duration: {:.3f} s".format(
111+
datas[int(num_datas * 0.75)][0]))
112+
logging.info("P50 duration: {:.3f} s".format(
113+
datas[int(num_datas * 0.5)][0]))
114+
logging.info("P25 duration: {:.3f} s".format(
115+
datas[int(num_datas * 0.25)][0]))
116+
logging.info("min duration: {:.3f} s (wav_id: {})".format(
117+
datas[0][0], datas[0][-1]))
118+
logging.info("avg duration: {:.3f} s".format(
119+
total_dur / len(datas)))
120+
logging.info("==================")
121+
datas.sort(key=lambda x: x[1]) # sort by text length
122+
logging.info("max text length: {} (wav_id: {})".format(
123+
datas[-1][1], datas[-1][3]))
124+
logging.info("P99 text length: {}".format(
125+
datas[int(num_datas * 0.99)][1]))
126+
logging.info("P75 text length: {}".format(
127+
datas[int(num_datas * 0.75)][1]))
128+
logging.info("P50 text length: {}".format(
129+
datas[int(num_datas * 0.5)][1]))
130+
logging.info("P25 text length: {}".format(
131+
datas[int(num_datas * 0.25)][1]))
132+
logging.info("min text length: {} (wav_id: {})".format(
133+
datas[0][1], datas[0][-1]))
134+
logging.info("avg text length: {:.3f}".format(
135+
total_len / len(datas)))
136+
logging.info("==================")
137+
datas.sort(key=lambda x: x[2]) # sort by speed
138+
logging.info("max speed: {:.3f} char/s (wav_id: {})".format(
139+
datas[-1][2], datas[-1][3]))
140+
logging.info("P99 speed: {:.3f} char/s".format(
141+
datas[int(num_datas * 0.99)][2]))
142+
logging.info("P75 speed: {:.3f} char/s".format(
143+
datas[int(num_datas * 0.75)][2]))
144+
logging.info("P50 speed: {:.3f} char/s".format(
145+
datas[int(num_datas * 0.5)][2]))
146+
logging.info("P25 speed: {:.3f} char/s".format(
147+
datas[int(num_datas * 0.25)][2]))
148+
logging.info("min speed: {:.3f} char/s (wav_id: {})".format(
149+
datas[0][2], datas[0][-1]))
150+
logging.info("avg speed: {:.3f} char/s".format(
151+
total_len / total_dur))
152+
153+
154+
if __name__ == '__main__':
155+
main()

0 commit comments

Comments
 (0)