Skip to content

Commit bce6ec1

Browse files
tenpercentclaude
andcommitted
Optimize tensor descriptor functor template instantiation
Replace inline lambdas with named functor structs in transform_tensor_descriptor to reduce template instantiation overhead and improve compile times. Changes: - Add three named functors in tensor_descriptor.hpp: - convert_visible_to_hidden_id: maps visible dimension ID to hidden ID - convert_visible_ids_to_hidden_ids: maps sequence of visible IDs to hidden IDs - generate_arithmetic_sequence_from_scan: generates consecutive hidden dim ID ranges - Add utility functions in sequence_helper.hpp and tuple_helper.hpp: - unpack_and_merge_sequences(): unpacks tuple of sequences and merges them - generate_identity_sequences(): creates Tuple<Sequence<0>, Sequence<1>, ...> - Update 14 call sites across threadwise transfer, wrapper, and device files to use generate_identity_sequences() instead of generate_tuple with lambdas - Add comprehensive unit tests: - unit_sequence_helper.cpp: tests for new utility functions - unit_tensor_descriptor_functors.cpp: tests for new functors Co-Authored-By: Claude <[email protected]>
1 parent f16d910 commit bce6ec1

19 files changed

Lines changed: 549 additions & 73 deletions

include/ck/tensor_description/tensor_descriptor.hpp

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,9 @@ struct TensorDescriptor
3636

3737
__host__ __device__ static constexpr index_t GetNumOfHiddenDimension()
3838
{
39-
constexpr auto all_low_dim_ids = unpack(
40-
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, LowerDimensionIdss{});
39+
constexpr auto all_low_dim_ids = unpack_and_merge_sequences(LowerDimensionIdss{});
4140

42-
constexpr auto all_up_dim_ids = unpack(
43-
[](auto&&... xs) constexpr { return merge_sequences(xs...); }, UpperDimensionIdss{});
41+
constexpr auto all_up_dim_ids = unpack_and_merge_sequences(UpperDimensionIdss{});
4442

4543
constexpr auto all_dim_ids = merge_sequences(all_low_dim_ids, all_up_dim_ids);
4644

@@ -311,6 +309,41 @@ struct lambda_get_up_dim_num
311309
}
312310
};
313311

312+
// Maps a visible dimension ID to its corresponding hidden dimension ID
313+
template <typename OldTensorDescriptor>
314+
struct convert_visible_to_hidden_id
315+
{
316+
__host__ __device__ constexpr auto operator()(index_t low_dim_visible_id) const
317+
{
318+
return OldTensorDescriptor::GetVisibleDimensionIds().At(low_dim_visible_id);
319+
}
320+
};
321+
322+
// Maps a sequence of visible IDs to their corresponding hidden IDs
323+
template <typename OldTensorDescriptor>
324+
struct convert_visible_ids_to_hidden_ids
325+
{
326+
template <typename LowDimVisibleIds>
327+
__host__ __device__ constexpr auto operator()(LowDimVisibleIds low_dim_visible_ids) const
328+
{
329+
return transform_sequences(convert_visible_to_hidden_id<OldTensorDescriptor>{},
330+
low_dim_visible_ids);
331+
}
332+
};
333+
334+
// Generates consecutive ranges of hidden dimension IDs for each transform's upper dimensions
335+
template <index_t OldHiddenDimNumber, typename UpDimNumbersScan>
336+
struct generate_arithmetic_sequence_from_scan
337+
{
338+
template <typename I>
339+
__host__ __device__ constexpr auto operator()(I) const
340+
{
341+
constexpr index_t start = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{});
342+
constexpr index_t end = OldHiddenDimNumber + UpDimNumbersScan{}.At(I{} + Number<1>{});
343+
return typename arithmetic_sequence_gen<start, end, 1>::type{};
344+
}
345+
};
346+
314347
template <typename OldTensorDescriptor,
315348
typename NewTransforms,
316349
typename NewLowerDimensionOldVisibleIdss,
@@ -327,11 +360,11 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
327360
NewTransforms::Size() == NewUpperDimensionNewVisibleIdss::Size(),
328361
"wrong! inconsitent number of transform");
329362

330-
constexpr auto all_old_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
331-
NewLowerDimensionOldVisibleIdss{});
363+
constexpr auto all_old_top_ids =
364+
unpack_and_merge_sequences(NewLowerDimensionOldVisibleIdss{});
332365

333-
constexpr auto all_new_top_ids = unpack([](auto... xs) { return merge_sequences(xs...); },
334-
NewUpperDimensionNewVisibleIdss{});
366+
constexpr auto all_new_top_ids =
367+
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
335368

336369
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
337370
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
@@ -341,17 +374,9 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
341374
// lower dimension's hidden idss
342375
// convert lower dimension visible idss (tuple of sequences) to hidden idss (tuple of
343376
// sequences)
344-
constexpr auto low_dim_hidden_idss = transform_tuples(
345-
// convert lower dimension visible ids (a sequence) to hidden ids (a sequence)
346-
[](auto low_dim_visible_ids) constexpr {
347-
return transform_sequences(
348-
// convert lower dimension visible id to hidden id
349-
[](auto low_dim_visible_id) constexpr {
350-
return OldTensorDescriptor::GetVisibleDimensionIds()[low_dim_visible_id];
351-
},
352-
low_dim_visible_ids);
353-
},
354-
NewLowerDimensionOldVisibleIdss{});
377+
constexpr auto low_dim_hidden_idss =
378+
transform_tuples(convert_visible_ids_to_hidden_ids<OldTensorDescriptor>{},
379+
NewLowerDimensionOldVisibleIdss{});
355380

356381
constexpr index_t num_new_transform = NewTransforms::Size();
357382

@@ -364,22 +389,17 @@ transform_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
364389
constexpr auto up_dim_numbers_scan = merge_sequences(
365390
Sequence<0>{}, inclusive_scan_sequence(up_dim_numbers, math::plus<index_t>{}, Number<0>{}));
366391

392+
using UpDimNumbersScanType = remove_cvref_t<decltype(up_dim_numbers_scan)>;
367393
constexpr auto up_dim_hidden_idss = generate_tuple(
368-
[old_hidden_dim_number, up_dim_numbers_scan](auto i) constexpr {
369-
return
370-
typename arithmetic_sequence_gen<old_hidden_dim_number + up_dim_numbers_scan[i],
371-
old_hidden_dim_number + up_dim_numbers_scan[i + 1],
372-
1>::type{};
373-
},
394+
generate_arithmetic_sequence_from_scan<old_hidden_dim_number, UpDimNumbersScanType>{},
374395
Number<num_new_transform>{});
375396

376397
// new visible dimension's hidden ids
377398
constexpr auto unordered_new_visible_dim_hidden_ids =
378-
unpack([](auto... xs) constexpr { return merge_sequences(xs...); }, up_dim_hidden_idss);
399+
unpack_and_merge_sequences(up_dim_hidden_idss);
379400

380401
constexpr auto new_visible_dim_unordered2ordered =
381-
unpack([](auto... xs) constexpr { return merge_sequences(xs...); },
382-
NewUpperDimensionNewVisibleIdss{});
402+
unpack_and_merge_sequences(NewUpperDimensionNewVisibleIdss{});
383403

384404
constexpr auto new_visible_dim_hidden_ids =
385405
unordered_new_visible_dim_hidden_ids.ReorderGivenOld2New(new_visible_dim_unordered2ordered);

include/ck/tensor_operation/gpu/device/matrix_padder.hpp

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,8 @@ PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoP
4343
},
4444
Number<num_dim>{});
4545

46-
// lower dimension Id
47-
const auto lower_dimss =
48-
generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});
49-
50-
// upper dimension Id
46+
// lower/upper dimension Ids
47+
const auto lower_dimss = generate_identity_sequences<num_dim>();
5148
const auto upper_dimss = lower_dimss;
5249

