#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -----------------------------------------------------------------------------
# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
# -----------------------------------------------------------------------------
# This is illustrative code developed for tutorial purposes, it is not
# intended for scientific use and is not guarantied to be accurate or correct.
"""
Usage:
    forecast INTERVAL ITERATIONS

Arguments:
    INTERVAL: The period between forecasts in minutes.
    ITERATIONS: The number of forecasts to produce.

Example:
    forecast 60 5  # Produce 5 forecasts at one hour intervals
                   # i.e. T+1, T+2, T+3, T+4, T+5

Environment Variables:
    PT3H: The cycle point 3 hours prior.
    PT6H: The cycle point 6 hours prior.
    SPLINE_LEVEL (int): The beta-spline level to use for interpolating
        rainfall data onto the grid (either 0 or 1).
    WEIGHTING (list): The weighting to give to each observation set.

"""

import os
from subprocess import Popen, PIPE
import sys

import util


MI2KM = 1.609344001  # Conversion between miles and kilometers.


def get_wind_data(wind_file_path, wind_cycles):
    """Return wind field data for the provided cycle points.

    Args:
        wind_file_path (str): Template for the path to a wind data file.
            The template should contain substitutions for `{cycles}`
            (i.e. the cycle point) and `{xy}` (i.e. the value 'x' or 'y').
        wind_cycles (list): List of cycle points to get wind data for
            as strings.

    Return:
            list - A list 2-tuples containing wind field matrices i.e.
            [(wind_x, wind_y), ...]

    """
    wind_fields = []
    for cycle in wind_cycles:
        wind_x = util.read_csv(wind_file_path.format(cycle=cycle, xy='x'))
        wind_y = util.read_csv(wind_file_path.format(cycle=cycle, xy='y'))
        wind_fields.append((wind_x, wind_y))
    return wind_fields


def apply_weighting(wind_fields, weightings):
    """Return a single wind field averaged from the provided fields using the
    provided weightings.

    Args:
        wind_fields (list): List of wind field data as returned by
            get_wind_data().
        weightings (list): List of the weighting to apply to each field in the
            generated result (e.g. [1] would just use the first field and
            [0.5, 0.5] would average the first two.

    Return:
        tuple - (weighted_wind_x, weighted_wind_y)

    """
    # Get dimensions (in degrees lng/lat).
    dim_y = len(wind_fields[0][0])
    dim_x = len(wind_fields[0][0][0])

    # Generate matrices for the x and y components of the weighted wind.
    weighted_x = util.generate_matrix(dim_x, dim_y, 0.)
    weighted_y = util.generate_matrix(dim_x, dim_y, 0.)

    # Iterate over wind data from this and previous cycles.
    for ind, (wind_x, wind_y) in enumerate(wind_fields):
        for itt_y in range(dim_y):
            for itt_x in range(dim_x):
                # Acquire weighting for the cycle #ind.
                try:
                    weighting = weightings[ind]
                except IndexError:
                    continue
                weighted_x[itt_y][itt_x] += weighting * wind_x[itt_y][itt_x]
                weighted_y[itt_y][itt_x] += weighting * wind_y[itt_y][itt_x]

    util.plot_vector_grid('weighted-wind-data.png', weighted_x, weighted_y)

    return weighted_x, weighted_y


