diff --git a/36-dnn/generate_c2w_data.py b/36-dnn/generate_c2w_data.py new file mode 100644 index 0000000..d81f240 --- /dev/null +++ b/36-dnn/generate_c2w_data.py @@ -0,0 +1,74 @@ +import os, sys +import collections +import numpy as np +import re, string + +MAX_LINE_SIZE = 80 +MAX_WORDS_IN_LINE = 20 + +all_chars = "" +with open('pride-and-prejudice.txt') as f: + all_chars = f.read().replace('\n', ' ') +all_words = re.findall('[a-z]{2,}', all_chars.lower()) +words = list(set(all_words)) + +def generate_pair(): + # Grab a slice of the input file of size MAX_LINE_SIZE + index = np.random.randint(0, len(all_chars) - MAX_LINE_SIZE) + cquery = ' ' + all_chars[index:index+MAX_LINE_SIZE - 2] + ' ' + # Replace unknown words with known ones + wquery = set(re.findall('[a-z]{2,}', cquery.lower())) + for w in wquery: + if w not in words[:VOCAB_SIZE]: + # Replace ALL occurrences in query with the same replacement word + other = words[np.random.randint(0, VOCAB_SIZE/2)] + exp = '[^a-z]' + w + '[^a-z]' + indices = [(m.start()+1, m.end()-1) for m in re.finditer(exp, cquery.lower())] + for b, e in reversed(indices): + cquery = cquery[0:b] + other + cquery[e:] + + # Make sure the size of all chars is less than MAX_LINE_SIZE + if len(cquery) >= MAX_LINE_SIZE: + last_sp = cquery[:MAX_LINE_SIZE].rfind(' ') + cquery = cquery[:last_sp] + ' ' * (MAX_LINE_SIZE - last_sp) + + # OK, now that we have the sequence of chars, find its sequence of words + # [TODO] Remember to remove stop words + list_of_words = re.findall('[a-z]{2,}', cquery.lower()) + + return cquery.strip(), list_of_words + + +def generate_data(ntrain, nval, vocab_size, data_folder, train_x, train_y, val_x, val_y): + if not os.path.exists(data_folder): + os.makedirs(data_folder) + + global VOCAB_SIZE + VOCAB_SIZE = vocab_size + with open(train_x, 'w') as fx, open(train_y, 'w') as fy: + for _ in range(0, ntrain): + query, ans = generate_pair() + fx.write(query + '\n') + fy.write(','.join(ans) + '\n') + + with open(val_x, 'w') as fx, open(val_y, 'w') as fy: + for _ in range(0, nval): + query, ans = generate_pair() + fx.write(query + '\n') + fy.write(','.join(ans) + '\n') + +def main(): + # [1]: number of samples in training set + # [2]: number of samples in validation set + # [3]: vocabulary size + data_folder = 'c2w_data' + if len(sys.argv) > 3: data_folder = data_folder + "_" + sys.argv[3] + train_x = os.path.join(data_folder, 'train_x.txt') + train_y = os.path.join(data_folder, 'train_y.txt') + val_x = os.path.join(data_folder, 'val_x.txt') + val_y = os.path.join(data_folder, 'val_y.txt') + generate_data(int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3]), data_folder, train_x, train_y, val_x, val_y) + +if __name__ == "__main__": + main() +