From 290c68795d8100cc97b8b53d80f331e536fc71b1 Mon Sep 17 00:00:00 2001 From: davidovski Date: Wed, 30 Nov 2022 10:06:56 +0000 Subject: Added files to repository --- src/neuralnetwork/training.py | 373 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 373 insertions(+) create mode 100644 src/neuralnetwork/training.py (limited to 'src/neuralnetwork/training.py') diff --git a/src/neuralnetwork/training.py b/src/neuralnetwork/training.py new file mode 100644 index 0000000..85c6dbd --- /dev/null +++ b/src/neuralnetwork/training.py @@ -0,0 +1,373 @@ +import os +import types +import json +import random + +from util import * +from rnn import * + +cuda = False +num_processes = 12 + + +class Country: + def __init__(self, path): + self.path = path + self.datasets = { + "female": os.path.join(path, "female.txt"), + "male": os.path.join(path, "male.txt"), + "surname": os.path.join(path, "surname.txt"), + } + + # initialise the pre and post proccess function lists + self.pre_process = [] + self.post_process = [] + + # load the data file + self.load_data() + + # load the alphabet file + self.alphabet = self.load_alphabet() + + # initialise the rnn models + hidden_size = 128 + self.rnn = {} + + for dataset in self.datasets: + self.rnn[dataset] = RNN( + len(self.alphabet), hidden_size, len(self.alphabet)) + + """ Load the alphabet from the alphabet file + Returns: + alphabet: (str[]) the list of the letters/characters to use while training + """ + + def load_alphabet(self): + alphabet_path = os.path.join(self.path, "alphabet.txt") + + # check if the alphabet file exists, if not, raise an exception + if os.path.exists(alphabet_path): + with open(alphabet_path, "r") as alphabet_file: + # Split the file by lines: on letter/character should be on each line + letters = alphabet_file.read().split("\n") + return letters + else: + raise Exception( + f"The alphabet file {alphabet_path} could not be found") + return [] + + """ load the data from the data file + """ + + def load_data(self): + data_path = os.path.join(self.path, "data.json") + if os.path.exists(data_path): + with open(data_path, "r") as data_file: + j = json.loads(data_file.read()) + + # match the imported global function with the ones listed in the json file + for pre in j["pre"]: + if pre in globals(): + func = globals()[pre] + + # check if the requested object is a function + if type(func) is types.FunctionType: + self.pre_process.append(func) + else: + raise Exception( + f"The function '{pre}' is not a function") + else: + # If the function was not loaded, throw an exception + raise Exception( + f"The function '{pre}' was not loaded or does not exist") + + for post in j["post"]: + if post in globals(): + func = globals()[post] + + # check if the requested object is a function + if type(func) is types.FunctionType: + self.post_process.append(func) + else: + raise Exception( + f"The function '{post}' is not a function") + else: + # If the function was not loaded, throw an exception + raise Exception( + f"The function '{post}' was not loaded or does not exist") + + else: + # load the default pre and post processing functions + self.pre_process = [uncapitalise] + self.post_process = [deserialise, capitalise] + + """ List all the names from a given category file + Args: + category: (str) the category to select names from + Returns: + data: (str[]) an array containing all of the names from the given category file + """ + + def get_names(self, category): + with open(self.datasets[category], "r") as datafile: + return [name for name in datafile.read().split("\n")] + + """ List all names in all categories + Returns: + data: (str[]) an array with all of the names in this country's datasets + """ + + def get_all(self): + return [name for k in self.datasets for name in self.get_names(k)] + + """ Pre-process a name for training + Args: + name: the name loaded from the dataset + Returns: + name: the name after being processed + """ + + def postprocess(self, name): + for f in self.post_process: + name = f(name) + return name + + """ Post-process a name after sampling + Args: + name: the name output from the recurrent neural network + Returns: + name: the name after being processed + """ + + def preprocess(self, name): + for f in self.pre_process: + name = f(name) + return name + + """ Train a neural network on the given dataset + Args: + category: (str) the category to sample training names from + """ + + def train(self, category): + # select the RNN model to be training on + rnn = self.rnn[category] + + # load names from that dataset and pre proccess them + print("preprocessing names...") + names = [self.preprocess(name) for name in self.get_names(category)] + print(f"processed {len(names)} names!") + + # calculate optimum number of iterations (using 80% of whole dataset) + iters = int(len(names) * 0.8) + + # start training + learn_names(rnn, names, self.alphabet, iterations=iters, + num_processes=num_processes) + + """ Sample a name from the neural network with a given starting letter + Args: + category: (str) the category to sample generated names from + Returns: + name: the output from the neural network + """ + + def sample(self, category, start_letter): + + # select the RNN model to be sampling from + rnn = self.rnn[category] + + # set the random factor of the RNN to randomise names that are generated + rnn.random_factor = 0.7 + + # call the rnn sample function to generate a single name + name = sample(rnn, self.alphabet, start_letter) + + # post process the name and return + return self.postprocess(name) + + """ Load the rnn from its file + Args: + category: (str) the category to load + parent_directory: (str) where to find the model + """ + + def load_rnn(self, category, parent_directory): + model_file = os.path.join(parent_directory, f"{category}.pt") + self.rnn[category] = torch.load(model_file) + + """ Save the rnn of a given category to its file + Args: + category: (str) the category to save + parent_directory: (str) the directory to save the model file to + """ + + def save_rnn(self, category, parent_directory): + rnn = self.rnn[category] + model_file = os.path.join(parent_directory, f"{category}.pt") + torch.save(rnn, model_file) + + +def get_countries(): + return { + country: Country(os.path.join(countries_path, country)) for country in os.listdir(countries_path) if os.path.isdir(os.path.join(countries_path, country)) + } + + +""" train all of the datasets from a specific country + Args: + country: (Country) +""" + + +def train_country(country, name): + datasets = country.datasets + for dataset in datasets: + print(f"Training {dataset} in {name}") + country.train(dataset) + + print(f"Finished training on {dataset}... saving...", end="") + path = os.path.join("data", "models", name) + + # check if the path already exists before trying to make directories + if not os.path.exists(path): + os.makedirs(path) + + country.save_rnn(dataset, path) + print("saved!") + + +def sample_country(country, country_name, number_of_samples=10000): + + datasets = country.datasets + for dataset in datasets: + + # ensure that the model exists before sampling + path = os.path.join("data", "models", country_name) + if os.path.exists(os.path.join(path, dataset + ".pt")): + + # load the country's rnn + country.load_rnn(dataset, path) + + # load the names from the country's dataset, and pre-process them + names = [country.preprocess(name) + for name in country.get_names(dataset)] + + # make a dictionary full of start letters and their frequency + start_letters = {} + + for name in names: + if len(name) > 0: + start_letter = name[0] + + # if the start letter isn't already in the dictionary, add it with value 1 + if start_letter in start_letters: + start_letters[start_letter] += 1 + else: + start_letters[start_letter] = 1 + + # turn each integer count into a float where: letter_weight=frequency/total_names + total = len(names) + + for letter in start_letters: + weight = float(start_letters[letter] / total) + start_letters[letter] = weight + + # sample names from the RNN + sampled_names = [] + + for i in range(number_of_samples): + try: + letter = weighted_choice(start_letters) + sample = country.sample(dataset, letter) + sampled_names.append(sample) + except: + pass + + # remove duplicate names + sampled_names = list(dict.fromkeys(sampled_names)) + + # create a sqlite connection + connection = sqlite3.connect(database) + + # always close the connection when finished + with connection: + cursor = connection.cursor() + for name in sampled_names: + sql = "INSERT INTO names(Name, Origin, Category) VALUES(?, ?, ?)" + + # insert the current name and options into the database + cursor.execute(sql, (name, country_name, dataset)) + + # commit changes and save the database + connection.commit() + + print( + f"Saved {len(sampled_names)} names for {country_name}/{dataset}") + + else: + print(f"the model: {country_name}/{dataset} was not found.") + + +countries_path = "data/datasets" +database = os.path.join("data", "names.db") +if __name__ == "__main__": + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + # allow processes on this model to share memory + torch.multiprocessing.set_start_method('spawn') + + # List all the directories containing country datasets to populate the countries dictionary + countries = get_countries() + + country_count = len(countries) + # Display debug information + print(f"Loaded {country_count} countries!") + + # list all countries in neat collumns + collumns = 4 + width = 14 + i = 0 + for country in countries: + i += 1 + + # print the country and then its index + print(country, end="") + + # organise into rows and collumns + if i % collumns == 0: + print("") + else: + # separate collumns with spaces + print(" " * (width - len(country)), end="") + + # keep asking until the country selection is valid + good_selection = False + while not good_selection: + # prompt user to select a country to train, or train all + country_selection = input( + "select the name of a country to train on, or (all) to train on all countries: ") + + good_selection = True + selected_countries = [] + + # if the user selected all, then add all countries to list, if not, add the selected country + if country_selection.lower() == "all": + [selected_countries.append(country) for country in countries] + elif country_selection.lower() in countries: + selected_countries.append(country_selection) + else: + print("Country not found, try again") + good_selection = False + + choice = input("(t)rain on data, or (s)ample from weights?") + + if choice.lower()[0] == "t": + for country in selected_countries: + train_country(countries[country], country) + + elif choice.lower()[0] == "s": + create_table(database) + for country in selected_countries: + sample_country(countries[country], country) -- cgit v1.2.1