Skip to content

Metrics

File: metrics.py

Metric classes for evaluating model performance.

Accuracy

Bases: MeanMetric

Subclass of MeanMetric which defines a standard accuracy update function.

Source code in torchplate/metrics.py
34
35
36
37
38
39
40
41
class Accuracy(MeanMetric):
    """
    Subclass of MeanMetric which defines 
    a standard accuracy update function. 
    """

    def update(self, logits, labels):
        return calculate_accuracy(logits, labels)

MeanMetric

Scalar metric designed to use on a per-epoch basis and updated on per-batch basis. For getting average across the epoch.

Source code in torchplate/metrics.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class MeanMetric:
  """
  Scalar metric designed to use on a per-epoch basis 
  and updated on per-batch basis. For getting average across
  the epoch. 
  """

  def __init__(self):
    self.vals = []

  def update(self, new_val):
    self.vals.append(new_val)

  def reset(self):
    self.vals = []

  def get(self):
    mean_value = sum(self.vals)/len(self.vals)
    return mean_value

MeanMetricCustom

Bases: ABC, MeanMetric

Abstract scalar metric. Must provide calculation given logits and y.

Source code in torchplate/metrics.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class MeanMetricCustom(ABC, MeanMetric):
  """
  Abstract scalar metric. Must provide calculation given logits and y. 
  """

  def __init__(self):
    self.vals = []

  @abstractmethod
  def calculate(self, logits, y):
    # returns a value
    pass

  def update(self, logits, y):    
    self.vals.append(self.calculate(logits,y))