Skip to content
This repository was archived by the owner on Nov 9, 2022. It is now read-only.

Commit 91bdf91

Browse files
authored
Merge pull request #1 from dronedeploy/update_callbacks
fix callbacks, add image logging
2 parents 111abad + 570abc5 commit 91bdf91

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

libs/training.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,10 @@ def train_model(dataset):
4343

4444
data = datasets.load_dataset(dataset, size, bs)
4545
encoder_model = models.resnet18
46-
learn = unet_learner(data, encoder_model, path='models', metrics=metrics, wd=wd, bottle=True, pretrained=pretrained, callback_fns=WandbCallback)
46+
learn = unet_learner(data, encoder_model, path='models', metrics=metrics, wd=wd, bottle=True, pretrained=pretrained)
4747

4848
callbacks = [
49+
WandbCallback(learn, log=None, input_type="images"),
4950
MyCSVLogger(learn, filename='baseline_model'),
5051
ExportCallback(learn, "baseline_model", monitor='f_beta'),
5152
MySaveModelCallback(learn, every='epoch', monitor='f_beta')

libs/util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Any
2-
from fastai.callbacks import CSVLogger, Callback, SaveModelCallback, TrackerCallback
2+
from fastai.callbacks import CSVLogger, SaveModelCallback, TrackerCallback
3+
from fastai.callback import Callback
34
from fastai.metrics import add_metrics
45
from fastai.torch_core import dataclass, torch, Tensor, Optional, warn
56
from fastai.basic_train import Learner
@@ -22,6 +23,7 @@ def on_epoch_end(self, epoch:int, **kwargs:Any)->None:
2223
self.best = current
2324
self.learn.export(self.model_path)
2425

26+
# TODO: does this delete some other path or just overwrite?
2527
class MySaveModelCallback(SaveModelCallback):
2628
"""Saves the model after each epoch to potentially resume training.
2729
@@ -48,6 +50,7 @@ def __init__(self, learn, filename='history'):
4850

4951
def on_train_begin(self, **kwargs):
5052
if self.path.exists():
53+
# TODO: does this open a file named "a"...?
5154
self.file = self.path.open('a')
5255
else:
5356
super().on_train_begin(**kwargs)

requirements.txt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
numpy==1.17.1
2-
torch==1.1.0
1+
fastai
32
opencv_python==3.4.3.18
3+
numpy==1.17.1
4+
sklearn
5+
torch
46
typing==3.6.6
5-
fastai==1.0.54
7+
wandb

0 commit comments

Comments
 (0)