// SPDX-FileCopyrightText: Copyright (c) 2011, Duane Merrill. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
// SPDX-License-Identifier: BSD-3

/**
 * @file
 * cub::DeviceRadixSort provides device-wide, parallel operations for computing a radix sort across
 * a sequence of data items residing within device-accessible memory.
 */

#pragma once

#include <cub/config.cuh>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cub/device/dispatch/kernels/kernel_segmented_radix_sort.cuh>
#include <cub/device/dispatch/tuning/tuning_radix_sort.cuh>
#include <cub/util_debug.cuh>
#include <cub/util_device.cuh>
#include <cub/util_math.cuh>
#include <cub/util_type.cuh>

#include <cuda/__cmath/ceil_div.h>
#include <cuda/std/__algorithm/max.h>
#include <cuda/std/__algorithm/min.h>
#include <cuda/std/__type_traits/is_same.h>
#include <cuda/std/cstdint>
#include <cuda/std/limits>

// suppress warnings triggered by #pragma unroll:
// "warning: loop not unrolled: the optimizer was unable to perform the requested transformation; the transformation
// might be disabled or specified as part of an unsupported transformation ordering [-Wpass-failed=transform-warning]"
_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_CLANG("-Wpass-failed")

CUB_NAMESPACE_BEGIN

namespace detail::radix_sort
{
template <typename MaxPolicyT,
          SortOrder Order,
          typename KeyT,
          typename ValueT,
          typename BeginOffsetIteratorT,
          typename EndOffsetIteratorT,
          typename SegmentSizeT,
          typename DecomposerT>
struct DeviceSegmentedRadixSortKernelSource
{
  CUB_DEFINE_KERNEL_GETTER(
    SegmentedRadixSortKernel,
    DeviceSegmentedRadixSortKernel<
      MaxPolicyT,
      false,
      Order,
      KeyT,
      ValueT,
      BeginOffsetIteratorT,
      EndOffsetIteratorT,
      SegmentSizeT,
      DecomposerT>);

  CUB_DEFINE_KERNEL_GETTER(
    AltSegmentedRadixSortKernel,
    DeviceSegmentedRadixSortKernel<
      MaxPolicyT,
      true,
      Order,
      KeyT,
      ValueT,
      BeginOffsetIteratorT,
      EndOffsetIteratorT,
      SegmentSizeT,
      DecomposerT>);

  CUB_RUNTIME_FUNCTION static constexpr size_t KeySize()
  {
    return sizeof(KeyT);
  }

  CUB_RUNTIME_FUNCTION static constexpr size_t ValueSize()
  {
    return sizeof(ValueT);
  }
};
} // namespace detail::radix_sort

/******************************************************************************
 * Segmented dispatch
 ******************************************************************************/

/**
 * @brief Utility class for dispatching the appropriately-tuned kernels for segmented device-wide
 * radix sort
 *
 * @tparam SortOrder
 *   Whether to sort in ascending or descending order
 *
 * @tparam KeyT
 *   Key type
 *
 * @tparam ValueT
 *   Value type
 *
 * @tparam BeginOffsetIteratorT
 *   Random-access input iterator type for reading segment beginning offsets @iterator
 *
 * @tparam EndOffsetIteratorT
 *   Random-access input iterator type for reading segment ending offsets @iterator
 *
 * @tparam SegmentSizeT
 *   Integer type to index items within a segment
 */
template <SortOrder Order,
          typename KeyT,
          typename ValueT,
          typename BeginOffsetIteratorT,
          typename EndOffsetIteratorT,
          typename SegmentSizeT,
          typename PolicyHub    = detail::radix::policy_hub<KeyT, ValueT, SegmentSizeT>,
          typename DecomposerT  = detail::identity_decomposer_t,
          typename KernelSource = detail::radix_sort::DeviceSegmentedRadixSortKernelSource<
            typename PolicyHub::MaxPolicy,
            Order,
            KeyT,
            ValueT,
            BeginOffsetIteratorT,
            EndOffsetIteratorT,
            SegmentSizeT,
            DecomposerT>,
          typename KernelLauncherFactory = CUB_DETAIL_DEFAULT_KERNEL_LAUNCHER_FACTORY>
struct DispatchSegmentedRadixSort
{
  //------------------------------------------------------------------------------
  // Constants
  //------------------------------------------------------------------------------

  // Whether this is a keys-only (or key-value) sort
  static constexpr bool KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;

