Hot-keys on this page
r m x p toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1#!/usr/bin/env python3
2from pathlib import Path
3import pandas as pd
4from torch import nn
5from fastai.data.block import DataBlock, TransformBlock
6from fastai.data.transforms import ColReader, RandomSplitter, Transform
7import torchapp as ta
8from torchapp.blocks import BoolBlock, Float32Block
9from torchapp.metrics import logit_accuracy, logit_f1
12class Normalize(Transform):
13 def __init__(self, mean=None, std=None):
14 self.mean = mean
15 self.std = std
17 def encodes(self, x):
18 return (x-self.mean) / self.std
20 def decodes(self, x):
21 return x * self.std + self.mean
24class LogisticRegressionApp(ta.TorchApp):
25 """
26 Creates a basic app to do logistic regression.
27 """
28 def dataloaders(
29 self,
30 csv: Path = ta.Param(help="The path to a CSV file with the data."),
31 x: str = ta.Param(default="x", help="The column name of the independent variable."),
32 y: str = ta.Param(default="y", help="The column name of the dependent variable."),
33 validation_proportion: float = ta.Param(
34 default=0.2, help="The proportion of the dataset to use for validation."
35 ),
36 seed: int = ta.Param(default=42, help="The random seed to use for splitting the data."),
37 batch_size: int = ta.Param(
38 default=32,
39 tune=True,
40 tune_min=8,
41 tune_max=128,
42 log=True,
43 help="The number of items to use in each batch.",
44 ),
45 ):
47 df = pd.read_csv(csv)
48 datablock = DataBlock(
49 blocks=[Float32Block(type_tfms=[Normalize(mean=df[x].mean(), std=df[x].std())]), BoolBlock],
50 get_x=ColReader(x),
51 get_y=ColReader(y),
52 splitter=RandomSplitter(validation_proportion, seed=seed),
53 )
55 return datablock.dataloaders(df, bs=batch_size)
57 def model(self) -> nn.Module:
58 """Builds a simple logistic regression model."""
59 return nn.Linear(in_features=1, out_features=1, bias=True)
61 def loss_func(self):
62 return nn.BCEWithLogitsLoss()
64 def metrics(self):
65 return [logit_accuracy, logit_f1]
67 def monitor(self):
68 return "logit_f1"
71if __name__ == "__main__":
72 LogisticRegressionApp.main()