Add this one more train/validation data generator
This commit is contained in:
74
36-dnn/generate_c2w_data.py
Normal file
74
36-dnn/generate_c2w_data.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import os, sys
|
||||
import collections
|
||||
import numpy as np
|
||||
import re, string
|
||||
|
||||
MAX_LINE_SIZE = 80
|
||||
MAX_WORDS_IN_LINE = 20
|
||||
|
||||
all_chars = ""
|
||||
with open('pride-and-prejudice.txt') as f:
|
||||
all_chars = f.read().replace('\n', ' ')
|
||||
all_words = re.findall('[a-z]{2,}', all_chars.lower())
|
||||
words = list(set(all_words))
|
||||
|
||||
def generate_pair():
|
||||
# Grab a slice of the input file of size MAX_LINE_SIZE
|
||||
index = np.random.randint(0, len(all_chars) - MAX_LINE_SIZE)
|
||||
cquery = ' ' + all_chars[index:index+MAX_LINE_SIZE - 2] + ' '
|
||||
# Replace unknown words with known ones
|
||||
wquery = set(re.findall('[a-z]{2,}', cquery.lower()))
|
||||
for w in wquery:
|
||||
if w not in words[:VOCAB_SIZE]:
|
||||
# Replace ALL occurrences in query with the same replacement word
|
||||
other = words[np.random.randint(0, VOCAB_SIZE/2)]
|
||||
exp = '[^a-z]' + w + '[^a-z]'
|
||||
indices = [(m.start()+1, m.end()-1) for m in re.finditer(exp, cquery.lower())]
|
||||
for b, e in reversed(indices):
|
||||
cquery = cquery[0:b] + other + cquery[e:]
|
||||
|
||||
# Make sure the size of all chars is less than MAX_LINE_SIZE
|
||||
if len(cquery) >= MAX_LINE_SIZE:
|
||||
last_sp = cquery[:MAX_LINE_SIZE].rfind(' ')
|
||||
cquery = cquery[:last_sp] + ' ' * (MAX_LINE_SIZE - last_sp)
|
||||
|
||||
# OK, now that we have the sequence of chars, find its sequence of words
|
||||
# [TODO] Remember to remove stop words
|
||||
list_of_words = re.findall('[a-z]{2,}', cquery.lower())
|
||||
|
||||
return cquery.strip(), list_of_words
|
||||
|
||||
|
||||
def generate_data(ntrain, nval, vocab_size, data_folder, train_x, train_y, val_x, val_y):
|
||||
if not os.path.exists(data_folder):
|
||||
os.makedirs(data_folder)
|
||||
|
||||
global VOCAB_SIZE
|
||||
VOCAB_SIZE = vocab_size
|
||||
with open(train_x, 'w') as fx, open(train_y, 'w') as fy:
|
||||
for _ in range(0, ntrain):
|
||||
query, ans = generate_pair()
|
||||
fx.write(query + '\n')
|
||||
fy.write(','.join(ans) + '\n')
|
||||
|
||||
with open(val_x, 'w') as fx, open(val_y, 'w') as fy:
|
||||
for _ in range(0, nval):
|
||||
query, ans = generate_pair()
|
||||
fx.write(query + '\n')
|
||||
fy.write(','.join(ans) + '\n')
|
||||
|
||||
def main():
|
||||
# [1]: number of samples in training set
|
||||
# [2]: number of samples in validation set
|
||||
# [3]: vocabulary size
|
||||
data_folder = 'c2w_data'
|
||||
if len(sys.argv) > 3: data_folder = data_folder + "_" + sys.argv[3]
|
||||
train_x = os.path.join(data_folder, 'train_x.txt')
|
||||
train_y = os.path.join(data_folder, 'train_y.txt')
|
||||
val_x = os.path.join(data_folder, 'val_x.txt')
|
||||
val_y = os.path.join(data_folder, 'val_y.txt')
|
||||
generate_data(int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3]), data_folder, train_x, train_y, val_x, val_y)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user