  //------------------------------------------------------------------------------
  // Parameter members
  //------------------------------------------------------------------------------

  /// Device-accessible allocation of temporary storage.  When nullptr, the required allocation size
  /// is written to `temp_storage_bytes` and no work is done.
  void* d_temp_storage;

  /// Reference to size in bytes of `d_temp_storage` allocation
  size_t& temp_storage_bytes;

  /// Double-buffer whose current buffer contains the unsorted input keys and, upon return, is
  /// updated to point to the sorted output keys
  DoubleBuffer<KeyT>& d_keys;

  /// Double-buffer whose current buffer contains the unsorted input values and, upon return, is
  /// updated to point to the sorted output values
  DoubleBuffer<ValueT>& d_values;

  /// Number of items to sort
  ::cuda::std::int64_t num_items;

  /// The number of segments that comprise the sorting data
  ::cuda::std::int64_t num_segments;

  /// Random-access input iterator to the sequence of beginning offsets of length `num_segments`,
  /// such that <tt>d_begin_offsets[i]</tt> is the first element of the <em>i</em><sup>th</sup>
  /// data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>
  BeginOffsetIteratorT d_begin_offsets;

  /// Random-access input iterator to the sequence of ending offsets of length `num_segments`,
  /// such that <tt>d_end_offsets[i]-1</tt> is the last element of the <em>i</em><sup>th</sup>
  /// data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>. If <tt>d_end_offsets[i]-1</tt>
  /// <= <tt>d_begin_offsets[i]</tt>, the <em>i</em><sup>th</sup> is considered empty.
  EndOffsetIteratorT d_end_offsets;

  /// The beginning (least-significant) bit index needed for key comparison
  int begin_bit;

  /// The past-the-end (most-significant) bit index needed for key comparison
  int end_bit;

  /// CUDA stream to launch kernels within.  Default is stream<sub>0</sub>.
  cudaStream_t stream;

  /// PTX version
  int ptx_version;

  /// Whether is okay to overwrite source buffers
  bool is_overwrite_okay;

  DecomposerT decomposer;

  KernelSource kernel_source;

  KernelLauncherFactory launcher_factory;

  //------------------------------------------------------------------------------
  // Constructors
  //------------------------------------------------------------------------------

  /// Constructor
  CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE DispatchSegmentedRadixSort(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    DoubleBuffer<KeyT>& d_keys,
    DoubleBuffer<ValueT>& d_values,
    ::cuda::std::int64_t num_items,
    ::cuda::std::int64_t num_segments,
    BeginOffsetIteratorT d_begin_offsets,
    EndOffsetIteratorT d_end_offsets,
    int begin_bit,
    int end_bit,
    bool is_overwrite_okay,
    cudaStream_t stream,
    int ptx_version,
    DecomposerT decomposer                 = {},
    KernelSource kernel_source             = {},
    KernelLauncherFactory launcher_factory = {})
      : d_temp_storage(d_temp_storage)
      , temp_storage_bytes(temp_storage_bytes)
      , d_keys(d_keys)
      , d_values(d_values)
      , num_items(num_items)
      , num_segments(num_segments)
      , d_begin_offsets(d_begin_offsets)
      , d_end_offsets(d_end_offsets)
      , begin_bit(begin_bit)
      , end_bit(end_bit)
      , stream(stream)
      , ptx_version(ptx_version)
      , is_overwrite_okay(is_overwrite_okay)
      , decomposer(decomposer)
      , kernel_source(kernel_source)
      , launcher_factory(launcher_factory)
  {}

  //------------------------------------------------------------------------------
  // Multi-segment invocation
  //------------------------------------------------------------------------------

