Learn to normalize characters given a line, and using the model of the no-learning version

This commit is contained in:
Crista Lopes
2019-11-26 00:21:25 -08:00
parent e2c531fc5c
commit a18c4c6980

View File

@@ -1,5 +1,5 @@
from keras.models import Model from keras.models import Model
from keras import layers from keras import layers, metrics
from keras.layers import Input, Dense from keras.layers import Input, Dense
from keras.utils import plot_model from keras.utils import plot_model
@@ -39,18 +39,16 @@ def decode_one_hot(x):
""" """
s = [] s = []
for onehot in x: for onehot in x:
one_index = np.where(onehot == 1) # one_index is a tuple of two things one_index = np.argmax(onehot)
if len(one_index[0]) > 0: c = indices_char[one_index]
n = one_index[0][0] s.append(c)
c = indices_char[n]
s.append(c)
return ''.join(s) return ''.join(s)
def build_model(): def build_model():
print('Build model...') print('Build model...')
# Normalize every character in the input, using a shared dense model # Normalize every character in the input, using a shared dense model
n_layer = Dense(INPUT_VOCAB_SIZE) n_layer = Dense(INPUT_VOCAB_SIZE, activation = "softmax")
raw_inputs = [] raw_inputs = []
normalized_outputs = [] normalized_outputs = []
for _ in range(0, LINE_SIZE): for _ in range(0, LINE_SIZE):
@@ -101,7 +99,7 @@ plot_model(model, to_file='normalization.png', show_shapes=True)
# Train the model each generation and show predictions against the validation # Train the model each generation and show predictions against the validation
# dataset. # dataset.
val_gen2 = input_generator(1) val_gen2 = input_generator(1)
for iteration in range(1, 500): for iteration in range(1, 12):
print() print()
print('-' * 50) print('-' * 50)
print('Iteration', iteration) print('Iteration', iteration)
@@ -112,40 +110,33 @@ for iteration in range(1, 500):
steps_per_epoch = 20, steps_per_epoch = 20,
validation_data = val_gen, validation_data = val_gen,
validation_steps = 10, workers=1) validation_steps = 10, workers=1)
# Select 10 samples from the validation set at random so we can visualize # Select samples from the a set at random so we can visualize errors.
# errors.
# print(batch_y)
# print(preds)
batch_x, batch_y = next(val_gen2) batch_x, batch_y = next(val_gen2)
for i in range(len(batch_y)): for i in range(len(batch_y)):
preds = model.predict(batch_x) preds = model.predict(batch_x)
expected = batch_y[i] expected = batch_y[i]
prediction = preds[i] prediction = preds[i]
#print(preds)
# preds[preds>=0.5] = 1
# preds[preds<0.5] = 0
#q = ctable.decode(query)
correct = decode_one_hot(expected) correct = decode_one_hot(expected)
guess = decode_one_hot(prediction) guess = decode_one_hot(prediction)
print('T', correct) print('T:', correct)
print('G', guess) print('G:', guess)
#with open(sys.argv[1]) as f: with open(sys.argv[1]) as f:
# for line in f: for line in f:
# if line.isspace(): continue if line.isspace(): continue
# onehots = encode_one_hot(line) onehots = encode_one_hot(line)
# data = [[] for _ in range(LINE_SIZE)] data = [[] for _ in range(LINE_SIZE)]
# for i, c in enumerate(onehots): for i, c in enumerate(onehots):
# data[i].append(c) data[i].append(c)
# for j in range(len(onehots), LINE_SIZE): for j in range(len(onehots), LINE_SIZE):
# data[j].append(np.zeros((INPUT_VOCAB_SIZE))) data[j].append(np.zeros((INPUT_VOCAB_SIZE)))
# inputs = [np.array(e) for e in data] inputs = [np.array(e) for e in data]
# preds = model.predict(inputs) preds = model.predict(inputs)
# normal = decode_one_hot(preds[0]) normal = decode_one_hot(preds[0])
# print(decode_one_hot(onehots)) print(decode_one_hot(onehots))
# print(normal) print(normal)