summaryrefslogtreecommitdiff
path: root/src/neuralnetwork/training.py
diff options
context:
space:
mode:
authordavidovski <david@sendula.com>2022-11-30 10:06:56 +0000
committerdavidovski <david@sendula.com>2022-11-30 10:06:56 +0000
commit290c68795d8100cc97b8b53d80f331e536fc71b1 (patch)
treebf0068c4c9121406df9bc90f5c159fd93de8a61e /src/neuralnetwork/training.py
Added files to repositoryHEADmain
Diffstat (limited to 'src/neuralnetwork/training.py')
-rw-r--r--src/neuralnetwork/training.py373
1 files changed, 373 insertions, 0 deletions
diff --git a/src/neuralnetwork/training.py b/src/neuralnetwork/training.py
new file mode 100644
index 0000000..85c6dbd
--- /dev/null
+++ b/src/neuralnetwork/training.py
@@ -0,0 +1,373 @@
+import os
+import types
+import json
+import random
+
+from util import *
+from rnn import *
+
+cuda = False
+num_processes = 12
+
+
+class Country:
+ def __init__(self, path):
+ self.path = path
+ self.datasets = {
+ "female": os.path.join(path, "female.txt"),
+ "male": os.path.join(path, "male.txt"),
+ "surname": os.path.join(path, "surname.txt"),
+ }
+
+ # initialise the pre and post proccess function lists
+ self.pre_process = []
+ self.post_process = []
+
+ # load the data file
+ self.load_data()
+
+ # load the alphabet file
+ self.alphabet = self.load_alphabet()
+
+ # initialise the rnn models
+ hidden_size = 128
+ self.rnn = {}
+
+ for dataset in self.datasets:
+ self.rnn[dataset] = RNN(
+ len(self.alphabet), hidden_size, len(self.alphabet))
+
+ """ Load the alphabet from the alphabet file
+ Returns:
+ alphabet: (str[]) the list of the letters/characters to use while training
+ """
+
+ def load_alphabet(self):
+ alphabet_path = os.path.join(self.path, "alphabet.txt")
+
+ # check if the alphabet file exists, if not, raise an exception
+ if os.path.exists(alphabet_path):
+ with open(alphabet_path, "r") as alphabet_file:
+ # Split the file by lines: on letter/character should be on each line
+ letters = alphabet_file.read().split("\n")
+ return letters
+ else:
+ raise Exception(
+ f"The alphabet file {alphabet_path} could not be found")
+ return []
+
+ """ load the data from the data file
+ """
+
+ def load_data(self):
+ data_path = os.path.join(self.path, "data.json")
+ if os.path.exists(data_path):
+ with open(data_path, "r") as data_file:
+ j = json.loads(data_file.read())
+
+ # match the imported global function with the ones listed in the json file
+ for pre in j["pre"]:
+ if pre in globals():
+ func = globals()[pre]
+
+ # check if the requested object is a function
+ if type(func) is types.FunctionType:
+ self.pre_process.append(func)
+ else:
+ raise Exception(
+ f"The function '{pre}' is not a function")
+ else:
+ # If the function was not loaded, throw an exception
+ raise Exception(
+ f"The function '{pre}' was not loaded or does not exist")
+
+ for post in j["post"]:
+ if post in globals():
+ func = globals()[post]
+
+ # check if the requested object is a function
+ if type(func) is types.FunctionType:
+ self.post_process.append(func)
+ else:
+ raise Exception(
+ f"The function '{post}' is not a function")
+ else:
+ # If the function was not loaded, throw an exception
+ raise Exception(
+ f"The function '{post}' was not loaded or does not exist")
+
+ else:
+ # load the default pre and post processing functions
+ self.pre_process = [uncapitalise]
+ self.post_process = [deserialise, capitalise]
+
+ """ List all the names from a given category file
+ Args:
+ category: (str) the category to select names from
+ Returns:
+ data: (str[]) an array containing all of the names from the given category file
+ """
+
+ def get_names(self, category):
+ with open(self.datasets[category], "r") as datafile:
+ return [name for name in datafile.read().split("\n")]
+
+ """ List all names in all categories
+ Returns:
+ data: (str[]) an array with all of the names in this country's datasets
+ """
+
+ def get_all(self):
+ return [name for k in self.datasets for name in self.get_names(k)]
+
+ """ Pre-process a name for training
+ Args:
+ name: the name loaded from the dataset
+ Returns:
+ name: the name after being processed
+ """
+
+ def postprocess(self, name):
+ for f in self.post_process:
+ name = f(name)
+ return name
+
+ """ Post-process a name after sampling
+ Args:
+ name: the name output from the recurrent neural network
+ Returns:
+ name: the name after being processed
+ """
+
+ def preprocess(self, name):
+ for f in self.pre_process:
+ name = f(name)
+ return name
+
+ """ Train a neural network on the given dataset
+ Args:
+ category: (str) the category to sample training names from
+ """
+
+ def train(self, category):
+ # select the RNN model to be training on
+ rnn = self.rnn[category]
+
+ # load names from that dataset and pre proccess them
+ print("preprocessing names...")
+ names = [self.preprocess(name) for name in self.get_names(category)]
+ print(f"processed {len(names)} names!")
+
+ # calculate optimum number of iterations (using 80% of whole dataset)
+ iters = int(len(names) * 0.8)
+
+ # start training
+ learn_names(rnn, names, self.alphabet, iterations=iters,
+ num_processes=num_processes)
+
+ """ Sample a name from the neural network with a given starting letter
+ Args:
+ category: (str) the category to sample generated names from
+ Returns:
+ name: the output from the neural network
+ """
+
+ def sample(self, category, start_letter):
+
+ # select the RNN model to be sampling from
+ rnn = self.rnn[category]
+
+ # set the random factor of the RNN to randomise names that are generated
+ rnn.random_factor = 0.7
+
+ # call the rnn sample function to generate a single name
+ name = sample(rnn, self.alphabet, start_letter)
+
+ # post process the name and return
+ return self.postprocess(name)
+
+ """ Load the rnn from its file
+ Args:
+ category: (str) the category to load
+ parent_directory: (str) where to find the model
+ """
+
+ def load_rnn(self, category, parent_directory):
+ model_file = os.path.join(parent_directory, f"{category}.pt")
+ self.rnn[category] = torch.load(model_file)
+
+ """ Save the rnn of a given category to its file
+ Args:
+ category: (str) the category to save
+ parent_directory: (str) the directory to save the model file to
+ """
+
+ def save_rnn(self, category, parent_directory):
+ rnn = self.rnn[category]
+ model_file = os.path.join(parent_directory, f"{category}.pt")
+ torch.save(rnn, model_file)
+
+
+def get_countries():
+ return {
+ country: Country(os.path.join(countries_path, country)) for country in os.listdir(countries_path) if os.path.isdir(os.path.join(countries_path, country))
+ }
+
+
+""" train all of the datasets from a specific country
+ Args:
+ country: (Country)
+"""
+
+
+def train_country(country, name):
+ datasets = country.datasets
+ for dataset in datasets:
+ print(f"Training {dataset} in {name}")
+ country.train(dataset)
+
+ print(f"Finished training on {dataset}... saving...", end="")
+ path = os.path.join("data", "models", name)
+
+ # check if the path already exists before trying to make directories
+ if not os.path.exists(path):
+ os.makedirs(path)
+
+ country.save_rnn(dataset, path)
+ print("saved!")
+
+
+def sample_country(country, country_name, number_of_samples=10000):
+
+ datasets = country.datasets
+ for dataset in datasets:
+
+ # ensure that the model exists before sampling
+ path = os.path.join("data", "models", country_name)
+ if os.path.exists(os.path.join(path, dataset + ".pt")):
+
+ # load the country's rnn
+ country.load_rnn(dataset, path)
+
+ # load the names from the country's dataset, and pre-process them
+ names = [country.preprocess(name)
+ for name in country.get_names(dataset)]
+
+ # make a dictionary full of start letters and their frequency
+ start_letters = {}
+
+ for name in names:
+ if len(name) > 0:
+ start_letter = name[0]
+
+ # if the start letter isn't already in the dictionary, add it with value 1
+ if start_letter in start_letters:
+ start_letters[start_letter] += 1
+ else:
+ start_letters[start_letter] = 1
+
+ # turn each integer count into a float where: letter_weight=frequency/total_names
+ total = len(names)
+
+ for letter in start_letters:
+ weight = float(start_letters[letter] / total)
+ start_letters[letter] = weight
+
+ # sample names from the RNN
+ sampled_names = []
+
+ for i in range(number_of_samples):
+ try:
+ letter = weighted_choice(start_letters)
+ sample = country.sample(dataset, letter)
+ sampled_names.append(sample)
+ except:
+ pass
+
+ # remove duplicate names
+ sampled_names = list(dict.fromkeys(sampled_names))
+
+ # create a sqlite connection
+ connection = sqlite3.connect(database)
+
+ # always close the connection when finished
+ with connection:
+ cursor = connection.cursor()
+ for name in sampled_names:
+ sql = "INSERT INTO names(Name, Origin, Category) VALUES(?, ?, ?)"
+
+ # insert the current name and options into the database
+ cursor.execute(sql, (name, country_name, dataset))
+
+ # commit changes and save the database
+ connection.commit()
+
+ print(
+ f"Saved {len(sampled_names)} names for {country_name}/{dataset}")
+
+ else:
+ print(f"the model: {country_name}/{dataset} was not found.")
+
+
+countries_path = "data/datasets"
+database = os.path.join("data", "names.db")
+if __name__ == "__main__":
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+
+ # allow processes on this model to share memory
+ torch.multiprocessing.set_start_method('spawn')
+
+ # List all the directories containing country datasets to populate the countries dictionary
+ countries = get_countries()
+
+ country_count = len(countries)
+ # Display debug information
+ print(f"Loaded {country_count} countries!")
+
+ # list all countries in neat collumns
+ collumns = 4
+ width = 14
+ i = 0
+ for country in countries:
+ i += 1
+
+ # print the country and then its index
+ print(country, end="")
+
+ # organise into rows and collumns
+ if i % collumns == 0:
+ print("")
+ else:
+ # separate collumns with spaces
+ print(" " * (width - len(country)), end="")
+
+ # keep asking until the country selection is valid
+ good_selection = False
+ while not good_selection:
+ # prompt user to select a country to train, or train all
+ country_selection = input(
+ "select the name of a country to train on, or (all) to train on all countries: ")
+
+ good_selection = True
+ selected_countries = []
+
+ # if the user selected all, then add all countries to list, if not, add the selected country
+ if country_selection.lower() == "all":
+ [selected_countries.append(country) for country in countries]
+ elif country_selection.lower() in countries:
+ selected_countries.append(country_selection)
+ else:
+ print("Country not found, try again")
+ good_selection = False
+
+ choice = input("(t)rain on data, or (s)ample from weights?")
+
+ if choice.lower()[0] == "t":
+ for country in selected_countries:
+ train_country(countries[country], country)
+
+ elif choice.lower()[0] == "s":
+ create_table(database)
+ for country in selected_countries:
+ sample_country(countries[country], country)