  /// Invoke a three-kernel sorting pass at the current bit.
  template <typename PassConfigT>
  CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t InvokePass(
    const KeyT* d_keys_in,
    KeyT* d_keys_out,
    const ValueT* d_values_in,
    ValueT* d_values_out,
    int& current_bit,
    PassConfigT& pass_config)
  {
    cudaError error = cudaSuccess;

    // The number of bits to process in this pass
    int pass_bits = ::cuda::std::min(pass_config.radix_bits, (end_bit - current_bit));

    // The offset type (used to specialize the kernel template), large enough to index any segment within a single
    // invocation
    using per_invocation_segment_offset_t = ::cuda::std::int32_t;

    // The upper bound of segments that a single kernel invocation will process
    constexpr auto max_num_segments_per_invocation =
      static_cast<::cuda::std::int64_t>(::cuda::std::numeric_limits<per_invocation_segment_offset_t>::max());

    // Number of radix sort invocations until all segments have been processed
    const auto num_invocations = ::cuda::ceil_div(num_segments, max_num_segments_per_invocation);

    BeginOffsetIteratorT begin_offsets_current_it = d_begin_offsets;
    EndOffsetIteratorT end_offsets_current_it     = d_end_offsets;

    // Iterate over chunks of segments
    for (::cuda::std::int64_t invocation_index = 0; invocation_index < num_invocations; invocation_index++)
    {
      const auto current_segment_offset = invocation_index * max_num_segments_per_invocation;
      const auto num_current_segments =
        ::cuda::std::min(max_num_segments_per_invocation, num_segments - current_segment_offset);

// Log kernel configuration
#ifdef CUB_DEBUG_LOG
      _CubLog(
        "Invoking segmented_kernels<<<%lld, %lld, 0, %lld>>>(), "
        "%lld items per thread, %lld SM occupancy, "
        "current segment offset %lld, current bit %d, bit_grain %d\n",
        (long long) num_current_segments,
        (long long) pass_config.segmented_config.block_threads,
        (long long) stream,
        (long long) pass_config.segmented_config.items_per_thread,
        (long long) pass_config.segmented_config.sm_occupancy,
        (long long) current_segment_offset,
        current_bit,
        pass_bits);
#endif

      launcher_factory(
        static_cast<unsigned int>(num_current_segments), pass_config.segmented_config.block_threads, 0, stream)
        .doit(pass_config.segmented_kernel,
              d_keys_in,
              d_keys_out,
              d_values_in,
              d_values_out,
              begin_offsets_current_it,
              end_offsets_current_it,
              current_bit,
              pass_bits,
              decomposer);

      // Check for failure to launch
      error = CubDebug(cudaPeekAtLastError());
      if (cudaSuccess != error)
      {
        return error;
      }

      if (invocation_index + 1 < num_invocations)
      {
        begin_offsets_current_it += num_current_segments;
        end_offsets_current_it += num_current_segments;
      }

      // Sync the stream if specified to flush runtime errors
      error = CubDebug(detail::DebugSyncStream(stream));
      if (cudaSuccess != error)
      {
        return error;
      }
    }

    // Update current bit once all segments have been processed for the current pass
    current_bit += pass_bits;

    return error;
  }

  /// PassConfig data structure
  template <typename SegmentedKernelT>
  struct PassConfig
  {
    SegmentedKernelT segmented_kernel;
    detail::KernelConfig segmented_config;
    int radix_bits;
    int radix_digits;

    /// Initialize pass configuration
    template <typename SegmentedPolicyT>
    CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t InitPassConfig(
      SegmentedKernelT segmented_kernel,
      int radix_bits,
      SegmentedPolicyT policy                = {},
      KernelLauncherFactory launcher_factory = {})
    {
      this->segmented_kernel = segmented_kernel;
      this->radix_bits       = radix_bits;
      this->radix_digits     = 1 << radix_bits;

      return CubDebug(segmented_config.Init(segmented_kernel, policy, launcher_factory));
    }
  };

