BlockNormReduceCrossWarpSync< Problem_, Policy_ > Struct Template Reference

BlockNormReduceCrossWarpSync&lt; Problem_, Policy_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::BlockNormReduceCrossWarpSync< Problem_, Policy_ > Struct Template Reference
ck_tile::BlockNormReduceCrossWarpSync< Problem_, Policy_ > Struct Template Reference

#include <block_norm_reduce.hpp>

Public Types

using Problem = remove_cvref_t<Problem_>
using BlockShape = typename Problem::BlockShape
using smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>

Public Member Functions

template<typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void operator() (MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &count, void *smem)

Static Public Member Functions

template<typename MeanDistributedTensor_>
static CK_TILE_DEVICE constexpr index_t GetReduceWarps ()
template<typename MeanDistributedTensor_>
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize ()

Static Public Attributes

static constexpr bool kFastFDiv = Problem::kFastFDiv
static constexpr bool kWelford = Problem::kWelford

Member Typedef Documentation

◆ BlockShape

template<typename Problem_, typename Policy_ = void>
using ck_tile::BlockNormReduceCrossWarpSync< Problem_, Policy_ >::BlockShape = typename Problem::BlockShape

◆ Problem

template<typename Problem_, typename Policy_ = void>
using ck_tile::BlockNormReduceCrossWarpSync< Problem_, Policy_ >::Problem = remove_cvref_t<Problem_>

◆ smem_dtype

template<typename Problem_, typename Policy_ = void>
using ck_tile::BlockNormReduceCrossWarpSync< Problem_, Policy_ >::smem_dtype = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>

Member Function Documentation

◆ GetReduceWarps()

template<typename Problem_, typename Policy_ = void>
template<typename MeanDistributedTensor_>
CK_TILE_DEVICE constexpr index_t ck_tile::BlockNormReduceCrossWarpSync< Problem_, Policy_ >::GetReduceWarps ( )
inlinestaticconstexpr

◆ GetSmemSize()

template<typename Problem_, typename Policy_ = void>
template<typename MeanDistributedTensor_>
CK_TILE_HOST_DEVICE constexpr index_t ck_tile::BlockNormReduceCrossWarpSync< Problem_, Policy_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ operator()()

template<typename Problem_, typename Policy_ = void>
template<typename MeanDistributedTensor_, typename VarDistributedTensor_>
CK_TILE_DEVICE void ck_tile::BlockNormReduceCrossWarpSync< Problem_, Policy_ >::operator() ( MeanDistributedTensor_ & mean_tensor,
VarDistributedTensor_ & var_tensor,
int & count,
void * smem )
inline

Member Data Documentation

◆ kFastFDiv

template<typename Problem_, typename Policy_ = void>
bool ck_tile::BlockNormReduceCrossWarpSync< Problem_, Policy_ >::kFastFDiv = Problem::kFastFDiv
staticconstexpr

◆ kWelford

template<typename Problem_, typename Policy_ = void>
bool ck_tile::BlockNormReduceCrossWarpSync< Problem_, Policy_ >::kWelford = Problem::kWelford
staticconstexpr

The documentation for this struct was generated from the following file: