diff --git a/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp b/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp index a0ebe042c..b8140cd19 100644 --- a/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp +++ b/src/nnfusion/core/kernels/cuda_gpu/cuda_emitter.hpp @@ -287,6 +287,7 @@ namespace nnfusion << ctx->gnode->get_op_type(); log_cache.insert(ctx->gnode->get_op_type()); } + return; } kernel_info = diff --git a/src/nnfusion/engine/pass/graph/kernel_tuning.cpp b/src/nnfusion/engine/pass/graph/kernel_tuning.cpp index b71167944..1bee2b5cd 100644 --- a/src/nnfusion/engine/pass/graph/kernel_tuning.cpp +++ b/src/nnfusion/engine/pass/graph/kernel_tuning.cpp @@ -20,6 +20,7 @@ DEFINE_int64(fkernel_tuning_steps, 0, "Enable automatic kernel tuning for maximu DEFINE_string(ftuning_blocklist, "", "List of op types that skip kernel tuning pass, e.g., \"Softmax,Add\""); +DEFINE_string(ftuning_list, "", "List of op types for kernel tuning pass, e.g., \"Softmax,Add\""); DEFINE_string(fantares_perf_file, "./antares_perf.csv", "File to save Antares kernel performance."); DEFINE_string(ftuning_platform, "", "Antares platform: e.g., win64, xbox, etc."); DECLARE_bool(fantares_mode); @@ -114,6 +115,7 @@ void dump_perf(std::string filename, std::pair>, std::vector>> get_tuning_candidates(std::shared_ptr& graph, + const std::unordered_set tuning_list, const std::unordered_set block_list, std::unordered_map& ir2cnt) { @@ -125,12 +127,18 @@ std::pair>, std::vectorget_name(); } auto n_device_type = (*gnode)["DeviceType"].as(); NNFUSION_CHECK(n_device_type != UNKNOWN); + // filter ops not in TuningList + if (!tuning_list.empty() && tuning_list.find(gnode->get_op_type()) == tuning_list.end()) + { + continue; + } + // filter ops in BlockList if (block_list.find(gnode->get_op_type()) != block_list.end()) { @@ -228,18 +236,39 @@ std::pair>, std::vector status) { auto start = code.find("\n// Saved Perf ="); @@ -410,7 +439,14 @@ bool KernelTuning::run_on_graph(std::shared_ptr& graph) { if (FLAGS_fantares_mode) { + parse_tuning_list(); parse_block_list(); + for (auto item : TuningList) + { + NNFUSION_CHECK(BlockList.find(item) == BlockList.end()) + << "Kernel Tuning Pass: There are same operators in TuningList and " + "TuningBlockList."; + } // register antares kernels anyway here in case kernel selection pass will use them register_antares_kernel(); } @@ -424,7 +460,7 @@ bool KernelTuning::run_on_graph(std::shared_ptr& graph) std::vector> tuning_kernels; std::unordered_map ir2cnt; std::vector> nodes; - std::tie(nodes, tuned_kernels) = get_tuning_candidates(graph, BlockList, ir2cnt); + std::tie(nodes, tuned_kernels) = get_tuning_candidates(graph, TuningList, BlockList, ir2cnt); if (FLAGS_fantares_codegen_server.size() > 0) { @@ -535,4 +571,4 @@ bool KernelTuning::insert_to_kernel_cache(const std::vector>& nodes, std::vector>& tuned_kernels, @@ -51,6 +52,7 @@ namespace nnfusion private: std::unordered_set BlockList; + std::unordered_set TuningList; }; } }