#include <gtest/gtest.h>

#include <algorithm>
#include <set>
#include <sstream>
#include <stack>
#include <string>
#include <vector>

#include <gbwtgraph/algorithms.h>
#include <gbwtgraph/gfa.h>
#include <gbwtgraph/path_cover.h>

#include "shared.h"

using namespace gbwtgraph;

namespace
{

//------------------------------------------------------------------------------

// (n = 4, k = 3)-paths generated for components.gfa.

std::vector<std::set<gbwt::vector_type>> correct_paths =
{
  {
    {
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(11, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(12, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(14, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(15, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(17, false))
    },
    {
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(11, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(12, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(14, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(16, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(17, false))
    },
    {
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(11, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(13, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(14, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(15, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(17, false))
    },
    {
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(11, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(13, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(14, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(16, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(17, false))
    }
  },
  {
    {
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(21, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(22, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(24, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(25, false))
    },
    {
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(21, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(22, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(24, false)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(23, true)),
      static_cast<gbwt::vector_type::value_type>(gbwt::Node::encode(21, true))
    }
  }
};

//------------------------------------------------------------------------------

class PathCoverTest : public ::testing::Test
{
public:
  gbwt::GBWT index;
  GBWTGraph graph;
  size_t components;

  PathCoverTest()
  {
  }

  void SetUp() override
  {
    auto gfa_parse = gfa_to_gbwt("components.gfa");
    this->index = *(gfa_parse.first);
    this->graph = GBWTGraph(this->index, *(gfa_parse.second));
    this->components = correct_paths.size();
  }
};

TEST_F(PathCoverTest, CorrectPaths)
{
  size_t paths_per_component = 4;
  size_t context_length = 3;
  gbwt::size_type expected_sequences = this->components * paths_per_component * 2;

  gbwt::GBWT cover = path_cover_gbwt(this->graph, paths_per_component, context_length);
  ASSERT_EQ(cover.sequences(), expected_sequences) << "Wrong number of sequences in the path cover GBWT";

  // We insert the smaller of a path and its reverse complement to handle paths
  // that flip the orientation.
  std::vector<std::set<gbwt::vector_type>> result(this->components);
  for(size_t i = 0; i < this->components; i++)
  {
    for(size_t j = 0; j < paths_per_component; j++)
    {
      size_t seq_id = 2 * (i * paths_per_component + j);
      gbwt::vector_type forward = cover.extract(seq_id), reverse;
      gbwt::reversePath(forward, reverse);
      result[i].insert(std::min(forward, reverse));
    }
  }
  for(size_t i = 0; i < this->components; i++)
  {
    ASSERT_EQ(result[i].size(), correct_paths[i].size()) << "Wrong number of distinct paths for component " << i;
    auto result_iter = result[i].begin();
    auto correct_iter = correct_paths[i].begin();
    while(result_iter != result[i].end())
    {
      EXPECT_EQ(*result_iter, *correct_iter) << "Wrong path in component " << i;
      ++result_iter; ++correct_iter;
    }
  }
}

TEST_F(PathCoverTest, Metadata)
{
  size_t paths_per_component = 4;
  size_t context_length = 3;
  size_t expected_paths = paths_per_component * this->components;

  gbwt::GBWT cover = path_cover_gbwt(this->graph, paths_per_component, context_length);
  ASSERT_TRUE(cover.hasMetadata()) << "Path cover GBWT contains no metadata";
  EXPECT_EQ(cover.metadata.samples(), paths_per_component) << "Wrong number of samples in the metadata";
  EXPECT_EQ(cover.metadata.contigs(), this->components) << "Wrong number of contigs in the metadata";
  EXPECT_EQ(cover.metadata.haplotypes(), paths_per_component) << "Wrong number of haplotypes in the metadata";
  EXPECT_TRUE(cover.metadata.hasPathNames()) << "No path names in the metadata";
  EXPECT_EQ(cover.metadata.paths(), expected_paths) << "Wrong number of path names in the metadata";
}

//------------------------------------------------------------------------------

class LocalHaplotypesTest : public ::testing::Test
{
public:
  gbwt::GBWT index;
  GBWTGraph graph;
  size_t components;

  LocalHaplotypesTest()
  {
  }

  void SetUp() override
  {
    auto gfa_parse = gfa_to_gbwt("components.gfa");
    this->index = *(gfa_parse.first);
    this->graph = GBWTGraph(this->index, *(gfa_parse.second));
    this->components = correct_paths.size();
  }

  struct SearchState
  {
    gbwt::SearchState left, right;
    size_t depth;

    SearchState successor(const gbwt::GBWT& left_index, const gbwt::GBWT& right_index, gbwt::node_type to)
    {
      return { left_index.extend(this->left, to), right_index.extend(this->right, to), static_cast<size_t>(this->depth + 1) };
    }

    bool ok() const { return (this->left.empty() == this->right.empty()); }

    gbwt::node_type node() const { return this->left.node; }
  };

  // Does the second index have the same depth-k extensions as the first?
  bool same_extensions(const gbwt::GBWT& baseline, const gbwt::GBWT& candidate, size_t k)
  {
    for(gbwt::node_type node = baseline.firstNode(); node < baseline.sigma(); node++)
    {
      std::stack<SearchState> states;
      states.push({ baseline.find(node), candidate.find(node), static_cast<size_t>(1)});
      while(!(states.empty()))
      {
        SearchState state = states.top(); states.pop();
        if(!(state.ok())) { return false; }
        if(state.depth >= k) { continue; }
        for(auto edge : baseline.edges(state.node()))
        {
          if(edge.first == gbwt::ENDMARKER) { continue; }
          states.push(state.successor(baseline, candidate, edge.first));
        }
      }
    }
    return true;
  }
};

TEST_F(LocalHaplotypesTest, CorrectSubPaths)
{
  size_t paths_per_component = 4;
  size_t context_length = 3;
  gbwt::size_type expected_sequences = this->components * paths_per_component * 2;

  gbwt::GBWT cover = local_haplotypes(this->graph, this->index, paths_per_component, context_length);
  ASSERT_EQ(cover.sequences(), expected_sequences) << "Wrong number of sequences in the local haplotype GBWT";
  ASSERT_EQ(cover.sigma(), this->index.sigma()) << "Wrong alphabet size in the local haplotype GBWT";
  ASSERT_EQ(cover.effective(), this->index.effective()) << "Wrong effective alphabet size in the local haplotype GBWT";

  bool all_correct_subpaths = this->same_extensions(this->index, cover, context_length);
  EXPECT_TRUE(all_correct_subpaths) << "Missing " << context_length << "-subpaths in the local haplotype GBWT";
  bool no_extra_subpaths = this->same_extensions(cover, this->index, context_length);
  EXPECT_TRUE(no_extra_subpaths) << "Additional " << context_length << "-subpaths in the local haplotype GBWT";
}

TEST_F(LocalHaplotypesTest, Metadata)
{
  size_t paths_per_component = 4;
  size_t context_length = 3;
  size_t expected_paths = paths_per_component * this->components;

  gbwt::GBWT cover = local_haplotypes(this->graph, this->index, paths_per_component, context_length);
  ASSERT_TRUE(cover.hasMetadata()) << "Local haplotype GBWT contains no metadata";
  EXPECT_EQ(cover.metadata.samples(), paths_per_component) << "Wrong number of samples in the metadata";
  EXPECT_EQ(cover.metadata.contigs(), this->components) << "Wrong number of contigs in the metadata";
  EXPECT_EQ(cover.metadata.haplotypes(), paths_per_component) << "Wrong number of haplotypes in the metadata";
  EXPECT_TRUE(cover.metadata.hasPathNames()) << "No path names in the metadata";
  EXPECT_EQ(cover.metadata.paths(), expected_paths) << "Wrong number of path names in the metadata";
}

TEST_F(LocalHaplotypesTest, Frequencies)
{
  size_t paths_per_component = 4;
  size_t context_length = 3;
  std::vector<gbwt::node_type> frequent_path
  {
    gbwt::Node::encode(21, false),
    gbwt::Node::encode(22, false),
    gbwt::Node::encode(24, false)
  };
  std::vector<gbwt::node_type> rare_path
  {
    gbwt::Node::encode(22, false),
    gbwt::Node::encode(24, false),
    gbwt::Node::encode(25, false)
  };

  gbwt::GBWT cover = local_haplotypes(this->graph, this->index, paths_per_component, context_length);
  gbwt::SearchState frequent_state = cover.find(frequent_path.begin(), frequent_path.end());
  gbwt::SearchState rare_state = cover.find(rare_path.begin(), rare_path.end());
  EXPECT_GE(frequent_state.size(), rare_state.size()) << "Local haplotype frequencies do not reflect true frequencies";
}

TEST_F(LocalHaplotypesTest, RevertToPathCover)
{
  size_t paths_per_component = 4;
  size_t context_length = 3;
  gbwt::size_type expected_sequences = this->components * paths_per_component * 2;

  gbwt::GBWT haplotype_cover = local_haplotypes(this->graph, this->index, paths_per_component, context_length);
  ASSERT_EQ(haplotype_cover.sequences(), expected_sequences) << "Wrong number of sequences in the local haplotype GBWT";
  gbwt::GBWT path_cover = path_cover_gbwt(this->graph, paths_per_component, context_length);
  ASSERT_EQ(path_cover.sequences(), expected_sequences) << "Wrong number of sequences in the path cover GBWT";

  auto gfa_parse = gfa_to_gbwt("first_component.gfa");
  gbwt::GBWT mixed_cover = local_haplotypes(this->graph, *(gfa_parse.first), paths_per_component, context_length);
  ASSERT_EQ(mixed_cover.sequences(), expected_sequences) << "Wrong number of sequences in the mixed cover GBWT";

  // For the first component, we should have the same paths as with local haplotypes.
  for(size_t i = 0; i < paths_per_component; i++)
  {
    gbwt::vector_type path = mixed_cover.extract(gbwt::Path::encode(i, false));
    gbwt::vector_type correct_path = haplotype_cover.extract(gbwt::Path::encode(i, false));
    EXPECT_EQ(path, correct_path) << "Wrong path " << i << " in the first component";
  }

  // For the second component, we should have the same paths as with path cover.
  for(size_t i = 0; i < paths_per_component; i++)
  {
    gbwt::vector_type path = mixed_cover.extract(gbwt::Path::encode(paths_per_component + i, false));
    gbwt::vector_type correct_path = path_cover.extract(gbwt::Path::encode(paths_per_component + i, false));
    EXPECT_EQ(path, correct_path) << "Wrong path " << i << " in the second component";
  }
}

//------------------------------------------------------------------------------

class AugmentTest : public ::testing::Test
{
public:
  gbwt::GBWT index;
  GBWTGraph graph;
  std::vector<std::vector<nid_t>> components;
  size_t samples;

  AugmentTest()
  {
  }

  void SetUp() override
  {
    auto gfa_parse = gfa_to_gbwt("components.gfa");
    this->index = *(gfa_parse.first);
    this->graph = GBWTGraph(this->index, *(gfa_parse.second));
    this->components = weakly_connected_components(this->graph);
    this->samples = 2;
  }

  gbwt::DynamicGBWT create_gbwt(const std::set<size_t>& components_present, bool names) const
  {
    size_t node_width = gbwt::bit_length(this->index.sigma() - 1);
    size_t total_length = this->index.size();
    gbwt::GBWTBuilder builder(node_width, total_length);
    builder.index.addMetadata();

    std::vector<size_t> samples_per_component(this->components.size(), 0);
    std::vector<size_t> component_to_rank(this->components.size(), 0);
    for(size_t i = 0, rank = 0; i < this->components.size(); i++)
    {
      component_to_rank[i] = rank;
      if(components_present.find(i) != components_present.end()) { rank++; }
    }

    for(gbwt::size_type i = 0; i < this->index.sequences(); i += 2)
    {
      gbwt::vector_type sequence = this->index.extract(i);
      nid_t start_node = gbwt::Node::id(sequence.front());
      size_t component = 0;
      while(component < this->components.size())
      {
        const std::vector<nid_t>& curr = this->components[component];
        if(std::find(curr.begin(), curr.end(), start_node) != curr.end()) { break; }
        component++;
      }
      if(components_present.find(component) != components_present.end())
      {
        builder.insert(sequence, true);
        builder.index.metadata.addPath(
        {
          static_cast<gbwt::PathName::path_name_type>(samples_per_component[component]),
          static_cast<gbwt::PathName::path_name_type>(component_to_rank[component]),
          static_cast<gbwt::PathName::path_name_type>(0),
          static_cast<gbwt::PathName::path_name_type>(0)
        });
        samples_per_component[component]++;
      }
    }

    builder.finish();
    if(!(components_present.empty()))
    {
      if(names)
      {
        std::vector<std::string> sample_names, contig_names;
        for(size_t i = 0; i < this->samples; i++)
        {
          sample_names.push_back("sample_" + std::to_string(i));
        }
        builder.index.metadata.setSamples(sample_names);
        for(size_t contig : components_present)
        {
          contig_names.push_back("contig_" + std::to_string(contig));
        }
        builder.index.metadata.setContigs(contig_names);        
      }
      else
      {
        builder.index.metadata.setSamples(this->samples);
        builder.index.metadata.setContigs(components_present.size());
      }
      builder.index.metadata.setHaplotypes(this->samples);
    }
    return builder.index;
  }

  std::string set_name(const std::set<size_t>& components) const
  {
    std::stringstream ss;
    ss << "{";
    for(size_t i : components) { ss << " " << i; }
    ss << " }";
    return ss.str();
  }
};

TEST_F(AugmentTest, CorrectPaths)
{
  size_t paths_per_component = 4;
  size_t context_length = 3;
  std::vector<std::set<size_t>> component_sets
  {
    { }, { 0 }, { 1 }, { 0, 1 }
  };

  for(const std::set<size_t>& components_present : component_sets)
  {
    gbwt::DynamicGBWT original = this->create_gbwt(components_present, false);
    gbwt::DynamicGBWT augmented = original;
    std::string name = this->set_name(components_present);

    // Augment the GBWT.
    size_t expected_components = this->components.size() - components_present.size();
    size_t covered_components = augment_gbwt(this->graph, augmented, paths_per_component, context_length);
    ASSERT_EQ(covered_components, expected_components) << "Wrong number of covered components for components " << name;

    // Check metadata.
    size_t expected_sequences = original.sequences() + 2 * expected_components * paths_per_component;
    ASSERT_EQ(augmented.sequences(), expected_sequences) << "Wrong number of sequences for components " << name;

    // Check original paths.
    size_t path_id = 0;
    while(path_id < original.metadata.paths())
    {
      size_t i = gbwt::Path::encode(path_id, false);
      bool same_paths = (original.extract(i) == augmented.extract(i));
      EXPECT_TRUE(same_paths) << "Wrong original path " << path_id << " for components " << name;
      path_id++;
    }

    // Check generated paths.
    for(size_t component = 0; component < this->components.size(); component++)
    {
      if(components_present.find(component) != components_present.end()) { continue; }
      for(size_t i = 0; i < paths_per_component; i++)
      {
        gbwt::vector_type forward = augmented.extract(gbwt::Path::encode(path_id, false)), reverse;
        gbwt::reversePath(forward, reverse);
        gbwt::vector_type path = std::min(forward, reverse);
        bool correct_path = (correct_paths[component].find(path) != correct_paths[component].end());
        EXPECT_TRUE(correct_path) << "Wrong augmented path " << i << " in component " << component << " for components " << name;
        path_id++;
      }
    }
  }
}

TEST_F(AugmentTest, PathNames)
{
  size_t paths_per_component = 4;
  size_t context_length = 3;
  std::vector<std::set<size_t>> component_sets
  {
    { }, { 0 }, { 1 }, { 0, 1 }
  };

  for(const std::set<size_t>& components_present : component_sets)
  {
    gbwt::DynamicGBWT original = this->create_gbwt(components_present, false);
    gbwt::DynamicGBWT augmented = original;
    std::string name = this->set_name(components_present);

    // Augment the GBWT.
    size_t expected_components = this->components.size() - components_present.size();
    augment_gbwt(this->graph, augmented, paths_per_component, context_length);

    // Check metadata.
    ASSERT_TRUE(augmented.hasMetadata()) << "No metadata for components " << name;
    ASSERT_TRUE(augmented.metadata.hasPathNames()) << "No path names for components " << name;
    size_t expected_paths = original.metadata.paths() + expected_components * paths_per_component;
    ASSERT_EQ(augmented.metadata.paths(), expected_paths) << "Wrong number of path names for components " << name;

    // Check original paths.
    size_t path_id = 0;
    while(path_id < original.metadata.paths())
    {
      bool same_names = (original.metadata.path(path_id) == augmented.metadata.path(path_id));
      EXPECT_TRUE(same_names) << "Wrong original path name " << path_id << " for components " << name;
      path_id++;
    }

    // Check generated paths.
    size_t rank = 0;
    for(size_t component = 0; component < this->components.size(); component++)
    {
      if(components_present.find(component) != components_present.end()) { continue; }
      for(size_t i = 0; i < paths_per_component; i++)
      {
        gbwt::PathName path_name = augmented.metadata.path(path_id);
        EXPECT_EQ(path_name.sample, original.metadata.samples() + i) << "Wrong sample for augmented path " << i << " in component " << component << " for components " << name;
        EXPECT_EQ(path_name.contig, components_present.size() + rank) << "Wrong contig for augmented path " << i << " in component " << component << " for components " << name;
        path_id++;
      }
      rank++;
    }
  }
}

TEST_F(AugmentTest, SamplesAndContigs)
{
  size_t paths_per_component = 4;
  size_t context_length = 3;
  std::vector<std::set<size_t>> component_sets
  {
    { }, { 0 }, { 1 }, { 0, 1 }
  };

  for(const std::set<size_t>& components_present : component_sets)
  {
    for(bool names : { false, true })
    {
      // If the original GBWT is empty, we do not have any names in the metadata.
      if(components_present.empty() && names) { continue; }

      gbwt::DynamicGBWT original = this->create_gbwt(components_present, names);
      gbwt::DynamicGBWT augmented = original;
      std::string name = this->set_name(components_present);

      // Augment the GBWT.
      size_t expected_components = this->components.size() - components_present.size();
      augment_gbwt(this->graph, augmented, paths_per_component, context_length);

      // Check metadata.
      ASSERT_TRUE(augmented.hasMetadata()) << "No metadata for components " << name;
      ASSERT_EQ(augmented.metadata.hasSampleNames(), names) << "Sample names for components " << name;
      size_t expected_samples = original.metadata.samples() + (expected_components == 0 ? 0 : paths_per_component);
      ASSERT_EQ(augmented.metadata.samples(), expected_samples) << "Wrong number of samples for components " << name;
      ASSERT_EQ(augmented.metadata.hasContigNames(), names) << "Contig names for components " << name;
      ASSERT_EQ(augmented.metadata.contigs(), this->components.size()) << "Wrong number of contigs for components " << name;

      if(!names) { continue; }

      // Check sample names.
      size_t sample_id = 0;
      while(sample_id < original.metadata.samples())
      {
        std::string expected_name = "sample_" + std::to_string(sample_id);
        EXPECT_EQ(augmented.metadata.sample(sample_id), expected_name) << "Wrong name for original sample " << sample_id << " for components " << name;
        sample_id++;
      }
      if(expected_components > 0)
      {
        for(size_t i = 0; i < paths_per_component; i++)
        {
          std::string expected_name = "path_cover_" + std::to_string(i);
          EXPECT_EQ(augmented.metadata.sample(sample_id), expected_name) << "Wrong name for augmented sample " << i << " for components " << name;
          sample_id++;
        }
      }

      // Check contig names.
      size_t contig_id = 0;
      for(size_t i = 0; i < this->components.size(); i++)
      {
        if(components_present.find(i) != components_present.end())
        {
          std::string expected_name = "contig_" + std::to_string(i);
          EXPECT_EQ(augmented.metadata.contig(contig_id), expected_name) << "Wrong name for original contig " << i << " for components " << name;
          contig_id++;
        }
      }
      if(expected_components > 0)
      {
        for(size_t i = 0; i < this->components.size(); i++)
        {
          if(components_present.find(i) == components_present.end())
          {
            std::string expected_name = "component_" + std::to_string(i);
            EXPECT_EQ(augmented.metadata.contig(contig_id), expected_name) << "Wrong name for augmented contig " << i << " for components " << name;
            contig_id++;
          }
        }
      }
    }
  }
}

//------------------------------------------------------------------------------

} // namespace
