summaryrefslogtreecommitdiff
path: root/src/neuralnetwork/training.py
blob: 85c6dbd81aaa7792a3fdf9af7e0cf6b3e805ba67 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
import os
import types
import json
import random

from util import *
from rnn import *

cuda = False
num_processes = 12


class Country:
    def __init__(self, path):
        self.path = path
        self.datasets = {
            "female": os.path.join(path, "female.txt"),
            "male": os.path.join(path, "male.txt"),
            "surname": os.path.join(path, "surname.txt"),
        }

        # initialise the pre and post proccess function lists
        self.pre_process = []
        self.post_process = []

        # load the data file
        self.load_data()

        # load the alphabet file
        self.alphabet = self.load_alphabet()

        # initialise the rnn models
        hidden_size = 128
        self.rnn = {}

        for dataset in self.datasets:
            self.rnn[dataset] = RNN(
                len(self.alphabet), hidden_size, len(self.alphabet))

    """ Load the alphabet from the alphabet file
        Returns:
            alphabet: (str[]) the list of the letters/characters to use while training
    """

    def load_alphabet(self):
        alphabet_path = os.path.join(self.path, "alphabet.txt")

        # check if the alphabet file exists, if not, raise an exception
        if os.path.exists(alphabet_path):
            with open(alphabet_path, "r") as alphabet_file:
                # Split the file by lines: on letter/character should be on each line
                letters = alphabet_file.read().split("\n")
                return letters
        else:
            raise Exception(
                f"The alphabet file {alphabet_path} could not be found")
            return []

    """ load the data from the data file
    """

    def load_data(self):
        data_path = os.path.join(self.path, "data.json")
        if os.path.exists(data_path):
            with open(data_path, "r") as data_file:
                j = json.loads(data_file.read())

                # match the imported global function with the ones listed in the json file
                for pre in j["pre"]:
                    if pre in globals():
                        func = globals()[pre]

                        # check if the requested object is a function
                        if type(func) is types.FunctionType:
                            self.pre_process.append(func)
                        else:
                            raise Exception(
                                f"The function '{pre}' is not a function")
                    else:
                        # If the function was not loaded, throw an exception
                        raise Exception(
                            f"The function '{pre}' was not loaded or does not exist")

                for post in j["post"]:
                    if post in globals():
                        func = globals()[post]

                        # check if the requested object is a function
                        if type(func) is types.FunctionType:
                            self.post_process.append(func)
                        else:
                            raise Exception(
                                f"The function '{post}' is not a function")
                    else:
                        # If the function was not loaded, throw an exception
                        raise Exception(
                            f"The function '{post}' was not loaded or does not exist")

        else:
            # load the default pre and post processing functions
            self.pre_process = [uncapitalise]
            self.post_process = [deserialise, capitalise]

    """ List all the names from a given category file
        Args:
            category: (str) the category to select names from
        Returns:
            data: (str[]) an array containing all of the names from the given category file
    """

    def get_names(self, category):
        with open(self.datasets[category], "r") as datafile:
            return [name for name in datafile.read().split("\n")]

    """ List all names in all categories
        Returns:
            data: (str[]) an array with all of the names in this country's datasets
    """

    def get_all(self):
        return [name for k in self.datasets for name in self.get_names(k)]

    """ Pre-process a name for training
        Args:
            name: the name loaded from the dataset
        Returns:
            name: the name after being processed
    """

    def postprocess(self, name):
        for f in self.post_process:
            name = f(name)
        return name

    """ Post-process a name after sampling
        Args:
            name: the name output from the recurrent neural network
        Returns:
            name: the name after being processed
    """

    def preprocess(self, name):
        for f in self.pre_process:
            name = f(name)
        return name

    """ Train a neural network on the given dataset
        Args:
            category: (str) the category to sample training names from
    """

    def train(self, category):
        # select the RNN model to be training on
        rnn = self.rnn[category]

        # load names from that dataset and pre proccess them
        print("preprocessing names...")
        names = [self.preprocess(name) for name in self.get_names(category)]
        print(f"processed {len(names)} names!")

        # calculate optimum number of iterations (using 80% of whole dataset)
        iters = int(len(names) * 0.8)

        # start training
        learn_names(rnn, names, self.alphabet, iterations=iters,
                    num_processes=num_processes)

    """ Sample a name from the neural network with a given starting letter
        Args:
            category: (str) the category to sample generated names from
        Returns:
            name: the output from the neural network
    """

    def sample(self, category, start_letter):

        # select the RNN model to be sampling from
        rnn = self.rnn[category]

        # set the random factor of the RNN to randomise names that are generated
        rnn.random_factor = 0.7

        # call the rnn sample function to generate a single name
        name = sample(rnn, self.alphabet, start_letter)

        # post process the name and return
        return self.postprocess(name)

    """ Load the rnn from its file
        Args:
            category: (str) the category to load
            parent_directory: (str) where to find the model
    """

    def load_rnn(self, category, parent_directory):
        model_file = os.path.join(parent_directory, f"{category}.pt")
        self.rnn[category] = torch.load(model_file)

    """ Save the rnn of a given category to its file
        Args:
            category: (str) the category to save
            parent_directory: (str) the directory to save the model file to  
    """

    def save_rnn(self, category, parent_directory):
        rnn = self.rnn[category]
        model_file = os.path.join(parent_directory, f"{category}.pt")
        torch.save(rnn, model_file)


