Creating a price model using k-Nearest Neighbours + Genetic Algorithm

Chapter 8 of Programming Collective Intelligence (PCI) explains the usage and implementation of the k-Nearest Neighbours algorithm. (k-NN).

Simply put:

k-NN is a classification algorithm that uses (k) for the number of neighbours to determine what class an item will belong to.  To determine the neighbours to be used the algorithm uses a distance / similarity score function, in this example (Euclidian Distance).

PCI takes it a little further to help with accuracy in some scenarios. This includes the usage of a weighted average of the neighbours, as well as then using either simulated annealing or genetic algorithms to determine the best weights, building on Optimization techniques – Simulated Annealing & Genetic Algorithms (As with all the previous chapters the code is in my github repository).

So the similarity score function looked like (slightly different to the one used earlier, which was inverted to return 1 if equals):

package net.briandupreez.pci.chapter8;

import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class EuclideanDistanceScore {

    /**
     * Determine distance between list of points.
     *
     * @param list1 first list
     * @param list2 second list
     * @return distance between the two lists between 0 and 1... 0 being identical.
     */
    public static double distanceList(final List<Double> list1, final List<Double> list2) {
        if (list1.size() != list2.size()) {
            throw new RuntimeException("Same number of values required.");
        }

        double sumOfAllSquares = 0;
        for (int i = 0; i < list1.size(); i++) {
            sumOfAllSquares += Math.pow(list2.get(i) - list1.get(i), 2);
        }
        return Math.sqrt(sumOfAllSquares);
    }
}

The simulated annealing and genetic algorithm code I updated as I originally implemented them using Ints… (lesson learnt when doing anything it do with ML or AI, stick to doubles).

package net.briandupreez.pci.chapter8;

import org.javatuples.Pair;

import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.SortedMap;
import java.util.TreeMap;

/**
 * Created with IntelliJ IDEA.
 * User: bdupreez
 * Date: 2013/07/05
 * Time: 9:08 PM
 */
public class Optimization {

    public List<Pair<Integer, Integer>> createDomain() {

        final List<Pair<Integer, Integer>> domain = new ArrayList<>(4);
        for (int i = 0; i < 4; i++) {
            final Pair<Integer, Integer> pair = new Pair<>(0, 10);
            domain.add(pair);
        }
        return domain;
    }

    /**
     * Simulated Annealing
     *
     * @param domain list of tuples with min and max
     * @return (global minimum)
     */
    public Double[] simulatedAnnealing(final List<Pair<Integer, Integer>> domain, final double startingTemp, final double cool, final int step) {

        double temp = startingTemp;
        //create random
        Double[] sol = new Double[domain.size()];
        Random random = new Random();
        for (int r = 0; r < domain.size(); r++) {
            sol[r] = Double.valueOf(random.nextInt(19));
        }

        while (temp > 0.1) {
            //pick a random indices
            int i = random.nextInt(domain.size() - 1);

            //pick a directions + or -
            int direction = random.nextInt(step) % 2 == 0 ? -(random.nextInt(step)) : random.nextInt(1);

            Double[] cloneSolr = sol.clone();
            cloneSolr[i] += direction;
            if (cloneSolr[i] < domain.get(i).getValue0()) {
                cloneSolr[i] = Double.valueOf(domain.get(i).getValue0());
            } else if (cloneSolr[i] > domain.get(i).getValue1()) {
                cloneSolr[i] = Double.valueOf(domain.get(i).getValue1());
            }

            //calc current and new cost
            double currentCost = scheduleCost(sol);
            double newCost = scheduleCost(cloneSolr);
            System.out.println("Current: " + currentCost + " New: " + newCost);

            double probability = Math.pow(Math.E, -(newCost - currentCost) / temp);

            // Is it better, or does it make the probability cutoff?
            if (newCost < currentCost || Math.random() < probability) {
                sol = cloneSolr;
            }
            temp = temp * cool;
        }
        return sol;
    }

    public double scheduleCost(Double[] sol) {
        NumPredict numPredict = new NumPredict();
        final List<Map<String,List<Double>>> rescale = numPredict.rescale(numPredict.createWineSet2(), Arrays.asList(sol));
        return numPredict.crossValidate(rescale,0.1,100);
    }

