import sys
import threading
import tensorflow as tf
from tensorflow import keras
from . import model
from . import simsave
from . import drawmap


def train(m, filename):
  source = simsave.SimResult(open(filename, 'r'))
  training_data = model.training_data(source)
  m.heightmap().fit(x=training_data[0], y=training_data[1], batch_size=100, epochs=5000)
  m.save('models')

  inputs = tf.reshape(drawmap.inputs_equirectangular(2048, 1024),
                      (2048 * 1024, 3))
  outputs = tf.reshape(m.heightmap()(inputs), (1024, 2048, 1))
  outputs = tf.cast(drawmap.colourize_heightmap(outputs) * 255, tf.uint8)
  outputs = tf.io.encode_png(outputs).numpy()
  f = open("map.png", 'wb')
  f.write(outputs)
  f.close()


if len(sys.argv) > 1:
  m = model.ModelSet()
  bg_thread = threading.Thread(target=train, args=(m, sys.argv[1]), kwargs={})
  bg_thread.start()
  drawmap.run(m)
  bg_thread.join()
else:
  print("Expecting filename of tectonic.js save", file=sys.stderr)