Module degann.networks.metrics

Expand source code
from abc import ABC
from typing import Callable

import tensorflow as tf
from tensorflow import keras

from degann.networks.losses import get_all_loss_functions


# class MaxDeviation(tf.keras.losses.Loss, ABC):
#     def __init__(
#         self,
#         reduction=tf.keras.losses.Reduction.NONE,
#         name="max_deviation",
#         treeshold=0.05,
#         **kwargs
#     ):
#         super(MaxDeviation, self).__init__(reduction=reduction, name=name, **kwargs)
#         self.treeshold = treeshold
#
#     def __call__(self, y_true, y_pred, sample_weight=None):
#         y = tf.math.divide((y_true - y_pred), tf.where(y_true == 0.0, 1.0, y_true))
#         loss = tf.math.reduce_max(tf.where(tf.abs(y) <= self.treeshold, 0.0, abs(y)))
#         return loss
#
#
# class MeanDeviation(tf.keras.losses.Loss, ABC):
#     def __init__(
#         self,
#         reduction=tf.keras.losses.Reduction.NONE,
#         name="mean_deviation",
#         treeshold=0.05,
#         **kwargs
#     ):
#         super(MeanDeviation, self).__init__(reduction=reduction, name=name, **kwargs)
#         self.treeshold = treeshold
#
#     def __call__(self, y_true, y_pred, sample_weight=None):
#         y = tf.math.divide((y_true - y_pred), tf.where(y_true == 0.0, 1.0, y_true))
#         loss = tf.math.reduce_mean(
#             tf.where(tf.abs(y) <= self.treeshold, self.treeshold, abs(y))
#         )
#         return loss


_metrics: dict = {
    "RootMeanSquaredError": keras.metrics.RootMeanSquaredError(),
    # "MaxDeviation": MaxDeviation(),
    # "MeanDeviation": MeanDeviation(),
}

_metrics = dict(get_all_loss_functions(), **_metrics)


def get_metric(name: str):
    """
    Get metric by name
    Parameters
    ----------
    name: str
        Name of metric

    Returns
    -------
    metric_class: tf.keras.losses.Loss
        Result metric
    """
    return _metrics.get(name)


def get_all_metric_functions() -> dict[str, Callable]:
    """
    Get all metrics
    Parameters
    ----------

    Returns
    -------
    metric_class: dict[str, tf.keras.losses.Loss]
        All metrics
    """
    return _metrics

Functions

def get_all_metric_functions() ‑> dict[str, typing.Callable]

Get all metrics Parameters


Returns

metric_class : dict[str, tf.keras.losses.Loss]
All metrics
Expand source code
def get_all_metric_functions() -> dict[str, Callable]:
    """
    Get all metrics
    Parameters
    ----------

    Returns
    -------
    metric_class: dict[str, tf.keras.losses.Loss]
        All metrics
    """
    return _metrics
def get_metric(name: str)

Get metric by name Parameters


name : str
Name of metric

Returns

metric_class : tf.keras.losses.Loss
Result metric
Expand source code
def get_metric(name: str):
    """
    Get metric by name
    Parameters
    ----------
    name: str
        Name of metric

    Returns
    -------
    metric_class: tf.keras.losses.Loss
        Result metric
    """
    return _metrics.get(name)