Learn to normalize characters given a line, and using the model of the no-learning version

This commit is contained in:
Crista Lopes
2019-11-26 00:21:25 -08:00
parent e2c531fc5c
commit a18c4c6980

View File

@@ -1,5 +1,5 @@
from keras.models import Model
from keras import layers
from keras import layers, metrics
from keras.layers import Input, Dense
from keras.utils import plot_model
@@ -39,18 +39,16 @@ def decode_one_hot(x):
"""
s = []
for onehot in x:
one_index = np.where(onehot == 1) # one_index is a tuple of two things
if len(one_index[0]) > 0:
n = one_index[0][0]
c = indices_char[n]
s.append(c)
one_index = np.argmax(onehot)
c = indices_char[one_index]
s.append(c)
return ''.join(s)
def build_model():
print('Build model...')
# Normalize every character in the input, using a shared dense model
n_layer = Dense(INPUT_VOCAB_SIZE)
n_layer = Dense(INPUT_VOCAB_SIZE, activation = "softmax")
raw_inputs = []
normalized_outputs = []
for _ in range(0, LINE_SIZE):
@@ -101,7 +99,7 @@ plot_model(model, to_file='normalization.png', show_shapes=True)
# Train the model each generation and show predictions against the validation
# dataset.
val_gen2 = input_generator(1)
for iteration in range(1, 500):
for iteration in range(1, 12):
print()
print('-' * 50)
print('Iteration', iteration)
@@ -112,40 +110,33 @@ for iteration in range(1, 500):
steps_per_epoch = 20,
validation_data = val_gen,
validation_steps = 10, workers=1)
# Select 10 samples from the validation set at random so we can visualize
# errors.
# print(batch_y)
# print(preds)
# Select samples from the a set at random so we can visualize errors.
batch_x, batch_y = next(val_gen2)
for i in range(len(batch_y)):
preds = model.predict(batch_x)
expected = batch_y[i]
prediction = preds[i]
#print(preds)
# preds[preds>=0.5] = 1
# preds[preds<0.5] = 0
#q = ctable.decode(query)
correct = decode_one_hot(expected)
guess = decode_one_hot(prediction)
print('T', correct)
print('G', guess)
print('T:', correct)
print('G:', guess)
#with open(sys.argv[1]) as f:
# for line in f:
# if line.isspace(): continue
# onehots = encode_one_hot(line)
with open(sys.argv[1]) as f:
for line in f:
if line.isspace(): continue
onehots = encode_one_hot(line)
# data = [[] for _ in range(LINE_SIZE)]
# for i, c in enumerate(onehots):
# data[i].append(c)
# for j in range(len(onehots), LINE_SIZE):
# data[j].append(np.zeros((INPUT_VOCAB_SIZE)))
data = [[] for _ in range(LINE_SIZE)]
for i, c in enumerate(onehots):
data[i].append(c)
for j in range(len(onehots), LINE_SIZE):
data[j].append(np.zeros((INPUT_VOCAB_SIZE)))
# inputs = [np.array(e) for e in data]
inputs = [np.array(e) for e in data]
# preds = model.predict(inputs)
# normal = decode_one_hot(preds[0])
preds = model.predict(inputs)
normal = decode_one_hot(preds[0])
# print(decode_one_hot(onehots))
# print(normal)
print(decode_one_hot(onehots))
print(normal)