    public Double[] geneticAlgorithm(final List<Pair<Integer, Integer>> domain, final int populationSize,
                                     final int step, final double elite, final int maxIter, final double mutProb) {

        List<Double[]> pop = createPopulation(domain.size(), populationSize);

        final int topElite = new Double(elite * populationSize).intValue();

        final SortedMap<Double, Double[]> scores = new TreeMap<>();
        for (int i = 0; i < maxIter; i++) {
            for (final Double[] run : pop) {
                scores.put(scheduleCost(run), run);
            }
            pop = determineElite(topElite, scores);

            while (pop.size() < populationSize) {

                final Random random = new Random();
                if (Math.random() < mutProb) {

                    final int ran = random.nextInt(topElite);
                    pop.add(mutate(domain, pop.get(ran), step));

                } else {
                    final int ran1 = random.nextInt(topElite);
                    final int ran2 = random.nextInt(topElite);
                    pop.add(crossover(pop.get(ran1), pop.get(ran2), domain.size()));
                }
            }
            System.out.println(scores);
        }

        return scores.entrySet().iterator().next().getValue();


    }

    /**
     * Grab the elites
     *
     * @param topElite how many
     * @param scores   sorted on score
     * @return best ones
     */
    private List<Double[]> determineElite(int topElite, SortedMap<Double, Double[]> scores) {

        Double toKey = null;
        int index = 0;
        for (final Double key : scores.keySet()) {
            if (index++ == topElite) {
                toKey = key;
                break;
            }
        }
        scores = scores.headMap(toKey);
        return new ArrayList<>(scores.values());
    }

    /**
     * Create a population
     *
     * @param arraySize the array size
     * @param popSize   the population size
     * @return a random population
     */
    private List<Double[]> createPopulation(final int arraySize, final int popSize) {
        final List<Double[]> returnList = new ArrayList<>();
        for (int i = 0; i < popSize; i++) {
            Double[] sol = new Double[arraySize];
            Random random = new Random();
            for (int r = 0; r < arraySize; r++) {
                sol[r] = Double.valueOf(random.nextInt(8));
            }
            returnList.add(sol);
        }
        return returnList;
    }

    /**
     * Mutate a value.
     *
     * @param domain the domain
     * @param vec    the data to be mutated
     * @param step   the step
     * @return mutated array
     */
    private Double[] mutate(final List<Pair<Integer, Integer>> domain, final Double[] vec, final int step) {
        final Random random = new Random();
        int i = random.nextInt(domain.size() - 1);
        Double[] retArr = vec.clone();
        if (Math.random() < 0.5 && (vec[1] - step) > domain.get(i).getValue0()) {
            retArr[i] -= step;
        } else if (vec[i] + step < domain.get(i).getValue1()) {
            retArr[i] += step;
        }
        return vec;
    }

    /**
     * Cross over parts of each array
     *
     * @param arr1 array 1
     * @param arr2 array 2
     * @param max  max value
     * @return new array
     */
    private Double[] crossover(final Double[] arr1, final Double[] arr2, final int max) {
        final Random random = new Random();
        int i = random.nextInt(max);
        return concatArrays(Arrays.copyOfRange(arr1, 0, i), Arrays.copyOfRange(arr2, i, arr2.length));
    }

    /**
     * Concat 2 arrays
     *
     * @param first  first
     * @param second second
     * @return new combined array
     */
    private Double[] concatArrays(final Double[] first, final Double[] second) {
        Double[] result = Arrays.copyOf(first, first.length + second.length);
        System.arraycopy(second, 0, result, first.length, second.length);
        return result;
    }
}

Then finally putting it all together my Java implementation of the PCI example

package net.briandupreez.pci.chapter8;

import org.javatuples.Pair;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Random;

/**
 * NumPredict.
 * User: bdupreez
 * Date: 2013/08/12
 * Time: 8:29 AM
 */
public class NumPredict {


    /**
     * Determine the wine price.
     *
     * @param rating rating
     * @param age    age
     * @return the price
     */
    public double winePrice(final double rating, final double age) {
        final double peakAge = rating - 50;

        //Calculate the price based on rating
        double price = rating / 2;

        if (age > peakAge) {
            //goes bad in 10 years
            price *= 5 - (age - peakAge) / 2;
        } else {
            //increases as it reaches its peak
            price *= 5 * ((age + 1)) / peakAge;
        }

        if (price < 0) {
            price = 0.0;
        }

        return price;
    }


