import tensorflow as tf
from tensorflow import keras
import threading
from . import geology
import os

def icosahedral(inputs):
  layer1 = keras.layers.Dense(20,
                              activation=tf.math.asinh)(inputs)
  layer2 = keras.layers.Dense(20,
                              activation=tf.math.asinh)(layer1)
  layer3 = keras.layers.Dense(20, activation=tf.math.asinh)(
      keras.layers.concatenate([layer1, layer2]))
  layer4 = keras.layers.Dense(20, activation=tf.math.asinh)(
      keras.layers.concatenate([layer1, layer2]))
  layer5 = keras.layers.Dense(88, activation=keras.activations.relu)(
      keras.layers.concatenate([layer2, layer3]))
  layer6 = keras.layers.Dense(88, activation=keras.activations.relu)(
      keras.layers.concatenate([layer2, layer4, layer5]))
  layer7 = keras.layers.Dense(88, activation=keras.activations.relu)(
      keras.layers.concatenate([layer5, layer6]))
  layer8 = keras.layers.Dense(20, activation=tf.math.asinh)(
      keras.layers.concatenate([layer3, layer5, layer7]))
  layer9 = keras.layers.Dense(20, activation=tf.math.asinh)(
      keras.layers.concatenate([layer4, layer6, layer7]))
  layer10 = keras.layers.Dense(20, activation=tf.math.asinh)(
      keras.layers.concatenate([layer7, layer8, layer9]))
  layer11 = keras.layers.Dense(20, activation=tf.math.asinh)(
      keras.layers.concatenate([layer1, layer3, layer8, layer10]))
  layer12 = keras.layers.Dense(20, activation=tf.math.asinh)(
      keras.layers.concatenate([layer1, layer4, layer9, layer10, layer11]))
  return keras.layers.concatenate([layer11, layer12])


def model(chans = 1):
  inputs = keras.Input(shape=(3, ))
  d1 = icosahedral(inputs)
  d2 = icosahedral(inputs)
  m = keras.layers.Dense(20, activation=tf.math.asinh)(
          keras.layers.concatenate([d1,d2]))
  outputs = keras.layers.Dense(chans, activation=keras.activations.sigmoid)(m)
  return keras.Model(inputs=inputs, outputs=outputs, name="geology_model")


def shore_focused_loss(y_true, y_pred):
  y_true = tf.math.asinh((y_true - 0.1) * 5.) / 5.
  y_pred = tf.math.asinh((y_pred - 0.1) * 5.) / 5.
  return tf.math.reduce_mean((y_true - y_pred)**2)


def training_data(simresult):
  return (simresult.vertices,
          (geology.isostatic_displacement(simresult) + 1.) / 25000.)

def clone_my_model(source):
  custom_objects = {"asinh": tf.math.asinh}
  with keras.utils.custom_object_scope(custom_objects):
    return keras.models.clone_model(source)

class ModelSet:
    def __init__(self):
        self.members = {}

    def heightmap(self):
        if 'heightmap' not in self.members:
            self.members['heightmap'] = model()
            self.members['heightmap'].compile(
                    optimizer=keras.optimizers.RMSprop(),
                    loss=shore_focused_loss)
        return self.members['heightmap']

    def save(self, path):
        path = os.fspath(path)
        try:
            os.makedirs(path)
        except FileExistsError:
            None
        for name, model in self.members.items():
            model.save(f'{path}{os.sep}{name}')