def get_countries():
    return {
        country: Country(os.path.join(countries_path, country)) for country in os.listdir(countries_path) if os.path.isdir(os.path.join(countries_path, country))
    }


""" train all of the datasets from a specific country
    Args:
        country: (Country) 
"""


def train_country(country, name):
    datasets = country.datasets
    for dataset in datasets:
        print(f"Training {dataset} in {name}")
        country.train(dataset)

        print(f"Finished training on {dataset}... saving...", end="")
        path = os.path.join("data", "models", name)

        # check if the path already exists before trying to make directories
        if not os.path.exists(path):
            os.makedirs(path)

        country.save_rnn(dataset, path)
        print("saved!")


def sample_country(country, country_name, number_of_samples=10000):

    datasets = country.datasets
    for dataset in datasets:

        # ensure that the model exists before sampling
        path = os.path.join("data", "models", country_name)
        if os.path.exists(os.path.join(path, dataset + ".pt")):

            # load the country's rnn
            country.load_rnn(dataset, path)

            # load the names from the country's dataset, and pre-process them
            names = [country.preprocess(name)
                     for name in country.get_names(dataset)]

            # make a dictionary full of start letters and their frequency
            start_letters = {}

            for name in names:
                if len(name) > 0:
                    start_letter = name[0]

                    # if the start letter isn't already in the dictionary, add it with value 1
                    if start_letter in start_letters:
                        start_letters[start_letter] += 1
                    else:
                        start_letters[start_letter] = 1

            # turn each integer count into a float where: letter_weight=frequency/total_names
            total = len(names)

            for letter in start_letters:
                weight = float(start_letters[letter] / total)
                start_letters[letter] = weight

            # sample names from the RNN
            sampled_names = []

            for i in range(number_of_samples):
                try:
                    letter = weighted_choice(start_letters)
                    sample = country.sample(dataset, letter)
                    sampled_names.append(sample)
                except:
                    pass

            # remove duplicate names
            sampled_names = list(dict.fromkeys(sampled_names))

            # create a sqlite connection
            connection = sqlite3.connect(database)

            # always close the connection when finished
            with connection:
                cursor = connection.cursor()
                for name in sampled_names:
                    sql = "INSERT INTO names(Name, Origin, Category) VALUES(?, ?, ?)"

                    # insert the current name and options into the database
                    cursor.execute(sql, (name, country_name, dataset))

                # commit changes and save the database
                connection.commit()

            print(
                f"Saved {len(sampled_names)} names for {country_name}/{dataset}")

        else:
            print(f"the model: {country_name}/{dataset} was not found.")


countries_path = "data/datasets"
database = os.path.join("data", "names.db")
if __name__ == "__main__":
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")

        # allow processes on this model to share memory
        torch.multiprocessing.set_start_method('spawn')

        # List all the directories containing country datasets to populate the countries dictionary
        countries = get_countries()

        country_count = len(countries)
        # Display debug information
        print(f"Loaded {country_count} countries!")

        # list all countries in neat collumns
        collumns = 4
        width = 14
        i = 0
        for country in countries:
            i += 1

            # print the country and then its index
            print(country, end="")

            # organise into rows and collumns
            if i % collumns == 0:
                print("")
            else:
                # separate collumns with spaces
                print(" " * (width - len(country)), end="")

        # keep asking until the country selection is valid
        good_selection = False
        while not good_selection:
            # prompt user to select a country to train, or train all
            country_selection = input(
                "select the name of a country to train on, or (all) to train on all countries: ")

            good_selection = True
            selected_countries = []

            # if the user selected all, then add all countries to list, if not, add the selected country
            if country_selection.lower() == "all":
                [selected_countries.append(country) for country in countries]
            elif country_selection.lower() in countries:
                selected_countries.append(country_selection)
            else:
                print("Country not found, try again")
                good_selection = False

            choice = input("(t)rain on data, or (s)ample from weights?")

            if choice.lower()[0] == "t":
                for country in selected_countries:
                    train_country(countries[country], country)

            elif choice.lower()[0] == "s":
                create_table(database)
                for country in selected_countries:
                    sample_country(countries[country], country)