    /**
     * Data Generator
     *
     * @return data
     */
    @SuppressWarnings("unchecked")
    public List<Map<String, List<Double>>> createWineSet1() {

        final List<Map<String, List<Double>>> wineList = new ArrayList<>();
        for (int i = 0; i < 300; i++) {

            double rating = Math.random() * 50 + 50;
            double age = Math.random() * 50;

            double price = winePrice(rating, age);
            price *= (Math.random() * 0.2 + 0.9);

            final Map<String, List<Double>> map = new HashMap<>();
            final List<Double> input = new LinkedList<>();
            input.add(rating);
            input.add(age);
            map.put("input", input);
            final List<Double> result = new ArrayList();
            result.add(price);
            map.put("result", result);
            wineList.add(map);

        }


        return wineList;
    }


    /**
     * Data Generator
     *
     * @return data
     */
    @SuppressWarnings("unchecked")
    public List<Map<String, List<Double>>> createWineSet2() {

        final List<Map<String, List<Double>>> wineList = new ArrayList<>();
        for (int i = 0; i < 300; i++) {

            double rating = Math.random() * 50 + 50;
            double age = Math.random() * 50;
            final Random random = new Random();
            double aisle = (double) random.nextInt(20);
            double[] sizes = new double[]{375.0, 750.0, 1500.0};
            double bottleSize = sizes[random.nextInt(3)];


            double price = winePrice(rating, age);
            price *= (bottleSize / 750);
            price *= (Math.random() * 0.2 + 0.9);

            final Map<String, List<Double>> map = new HashMap<>();
            final List<Double> input = new LinkedList<>();
            input.add(rating);
            input.add(age);
            input.add(aisle);
            input.add(bottleSize);
            map.put("input", input);
            final List<Double> result = new ArrayList();
            result.add(price);
            map.put("result", result);
            wineList.add(map);

        }


        return wineList;
    }


    /**
     * Rescale
     *
     * @param data  data
     * @param scale the scales
     * @return scaled data
     */
    public List<Map<String, List<Double>>> rescale(final List<Map<String, List<Double>>> data, final List<Double> scale) {
        final List<Map<String, List<Double>>> scaledData = new ArrayList<>();
        for (final Map<String, List<Double>> dataItem : data) {
            final List<Double> scaledList = new LinkedList<>();
            for (int i = 0; i < scale.size(); i++) {
                scaledList.add(scale.get(i) * dataItem.get("input").get(i));
            }
            dataItem.put("input", scaledList);

            scaledData.add(dataItem);
        }
        return scaledData;
    }


    /**
     * Determine all the distances from a list
     *
     * @param data all the data
     * @param vec1 one list
     * @return all the distances
     */
    public List<Pair<Double, Integer>> determineDistances(final List<Map<String, List<Double>>> data, final List<Double> vec1) {
        final List<Pair<Double, Integer>> distances = new ArrayList<>();
        int i = 1;
        for (final Map<String, List<Double>> map : data) {
            final List<Double> vec2 = map.get("input");

            distances.add(new Pair(EuclideanDistanceScore.distanceList(vec1, vec2), i++));
        }

        Collections.sort(distances);

        return distances;
    }


    /**
     * Use kNN to estimate a new price
     *
     * @param data all the data
     * @param vec1 new fields to price
     * @param k    the amount of neighbours
     * @return the estimated price
     */
    public double knnEstimate(final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k) {
        final List<Pair<Double, Integer>> distances = determineDistances(data, vec1);
        double avg = 0.0;

        for (int i = 0; i <= k; i++) {
            int idx = distances.get(i).getValue1();
            avg += data.get(idx - 1).get("result").get(0);
        }
        avg = avg / k;
        return avg;
    }


    /**
     * KNN using a weighted average of the neighbours
     *
     * @param data the dataset
     * @param vec1 the data to price
     * @param k    number of neighbours
     * @return the weighted price
     */
    public double weightedKnn(final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k) {
        final List<Pair<Double, Integer>> distances = determineDistances(data, vec1);
        double avg = 0.0;
        double totalWeight = 0.0;

        for (int i = 0; i <= k; i++) {
            double dist = distances.get(i).getValue0();
            int idx = distances.get(i).getValue1();
            double weight = guassianWeight(dist, 5.0);

            avg += weight * data.get(idx - 1).get("result").get(0);
            totalWeight += weight;
        }
        if (totalWeight == 0.0) {
            return 0.0;
        }
        avg = avg / totalWeight;

        return avg;
    }


    /**
     * Gaussian Weight function, smoother weight curve that doesnt go to 0
     *
     * @param distance the distance
     * @param sigma    sigma
     * @return weighted value
     */
    public double guassianWeight(final double distance, final double sigma) {
        double alteredDistance = -(Math.pow(distance, 2));
        double sigmaSize = (2 * Math.pow(sigma, 2));
        return Math.pow(Math.E, (alteredDistance / sigmaSize));
    }

