"""
Summary
-------
Basic convolutional model, for now, for classifying the images.
Notes
------
Created By: Andrew Player
"""
from tensorflow import Tensor, keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
policy = keras.mixed_precision.Policy("mixed_float16")
keras.mixed_precision.set_global_policy(policy)
[docs]def conv2d_block(
input_tensor: Tensor, num_filters: int, kernel_size: int = 3, strides: int = 1
) -> Tensor:
"""
2D-Convolution Block for encoding / generating feature maps.
"""
x = layers.Conv2D(
filters=num_filters,
kernel_size=(kernel_size, kernel_size),
strides=(strides, strides),
kernel_initializer="random_normal",
padding="same",
)(input_tensor)
return layers.LeakyReLU()(x)
[docs]def create_eventnet(
model_name: str = "model",
tile_size: int = 512,
num_filters: int = 32,
label_count: int = 1,
learning_rate: float = 0.005,
) -> Model:
"""
Creates a basic convolutional network
"""
input = layers.Input(shape=(tile_size, tile_size, 1))
# # --------------------------------- #
# # Feature Map Generation #
# # --------------------------------- #
c1 = conv2d_block(input, num_filters, kernel_size=7, strides=2)
c2 = conv2d_block(c1, num_filters, kernel_size=3, strides=2)
c3 = conv2d_block(c2, num_filters=1, kernel_size=1, strides=1)
# # --------------------------------- #
# # Dense Hidden Layer #
# # --------------------------------- #
# TODO: Try Global Average Pooling
g1 = layers.GlobalAveragePooling2D(keepdims=True, data_format="channels_last")(c3)
# f0 = Flatten()(c3)
# d0 = Dense(512, activation='relu')(f0)
# --------------------------------- #
# Output Layer #
# --------------------------------- #
output = layers.Dense(label_count, activation="sigmoid")(g1)
# --------------------------------- #
# Model Creation and Compilation #
# --------------------------------- #
model = Model(inputs=[input], outputs=[output], name=model_name)
model.compile(
optimizer=keras.optimizers.SGD(learning_rate=learning_rate),
loss="mean_squared_error",
metrics=["acc", "mean_absolute_error"],
)
return model