-
Notifications
You must be signed in to change notification settings - Fork 21
Expand file tree
/
Copy pathdata_generator.py
More file actions
40 lines (33 loc) · 1.46 KB
/
data_generator.py
File metadata and controls
40 lines (33 loc) · 1.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
Generate training set, test set and challenge set.
Convert the MPD's data format into our models'.
"""
import sys
import json
import os
import numpy as np
import argparse
from utils.spotify_reader import *
def fullpaths_generator(path):
filenames = os.listdir(path)
fullpaths = []
for filename in filenames:
fullpath = os.sep.join((path, filename))
fullpaths.append(fullpath)
return fullpaths
if __name__ == '__main__':
args = argparse.ArgumentParser(description="args")
args.add_argument('--datadir', type=str, default='./data', help="directory where the outputs are stored")
args.add_argument('--mpd_tr', type=str, default='./mpd_train', help="train mpd path")
args.add_argument('--mpd_te', type=str, default='./mpd_test', help="test mpd path")
args.add_argument('--mincount_trk', type=int, default=5, help='minimum count of tracks')
args.add_argument('--mincount_art', type=int, default=3, help='minimum count of artists')
args = args.parse_args()
train_fullpaths = fullpaths_generator(args.mpd_tr)
train_fold = Spotify_train(train_fullpaths, args.mincount_trk, args.mincount_art, args.datadir)
if args.mpd_te != 'NULL':
test_fullpaths = fullpaths_generator(args.mpd_te)
for test_seed_num in [1, 5, 10, 25, 50, 100]:
is_shuffle = True
test_fold = Spotify_test(test_fullpaths, args.datadir+'/train', test_seed_num, args.datadir, is_shuffle)
del test_fold