  /**
   * @brief Invocation (run multiple digit passes)
   *
   * @tparam ActivePolicyT
   *   Umbrella policy active for the target device
   *
   * @tparam SegmentedKernelT
   *   Function type of cub::DeviceSegmentedRadixSortKernel
   *
   * @param[in] segmented_kernel
   *   Kernel function pointer to parameterization of cub::DeviceSegmentedRadixSortKernel
   *
   * @param[in] alt_segmented_kernel
   *   Alternate kernel function pointer to parameterization of
   *   cub::DeviceSegmentedRadixSortKernel
   */
  template <typename ActivePolicyT, typename SegmentedKernelT>
  CUB_RUNTIME_FUNCTION _CCCL_VISIBILITY_HIDDEN _CCCL_FORCEINLINE cudaError_t
  InvokePasses(SegmentedKernelT segmented_kernel, SegmentedKernelT alt_segmented_kernel, ActivePolicyT policy = {})
  {
    cudaError error = cudaSuccess;
    do
    {
      // Init regular and alternate kernel configurations
      PassConfig<SegmentedKernelT> pass_config, alt_pass_config;
      if ((error = pass_config.InitPassConfig(
             segmented_kernel, policy.RadixBits(policy.Segmented()), policy.Segmented(), launcher_factory)))
      {
        break;
      }
      if ((error = alt_pass_config.InitPassConfig(
             alt_segmented_kernel, policy.RadixBits(policy.AltSegmented()), policy.AltSegmented(), launcher_factory)))
      {
        break;
      }

      // Temporary storage allocation requirements
      void* allocations[2]       = {};
      size_t allocation_sizes[2] = {
        // bytes needed for 3rd keys buffer
        (is_overwrite_okay) ? 0 : num_items * kernel_source.KeySize(),

        // bytes needed for 3rd values buffer
        (is_overwrite_okay || (KEYS_ONLY)) ? 0 : num_items * sizeof(ValueT),
      };

      // Alias the temporary allocations from the single storage blob (or compute the necessary size of the blob)
      error = CubDebug(detail::alias_temporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes));
      if (cudaSuccess != error)
      {
        break;
      }

      // Return if the caller is simply requesting the size of the storage allocation
      if (d_temp_storage == nullptr)
      {
        if (temp_storage_bytes == 0)
        {
          temp_storage_bytes = 1;
        }
        return cudaSuccess;
      }

      // Pass planning.  Run passes of the alternate digit-size configuration until we have an even multiple of our
      // preferred digit size
      int radix_bits         = policy.RadixBits(policy.Segmented());
      int alt_radix_bits     = policy.RadixBits(policy.AltSegmented());
      int num_bits           = end_bit - begin_bit;
      int num_passes         = ::cuda::std::max(::cuda::ceil_div(num_bits, radix_bits), 1); // num_bits may be zero
      bool is_num_passes_odd = num_passes & 1;
      int max_alt_passes     = (num_passes * radix_bits) - num_bits;
      int alt_end_bit        = ::cuda::std::min(end_bit, begin_bit + (max_alt_passes * alt_radix_bits));

      DoubleBuffer<KeyT> d_keys_remaining_passes(
        (is_overwrite_okay || is_num_passes_odd) ? d_keys.Alternate() : static_cast<KeyT*>(allocations[0]),
        (is_overwrite_okay)   ? d_keys.Current()
        : (is_num_passes_odd) ? static_cast<KeyT*>(allocations[0])
                              : d_keys.Alternate());

      DoubleBuffer<ValueT> d_values_remaining_passes(
        (is_overwrite_okay || is_num_passes_odd) ? d_values.Alternate() : static_cast<ValueT*>(allocations[1]),
        (is_overwrite_okay)   ? d_values.Current()
        : (is_num_passes_odd) ? static_cast<ValueT*>(allocations[1])
                              : d_values.Alternate());

      // Run first pass, consuming from the input's current buffers
      int current_bit = begin_bit;

      error = CubDebug(InvokePass(
        d_keys.Current(),
        d_keys_remaining_passes.Current(),
        d_values.Current(),
        d_values_remaining_passes.Current(),
        current_bit,
        (current_bit < alt_end_bit) ? alt_pass_config : pass_config));
      if (cudaSuccess != error)
      {
        break;
      }

      // Run remaining passes
      while (current_bit < end_bit)
      {
        error = CubDebug(InvokePass(
          d_keys_remaining_passes.d_buffers[d_keys_remaining_passes.selector],
          d_keys_remaining_passes.d_buffers[d_keys_remaining_passes.selector ^ 1],
          d_values_remaining_passes.d_buffers[d_keys_remaining_passes.selector],
          d_values_remaining_passes.d_buffers[d_keys_remaining_passes.selector ^ 1],
          current_bit,
          (current_bit < alt_end_bit) ? alt_pass_config : pass_config));
        if (cudaSuccess != error)
        {
          break;
        }

        // Invert selectors and update current bit
        d_keys_remaining_passes.selector ^= 1;
        d_values_remaining_passes.selector ^= 1;
      }

      // Update selector
      if (!is_overwrite_okay)
      {
        num_passes = 1; // Sorted data always ends up in the other vector
      }

      d_keys.selector   = (d_keys.selector + num_passes) & 1;
      d_values.selector = (d_values.selector + num_passes) & 1;
    } while (0);

