Very rough learning to count words
This commit is contained in:
69
36-dnn/generate_words.py
Normal file
69
36-dnn/generate_words.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import os, sys
|
||||
import collections
|
||||
import numpy as np
|
||||
import re
|
||||
|
||||
SAMPLE_SIZE = 80
|
||||
VOCAB_SIZE = 10
|
||||
TOP = 5
|
||||
|
||||
stopwords = set(open('../stop_words.txt').read().split(','))
|
||||
all_words = re.findall('[a-z]{2,}', open('../pride-and-prejudice.txt').read().lower())
|
||||
words = list(set([w for w in all_words if w not in stopwords]))
|
||||
|
||||
def generate_pair(with_counts):
|
||||
# Grab a slice of the input file of size SAMPLE_SIZE
|
||||
index = np.random.randint(0, len(all_words) - SAMPLE_SIZE)
|
||||
querytmp = all_words[index:index+SAMPLE_SIZE]
|
||||
# Replace unknown words with known ones
|
||||
query = querytmp
|
||||
for i, w in enumerate(querytmp):
|
||||
if w not in words[:VOCAB_SIZE] and query[i] == w:
|
||||
# Replace ALL occurrences in query with the same replacement word
|
||||
other = words[np.random.randint(0, VOCAB_SIZE/2)]
|
||||
query = [other if v == w else v for v in query]
|
||||
|
||||
counts = collections.Counter(query)
|
||||
top = counts.most_common()
|
||||
if not with_counts:
|
||||
ans = list(list(zip(*top))[0])
|
||||
else:
|
||||
ans = [t[0] + " " + str(t[1]) for t in top]
|
||||
return query, ans
|
||||
|
||||
|
||||
def generate_data(data_folder, ntrain, nval, vocab_size, with_counts):
|
||||
train_x = os.path.join(data_folder, 'train_x.txt')
|
||||
train_y = os.path.join(data_folder, 'train_y.txt')
|
||||
val_x = os.path.join(data_folder, 'val_x.txt')
|
||||
val_y = os.path.join(data_folder, 'val_y.txt')
|
||||
|
||||
if not os.path.exists(data_folder):
|
||||
os.makedirs(data_folder)
|
||||
|
||||
global VOCAB_SIZE
|
||||
VOCAB_SIZE = vocab_size
|
||||
with open(train_x, 'w') as fx, open(train_y, 'w') as fy:
|
||||
for _ in range(0, ntrain):
|
||||
query, ans = generate_pair(with_counts)
|
||||
fx.write(' '.join(query) + '\n')
|
||||
fy.write(','.join(ans) + '\n')
|
||||
|
||||
with open(val_x, 'w') as fx, open(val_y, 'w') as fy:
|
||||
for _ in range(0, nval):
|
||||
query, ans = generate_pair(with_counts)
|
||||
fx.write(' '.join(query) + '\n')
|
||||
fy.write(','.join(ans) + '\n')
|
||||
|
||||
def main():
|
||||
# [1]: number of samples in training set
|
||||
# [2]: number of samples in validation set
|
||||
# [3]: vocabulary size
|
||||
# [4]: output with (1) or without (0) counts
|
||||
data_folder = 'words_data'
|
||||
if len(sys.argv) > 3: data_folder = data_folder + "_" + sys.argv[3]
|
||||
generate_data(data_folder, int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3]), bool(int(sys.argv[4])))
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user