//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDA_STREAM_REF
#define _CUDA_STREAM_REF

// clang-format off
/*
    stream_ref synopsis
namespace cuda {
class stream_ref {
    using value_type = cudaStream_t;

    stream_ref() = default;
    stream_ref(cudaStream_t stream_) noexcept : stream(stream_) {}

    stream_ref(int) = delete;
    stream_ref(nullptr_t) = delete;

    [[nodiscard]] value_type get() const noexcept;

    void wait() const;

    [[nodiscard]] bool ready() const;

    [[nodiscard]] friend bool operator==(stream_ref, stream_ref);
    [[nodiscard]] friend bool operator!=(stream_ref, stream_ref);

private:
  cudaStream_t stream = 0; // exposition only
};
}  // cuda
*/

#ifdef LIBCUDACXX_ENABLE_EXPERIMENTAL_MEMORY_RESOURCE

#include <cuda_runtime_api.h> // cuda_runtime_api needs to come first
// clang-format on

#include <cuda/std/detail/__config>

#include <cuda/std/detail/__pragma_push>

#include <cuda/std/type_traits>

_LIBCUDACXX_BEGIN_NAMESPACE_CUDA

/**
 * \brief A non-owning wrapper for a `cudaStream_t`.
 *
 * `stream_view` is a non-owning "view" type similar to `std::span` or
 * `std::string_view`. \see https://en.cppreference.com/w/cpp/container/span and
 * \see https://en.cppreference.com/w/cpp/string/basic_string_view
 *
 */
class stream_ref
{
private:
  ::cudaStream_t __stream{0};

public:
  using value_type = ::cudaStream_t;

  /**
   * \brief Constructs a `stream_view` of the "default" CUDA stream.
   *
   * For behavior of the default stream,
   * \see
   * https://docs.nvidia.com/cuda/cuda-runtime-api/stream-sync-behavior.html
   *
   */
  stream_ref() = default;

  /**
   * \brief Constructs a `stream_view` from a `cudaStream_t` handle.
   *
   * This constructor provides implicit conversion from `cudaStream_t`.
   *
   * \note: It is the callers responsibilty to ensure the `stream_view` does not
   * outlive the stream identified by the `cudaStream_t` handle.
   *
   */
  constexpr stream_ref(value_type __stream_) noexcept
      : __stream{__stream_}
  {}

  /// Disallow construction from an `int`, e.g., `0`.
  stream_ref(int) = delete;

  /// Disallow construction from `nullptr`.
  stream_ref(_CUDA_VSTD::nullptr_t) = delete;

  /**
   * \brief Compares two `stream_view`s for equality
   *
   * \note Allows comparison with `cudaStream_t` due to implicit conversion to
   * `stream_view`.
   *
   * \param lhs The first `stream_view` to compare
   * \param rhs The second `stream_view` to compare
   * \return true if equal, false if unequal
   */
  _LIBCUDACXX_NODISCARD_ATTRIBUTE friend bool constexpr operator==(stream_ref __lhs, stream_ref __rhs) noexcept
  {
    return __lhs.__stream == __rhs.__stream;
  }

  /**
   * \brief Compares two `stream_view`s for inequality
   *
   * \note Allows comparison with `cudaStream_t` due to implicit conversion to
   * `stream_view`.
   *
   * \param lhs The first `stream_view` to compare
   * \param rhs The second `stream_view` to compare
   * \return true if unequal, false if equal
   */
  _LIBCUDACXX_NODISCARD_ATTRIBUTE friend bool constexpr operator!=(stream_ref __lhs, stream_ref __rhs) noexcept
  {
    return __lhs.__stream != __rhs.__stream;
  }

  /// Returns the wrapped `cudaStream_t` handle.
  _LIBCUDACXX_NODISCARD_ATTRIBUTE constexpr value_type get() const noexcept { return __stream; }

  /**
   * \brief Synchronizes the wrapped stream.
   *
   * \throws cuda::cuda_error if synchronization fails.
   *
   */
  void wait() const
  {
    const auto __result = ::cudaStreamQuery(get());
    switch (__result)
    {
      case ::cudaSuccess:
        return;
      default:
        ::cudaGetLastError(); // Clear CUDA error state
#ifndef _LIBCUDACXX_NO_EXCEPTIONS
        throw cuda::cuda_error{__result, "Failed to synchronize stream."};
#else
        // _LIBCUDACXX_UNREACHABLE;
#endif
    }
  }

  /**
   * \brief Queries if all operations on the wrapped stream have completed.
   *
   * \throws cuda::cuda_error if the query fails.
   *
   * \return `true` if all operations have completed, or `false` if not.
   */
  _LIBCUDACXX_NODISCARD_ATTRIBUTE bool ready() const
  {
    const auto __result = ::cudaStreamQuery(get());
    switch (__result)
    {
      case ::cudaSuccess:
        return true;
      case ::cudaErrorNotReady:
        return false;
      default:
        ::cudaGetLastError(); // Clear CUDA error state
#ifndef _LIBCUDACXX_NO_EXCEPTIONS
        throw cuda::cuda_error{__result, ""};
#else
        // _LIBCUDACXX_UNREACHABLE;
#endif
        return false;
    }
  }
};

_LIBCUDACXX_END_NAMESPACE_CUDA

#include <cuda/std/detail/__pragma_pop>

#endif // LIBCUDACXX_ENABLE_EXPERIMENTAL_MEMORY_RESOURCE

#endif //_CUDA_STREAM_REF
