/** 
 * @author YourAndrewIDHere
 * You should fill in the stub for this code file.  The
 * arguments will be given to you 
 */

import java.io.*;
import java.util.*;

public class ProbModelDriver {

	/**
	 * The first argument is the model filename
	 * The second argument should be the value of p, as an ASCII string
	 * The third argument should be the NodeID of the initial product user
	 * @throws IOException 
	 */
	public static void main(String[] args) throws IOException {
		if(args.length < 3)
		{
			System.out.println("Usage: ProbModelDriver netDefFilename pValue initialNode (-test)");
			return;
		}
		
		String networkFilename = args[0];
		double successProb = Double.parseDouble(args[1]);
		int    startNode = Integer.parseInt(args[2]);
		
		ProbNetwork network = new ProbNetwork(networkFilename);
		
		// DO NOT REMOVE THE FOLLOWING CODE.
		if(args.length == 4 && args[3].toLowerCase().equals("-test"))
		{
			ArrayList<Integer> users = network.runTest(successProb, startNode);
			System.out.println(users);
			
			return; // make sure that this return is called!
		}
		
		int		numTrials = 10000;
		int		total = 0;
		HashMap<Integer, HashMap<Double, ArrayList<Double>>> degreeMap = 
			new HashMap<Integer, HashMap<Double, ArrayList<Double>>>();

		for (double prob = 0.05; prob < 1.0; prob += 0.05) {
			prob = (double)Math.round(prob*100)/100;
			
			for (int i = 0; i < network.numNodes; i++) {
				total = 0;
				
				
				int degree = network.network.get(i).neighbors.size();
				HashMap<Double, ArrayList<Double>> probMap = degreeMap.get(degree);
				if (probMap == null) probMap = new HashMap<Double, ArrayList<Double>>();
				
				ArrayList<Double> probList = probMap.get(prob);
				if (probList == null) probList = new ArrayList<Double>();

				double totalDistance = 0.0;
				
				for (int j = 0; j < numTrials; j++) {
					ArrayList<Integer> users = network.runTest(prob, i);
					
					double distance = 0.0;
					
					for (Iterator<Integer> iter = users.iterator(); iter.hasNext(); ) {
						int user = iter.next();
						distance += network.network.get(user).distance;
					}
					
					distance = distance/users.size();
					if (Double.isNaN(distance)) distance = 0.0;
					totalDistance += distance;
					
					network.resetUsers();
					
					total += users.toArray().length;
				}

				ArrayList<Double> dList = new ArrayList<Double>();
				dList.add(totalDistance/(double)numTrials);
				
				probList.add((double)total/(double)numTrials);
				probMap.put(prob+1.0, dList);
				probMap.put(prob, probList);
				degreeMap.put(degree, probMap);
			}
		}
		
		BufferedWriter output = new BufferedWriter(new FileWriter("probout.csv"));
		BufferedWriter dOutput = new BufferedWriter(new FileWriter("distout.csv"));
		
		Iterator<Integer> iter = degreeMap.keySet().iterator();
		while (iter.hasNext()) {
			int degree = iter.next();
			HashMap<Double, ArrayList<Double>> probMap = degreeMap.get(degree);
			
			output.write(degree + " ");
			dOutput.write(degree + " ");
			for (double prob = 0.05; prob < 1.0; prob += 0.05) {
				prob = (double)Math.round(prob*100)/100;
				ArrayList<Double> probList = probMap.get(prob);
				
				double probTotal = 0.0;
				for (Iterator<Double> i = probList.iterator(); i.hasNext(); ) {
					probTotal += i.next();
				}
				
				probTotal /= probList.size();
				output.write( ", " + probTotal);

				ArrayList<Double> dList = probMap.get(prob + 1.0);
				dOutput.write( ", " + dList.get(0));
				
			}
			
			output.write("\n");
			dOutput.write("\n");
		}

		output.flush();
		output.close();
		
		dOutput.flush();
		dOutput.close();
	}

}