    return error;
  }

  //------------------------------------------------------------------------------
  // Chained policy invocation
  //------------------------------------------------------------------------------

  /// Invocation
  template <typename ActivePolicyT>
  CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke(ActivePolicyT policy = {})
  {
    // Return if empty problem, or if no bits to sort and double-buffering is used
    if (num_items == 0 || num_segments == 0 || (begin_bit == end_bit && is_overwrite_okay))
    {
      if (d_temp_storage == nullptr)
      {
        temp_storage_bytes = 1;
      }
      return cudaSuccess;
    }

    // Force kernel code-generation in all compiler passes
    return InvokePasses(kernel_source.SegmentedRadixSortKernel(),
                        kernel_source.AltSegmentedRadixSortKernel(),
                        detail::radix::MakeRadixSortPolicyWrapper(policy));
  }

  //------------------------------------------------------------------------------
  // Dispatch entrypoints
  //------------------------------------------------------------------------------

  /**
   * @brief Internal dispatch routine
   *
   * @param[in] d_temp_storage
   *   Device-accessible allocation of temporary storage.  When nullptr, the required allocation size
   *   is written to `temp_storage_bytes` and no work is done.
   *
   * @param[in,out] temp_storage_bytes
   *   Reference to size in bytes of `d_temp_storage` allocation
   *
   * @param[in,out] d_keys
   *   Double-buffer whose current buffer contains the unsorted input keys and, upon return, is
   * updated to point to the sorted output keys
   *
   * @param[in,out] d_values
   *   Double-buffer whose current buffer contains the unsorted input values and, upon return, is
   *   updated to point to the sorted output values
   *
   * @param[in] num_items
   *   Number of items to sort
   *
   * @param[in] num_segments
   *   The number of segments that comprise the sorting data
   *
   * @param[in] d_begin_offsets
   *   Random-access input iterator to the sequence of beginning offsets of length
   *   `num_segments`, such that <tt>d_begin_offsets[i]</tt> is the first element of the
   *   <em>i</em><sup>th</sup> data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>
   *
   * @param[in] d_end_offsets
   *   Random-access input iterator to the sequence of ending offsets of length `num_segments`,
   *   such that <tt>d_end_offsets[i]-1</tt> is the last element of the <em>i</em><sup>th</sup>
   *   data segment in <tt>d_keys_*</tt> and <tt>d_values_*</tt>.
   *   If <tt>d_end_offsets[i]-1</tt> <= <tt>d_begin_offsets[i]</tt>,
   *   the <em>i</em><sup>th</sup> is considered empty.
   *
   * @param[in] begin_bit
   *   The beginning (least-significant) bit index needed for key comparison
   *
   * @param[in] end_bit
   *   The past-the-end (most-significant) bit index needed for key comparison
   *
   * @param[in] is_overwrite_okay
   *   Whether is okay to overwrite source buffers
   *
   * @param[in] stream
   *   CUDA stream to launch kernels within.  Default is stream<sub>0</sub>.
   */
  template <typename MaxPolicyT = typename PolicyHub::MaxPolicy>
  CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Dispatch(
    void* d_temp_storage,
    size_t& temp_storage_bytes,
    DoubleBuffer<KeyT>& d_keys,
    DoubleBuffer<ValueT>& d_values,
    ::cuda::std::int64_t num_items,
    ::cuda::std::int64_t num_segments,
    BeginOffsetIteratorT d_begin_offsets,
    EndOffsetIteratorT d_end_offsets,
    int begin_bit,
    int end_bit,
    bool is_overwrite_okay,
    cudaStream_t stream,
    KernelSource kernel_source             = {},
    KernelLauncherFactory launcher_factory = {},
    MaxPolicyT max_policy                  = {})
  {
    cudaError_t error;
    do
    {
      // Get PTX version
      int ptx_version = 0;

      error = CubDebug(launcher_factory.PtxVersion(ptx_version));
      if (cudaSuccess != error)
      {
        break;
      }

      // Create dispatch functor
      DispatchSegmentedRadixSort dispatch(
        d_temp_storage,
        temp_storage_bytes,
        d_keys,
        d_values,
        num_items,
        num_segments,
        d_begin_offsets,
        d_end_offsets,
        begin_bit,
        end_bit,
        is_overwrite_okay,
        stream,
        ptx_version,
        {},
        kernel_source,
        launcher_factory);

      // Dispatch to chained policy
      error = CubDebug(max_policy.Invoke(ptx_version, dispatch));
      if (cudaSuccess != error)
      {
        break;
      }
    } while (0);

    return error;
  }
};

CUB_NAMESPACE_END

_CCCL_DIAG_POP