    /**
     * Split the data for cross validation.
     *
     * @param data        the data to split
     * @param testPercent % of data to use for the tests
     * @return a tuple 0 - training, 1 - test
     */
    @SuppressWarnings("unchecked")
    public Pair<List, List> divideData(final List<Map<String, List<Double>>> data, final double testPercent) {
        final List trainingList = new ArrayList();
        final List testList = new ArrayList();

        for (final Map<String, List<Double>> dataItem : data) {
            if (Math.random() < testPercent) {
                testList.add(dataItem);
            } else {
                trainingList.add(dataItem);
            }
        }

        return new Pair(trainingList, testList);
    }


    /**
     * Test result and squares the differences to make it more obvious
     *
     * @param trainingSet the training set
     * @param testSet     the test set
     * @return the error
     */
    @SuppressWarnings("unchecked")
    public double testAlgorithm(final List trainingSet, final List testSet) {
        double error = 0.0;
        final List<Map<String, List<Double>>> typedSet = (List<Map<String, List<Double>>>) testSet;
        for (final Map<String, List<Double>> testData : typedSet) {
            double guess = weightedKnn(trainingSet, testData.get("input"), 3);
            error += Math.pow((testData.get("result").get(0) - guess), 2);
        }
        return error / testSet.size();
    }

    /**
     * This runs iterations of the test, and returns an averaged score
     *
     * @param data        the data
     * @param testPercent % test
     * @param trials      number of iterations
     * @return result
     */
    public double crossValidate(final List<Map<String, List<Double>>> data, final double testPercent, final int trials) {

        double error = 0.0;
        for (int i = 0; i < trials; i++) {
            final Pair<List, List> trainingPair = divideData(data, testPercent);
            error += testAlgorithm(trainingPair.getValue0(), trainingPair.getValue1());
        }
        return error / trials;
    }


    /**
     * Gives the probability that an item is in a price range between 0 and 1
     * Adds up the neighbours weightd and divides it by the total
     *
     * @param data the data
     * @param vec1 the input
     * @param k    the number of neighbours
     * @param low  low amount of range
     * @param high the high amount
     * @return probability between 0 and 1
     */
    public double probabilityGuess(final List<Map<String, List<Double>>> data, final List<Double> vec1, final int k,
                                   final double low, final double high) {

        final List<Pair<Double, Integer>> distances = determineDistances(data, vec1);
        double neighbourWeights = 0.0;
        double totalWeights = 0.0;

        for (int i = 0; i < k; i++) {
            double dist = distances.get(i).getValue0();
            int index = distances.get(i).getValue1();
            double weight = guassianWeight(dist, 5);
            final List<Double> result = data.get(index).get("result");
            double v = result.get(0);

            //check if the point is in the range.
            if (v >= low && v <= high) {
                neighbourWeights += weight;
            }
            totalWeights += weight;
        }
        if (totalWeights == 0) {
            return 0;
        }
        return neighbourWeights / totalWeights;
    }

}

While reading up some more on k-NN I also stumbled upon the following blog posts:

 

Related Whitepaper:

Software Architecture

This guide will introduce you to the world of Software Architecture!

This 162 page guide will cover topics within the field of software architecture including: software architecture as a solution balancing the concerns of different stakeholders, quality assurance, methods to describe and evaluate architectures, the influence of architecture on reuse, and the life cycle of a system and its architecture. This guide concludes with a comparison between the professions of software architect and software engineer.

Get it Now!  

Leave a Reply


8 − one =



Java Code Geeks and all content copyright © 2010-2014, Exelixis Media Ltd | Terms of Use | Privacy Policy
All trademarks and registered trademarks appearing on Java Code Geeks are the property of their respective owners.
Java is a trademark or registered trademark of Oracle Corporation in the United States and other countries.
Java Code Geeks is not connected to Oracle Corporation and is not sponsored by Oracle Corporation.

Sign up for our Newsletter

20,709 insiders are already enjoying weekly updates and complimentary whitepapers! Join them now to gain exclusive access to the latest news in the Java world, as well as insights about Android, Scala, Groovy and other related technologies.

As an extra bonus, by joining you will get our brand new e-books, published by Java Code Geeks and their JCG partners for your reading pleasure! Enter your info and stay on top of things,

  • Fresh trends
  • Cases and examples
  • Research and insights
  • Two complimentary e-books