diff --git a/36-dnn/normalize-chars.py b/36-dnn/normalize-chars.py index 92e71af..c709dcc 100644 --- a/36-dnn/normalize-chars.py +++ b/36-dnn/normalize-chars.py @@ -1,5 +1,5 @@ from keras.models import Model -from keras import layers +from keras import layers, metrics from keras.layers import Input, Dense from keras.utils import plot_model @@ -39,18 +39,16 @@ def decode_one_hot(x): """ s = [] for onehot in x: - one_index = np.where(onehot == 1) # one_index is a tuple of two things - if len(one_index[0]) > 0: - n = one_index[0][0] - c = indices_char[n] - s.append(c) + one_index = np.argmax(onehot) + c = indices_char[one_index] + s.append(c) return ''.join(s) - + def build_model(): print('Build model...') # Normalize every character in the input, using a shared dense model - n_layer = Dense(INPUT_VOCAB_SIZE) + n_layer = Dense(INPUT_VOCAB_SIZE, activation = "softmax") raw_inputs = [] normalized_outputs = [] for _ in range(0, LINE_SIZE): @@ -101,7 +99,7 @@ plot_model(model, to_file='normalization.png', show_shapes=True) # Train the model each generation and show predictions against the validation # dataset. val_gen2 = input_generator(1) -for iteration in range(1, 500): +for iteration in range(1, 12): print() print('-' * 50) print('Iteration', iteration) @@ -112,40 +110,33 @@ for iteration in range(1, 500): steps_per_epoch = 20, validation_data = val_gen, validation_steps = 10, workers=1) - # Select 10 samples from the validation set at random so we can visualize - # errors. -# print(batch_y) -# print(preds) + # Select samples from the a set at random so we can visualize errors. batch_x, batch_y = next(val_gen2) for i in range(len(batch_y)): preds = model.predict(batch_x) expected = batch_y[i] prediction = preds[i] - #print(preds) -# preds[preds>=0.5] = 1 -# preds[preds<0.5] = 0 - #q = ctable.decode(query) correct = decode_one_hot(expected) guess = decode_one_hot(prediction) - print('T', correct) - print('G', guess) + print('T:', correct) + print('G:', guess) -#with open(sys.argv[1]) as f: -# for line in f: -# if line.isspace(): continue -# onehots = encode_one_hot(line) +with open(sys.argv[1]) as f: + for line in f: + if line.isspace(): continue + onehots = encode_one_hot(line) -# data = [[] for _ in range(LINE_SIZE)] -# for i, c in enumerate(onehots): -# data[i].append(c) -# for j in range(len(onehots), LINE_SIZE): -# data[j].append(np.zeros((INPUT_VOCAB_SIZE))) + data = [[] for _ in range(LINE_SIZE)] + for i, c in enumerate(onehots): + data[i].append(c) + for j in range(len(onehots), LINE_SIZE): + data[j].append(np.zeros((INPUT_VOCAB_SIZE))) -# inputs = [np.array(e) for e in data] + inputs = [np.array(e) for e in data] -# preds = model.predict(inputs) -# normal = decode_one_hot(preds[0]) + preds = model.predict(inputs) + normal = decode_one_hot(preds[0]) -# print(decode_one_hot(onehots)) -# print(normal) + print(decode_one_hot(onehots)) + print(normal)