5350
return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -739,8 +739,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
739739
},
740740
Number<nDim>{});
741741

742-
constexpr auto up_dim_idss =
743-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
742+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
744743

745744
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
746745
}

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_dequant.hpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -894,8 +894,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
894894
},
895895
Number<nDim>{});
896896

897-
constexpr auto up_dim_idss =
898-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
897+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
899898

900899
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
901900
}
@@ -944,8 +943,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
944943
},
945944
Number<nDim>{});
946945

947-
constexpr auto up_dim_idss =
948-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
946+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
949947

950948
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
951949
}
@@ -993,8 +991,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_dequant
993991
},
994992
Number<nDim>{});
995993

996-
constexpr auto up_dim_idss =
997-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
994+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
998995

999996
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
1000997
}

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -833,8 +833,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
833833
},
834834
Number<nDim>{});
835835

836-
constexpr auto up_dim_idss =
837-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
836+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
838837

839838
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
840839
}
@@ -892,8 +891,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
892891
},
893892
Number<nDim>{});
894893

895-
constexpr auto up_dim_idss =
896-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
894+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
897895

898896
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
899897
}

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -692,8 +692,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
692692
},
693693
Number<nDim>{});
694694

695-
constexpr auto up_dim_idss =
696-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
695+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
697696

698697
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
699698
}
@@ -744,8 +743,7 @@ struct ThreadwiseTensorSliceTransfer_v3r2
744743
},
745744
Number<nDim>{});
746745

747-
constexpr auto up_dim_idss =
748-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
746+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
749747

750748
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
751749
}

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r2.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -514,8 +514,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
514514
},
515515
Number<nDim>{});
516516

517-
constexpr auto up_dim_idss =
518-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
517+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
519518

520519
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
521520
}
@@ -563,8 +562,7 @@ struct ThreadwiseTensorSliceTransfer_v7r2
563562
},
564563
Number<nDim>{});
565564

566-
constexpr auto up_dim_idss =
567-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
565+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
568566

569567
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
570568
}

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -656,8 +656,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
656656
},
657657
Number<nDim>{});
658658

659-
constexpr auto up_dim_idss =
660-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
659+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
661660

662661
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
663662
}
@@ -706,8 +705,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3
706705
},
707706
Number<nDim>{});
708707

709-
constexpr auto up_dim_idss =
710-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
708+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
711709

712710
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
713711
}

include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -548,8 +548,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
548548
},
549549
Number<nDim>{});
550550

551-
constexpr auto up_dim_idss =
552-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
551+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
553552

554553
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
555554
}
@@ -598,8 +597,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
598597
},
599598
Number<nDim>{});
600599

601-
constexpr auto up_dim_idss =
602-
generate_tuple([&](auto i) { return Sequence<i.value>{}; }, Number<nDim>{});
600+
constexpr auto up_dim_idss = generate_identity_sequences<nDim>();
603601

604602
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
605603
}

include/ck/utility/sequence_helper.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include "ck/utility/functional4.hpp"
67
#include "ck/utility/tuple.hpp"
78

89
namespace ck {
@@ -34,4 +35,21 @@ __host__ __device__ constexpr auto to_sequence(Tuple<Number<Is>...>)
3435
return Sequence<Is...>{};
3536
}
3637

38+
// Functor wrapper for merge_sequences to enable reuse across call sites
39+
struct merge_sequences_functor
40+
{
41+
template <typename... Seqs>
42+
__host__ __device__ constexpr auto operator()(Seqs... seqs) const
43+
{
44+
return merge_sequences(seqs...);
45+
}
46+
};
47+
48+
// Unpacks tuple of sequences and merges them into a single sequence
49+
template <typename TupleOfSequences>
50+
__host__ __device__ constexpr auto unpack_and_merge_sequences(TupleOfSequences tuple_of_sequences)
51+
{
52+
return unpack(merge_sequences_functor{}, tuple_of_sequences);
53+
}
54+
3755
} // namespace ck

0 commit comments

Comments
 (0)