def push_rainfall(rainfall, wind_data, step, resolution, spline_level):
    """Generate forecasts using rainfall and wind data.

    A simple system which "pushes" the rainfall by the wind i.e. if there is a
    10km/h wind blowing from the west assume that the rain would move 10km to
    the east every hour.

    Args:
        rainfall (list): A 2D list of rainfall data.
        wind_data (tuple): A 2-tuple containing the x and y components of the
            wind each as a 2D list.
        step (int): The number of minutes between each forecast.
        resolution (float): The size of each grid cell in degrees.
        spline_level (int): The order of beta-spline to use for the
            interpolation.

    Yields:
        list - A 2D list of predicted rainfall data

    """
    dim_y = len(rainfall)
    dim_x = len(rainfall[0])

    x_values = []
    y_values = []
    z_values = []

    for itt_y in range(dim_y):
        for itt_x in range(dim_x):
            x_values.append(itt_x * resolution)
            y_values.append((dim_y - itt_y) * resolution)  # y-axis is flipped.
            z_values.append(rainfall[itt_y][itt_x])

    yield util.interpolate_grid(list(zip(x_values, y_values, z_values)),
                                dim_x, dim_y, resolution, resolution,
                                spline_level)

    while True:
        out_of_bounds = []
        for itt in range(len(x_values)):
            try:
                domain = util.parse_domain(os.environ['DOMAIN'])
                lng = domain['lng1'] + x_values[itt]
                lat = domain['lat1'] + y_values[itt]

                width = util.great_arc_distance(
                    (lng - (resolution / 2.), lat),
                    (lng + (resolution / 2.), lat))
                height = util.great_arc_distance(
                    (lng, lat - (resolution / 2.)),
                    (lng, lat + (resolution / 2.)))

                x_values[itt] += (
                    (wind_data[0][int(y_values[itt])][int(x_values[itt])] *
                     MI2KM * (step / 60.)) * (resolution / width))
                y_values[itt] -= (
                    (wind_data[1][int(y_values[itt])][int(x_values[itt])] *
                     MI2KM * (step / 60.)) * (resolution / height))
            except IndexError:
                out_of_bounds.append(itt)
        for itt in reversed(sorted(out_of_bounds)):
            del x_values[itt]
            del y_values[itt]
            del z_values[itt]

        yield util.interpolate_grid(list(zip(x_values, y_values, z_values)),
                                    dim_x, dim_y, resolution, resolution,
                                    spline_level)


def get_cyclepoint(offset_hours):
    """Calculate a cyclepoint from a relative offset.

    Run the system call `cylc cyclepoint --offset-hours <offset_hours>`
    """
    return Popen(
        ['cylc', 'cyclepoint', '--offset-hours', str(int(offset_hours))],
        stdout=PIPE
    ).communicate()[0].strip().decode()


def main(forecast_interval, forecast_iterations):
    # Get workflow settings.
    domain = util.parse_domain(os.environ['DOMAIN'])
    resolution = float(os.environ['RESOLUTION'])

    # Get science settings.
    spline_level = int(os.environ.get('SPLINE_LEVEL', 0))
    weighting = list(map(float, os.environ.get('WEIGHTING', '1').split(',')))

    # Get file paths.
    rainfall_file = os.environ['RAINFALL_FILE']
    wind_cycles = [get_cyclepoint(offset) for offset in
                   os.environ['WIND_CYCLES'].split(',')]
    wind_file_template = os.environ['WIND_FILE_TEMPLATE']
    map_file = os.environ['MAP_FILE']
    map_template = os.environ['MAP_TEMPLATE']

    # Get wind data from specified cycles.
    wind_data = get_wind_data(wind_file_template, wind_cycles)

    # Convert these into a single data grid using the specified weighting.
    weighted_wind_data = apply_weighting(wind_data, weighting)

    # Get rainfall data.
    rainfall = util.read_csv(rainfall_file)

    # Initiate the forecaster.
    generator = push_rainfall(rainfall, weighted_wind_data, forecast_interval,
                              resolution, spline_level)

    # Generate forecasts.
    forecasts = {}
    for lead_minutes in range(0, forecast_interval * (forecast_iterations + 1),
                              forecast_interval):
        forecast_name = '+PT%02dH%02dM' % (lead_minutes // 60,
                                           lead_minutes % 60)
        # Generate forecast.
        forecast = next(generator)
        forecasts[forecast_name] = forecast
        # Output forecast data.
        util.write_csv(forecast_name + '.csv', forecast)

    # Generate html page from forecast data.
    util.generate_html_map(map_file, map_template, forecasts, domain,
                           resolution)


if __name__ == '__main__':
    try:
        args = [int(sys.argv[1]), int(sys.argv[2])]
    except IndexError:
        print(__doc__)
        sys.exit(1)

    main(*args)
