--- title: Trainer keywords: fastai sidebar: home_sidebar summary: "The Trainer class inherits `tf.keras.Model` and contains everything a model needs for training. It exposes `learner.cyclic_fit` method which trains the model using **Cyclic Learning rate** discovered by Leslie Smith." description: "The Trainer class inherits `tf.keras.Model` and contains everything a model needs for training. It exposes `learner.cyclic_fit` method which trains the model using **Cyclic Learning rate** discovered by Leslie Smith." nb_path: "nbs/03_trainer.ipynb" ---
{% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}
{% endraw %} {% raw %}

create_classifier[source]

create_classifier(base_model_fn:callable, num_classes:int, weights='imagenet', dropout=0, include_top=False, name=None)

{% endraw %} {% raw %}
{% endraw %} {% raw %}

create_cnn[source]

create_cnn(base_model:Union[str, Model], num_classes:int, drop_out=0.5, keras_applications:bool=True, pooling:str='avg', weights:Optional[str]='imagenet', name=None)

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class Trainer[source]

Trainer(*args, **kwargs) :: Model

The Trainer class inherits tf.keras.Model and contains everything a model needs for training. It exposes trainer.cyclic_fit method which trains the model using Cyclic Learning rate discovered by Leslie Smith.

Arguments: ds: Dataset object model: object of type tf.keras.Model num_classes (int, None): number of classes in the dataset. If None then will auto infer from Dataset

{% endraw %} {% raw %}
{% endraw %} {% raw %}

class InterpretModel[source]

InterpretModel(gradcam_pp:bool, learner:Trainer, clone:bool=False)

{% endraw %} {% raw %}
{% endraw %} {% raw %}
path = '/Users/aniket/Pictures/data/train'
{% endraw %} {% raw %}
from glob import glob
from chitra.core import IMAGENET_LABELS

def load_files(path):
    return glob(f'{path}/*/images/*')

def get_label(path):
    return path.split('/')[-3]


ds = Dataset(path)
No item present in the image size list
{% endraw %} {% raw %}
model = create_cnn(tf.keras.applications.MobileNetV2(include_top=True), 1000, keras_applications=False)
num_classes is ignored. returning the passed model as it is.
{% endraw %} {% raw %}
trainer = Trainer(ds=ds, model=model, num_classes=1000)
{% endraw %} {% raw %}
trainer.compile2(2, 'sgd')
Model compiled!
{% endraw %} {% raw %}
interpret = InterpretModel(False, trainer)

# path = '/data/aniket/tiny-imagenet/data/tiny-imagenet-200/train/n01641577/images/n01641577_100.JPEG'
path  = '/Users/aniket/Pictures/data/train/cat/2.jpeg'
image = Image.open(path);image
{% endraw %} {% raw %}
interpret(image, auto_resize=False)
{% endraw %} {% raw %}
IMAGENET_LABELS[285]
{% endraw %}