diff --git a/CMakeLists.txt b/CMakeLists.txt index a93974407..e2651275a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -85,6 +85,7 @@ set(SOURCE_FILES src/codegen/dispatchers/if_else_dispatcher.cpp src/codegen/dispatchers/sequence_dispatcher.cpp src/codegen/dispatchers/for_dispatcher.cpp + src/codegen/dispatchers/for_each_dispatcher.cpp src/codegen/dispatchers/map_dispatcher.cpp src/codegen/dispatchers/while_dispatcher.cpp src/control_flow/interstate_edge.cpp @@ -150,11 +151,13 @@ set(SOURCE_FILES src/passes/structured_control_flow/loop_normalization.cpp src/passes/structured_control_flow/sequence_fusion.cpp src/passes/structured_control_flow/while_to_for_conversion.cpp + src/passes/structured_control_flow/while_to_for_each_conversion.cpp src/passes/debug_info_propagation.cpp src/serializer/json_serializer.cpp src/structured_control_flow/block.cpp src/structured_control_flow/control_flow_node.cpp src/structured_control_flow/for.cpp + src/structured_control_flow/for_each.cpp src/structured_control_flow/if_else.cpp src/structured_control_flow/map.cpp src/structured_control_flow/return.cpp diff --git a/include/sdfg/analysis/data_dependency_analysis.h b/include/sdfg/analysis/data_dependency_analysis.h index 708ebe4de..be88c2a23 100644 --- a/include/sdfg/analysis/data_dependency_analysis.h +++ b/include/sdfg/analysis/data_dependency_analysis.h @@ -79,6 +79,15 @@ class DataDependencyAnalysis : public Analysis { std::unordered_map>& closed_definitions ); + void visit_for_each( + analysis::Users& users, + analysis::AssumptionsAnalysis& assumptions_analysis, + structured_control_flow::ForEach& for_each, + std::unordered_set& undefined, + std::unordered_map>& open_definitions, + std::unordered_map>& closed_definitions + ); + void visit_if_else( analysis::Users& users, analysis::AssumptionsAnalysis& assumptions_analysis, diff --git a/include/sdfg/builder/structured_sdfg_builder.h b/include/sdfg/builder/structured_sdfg_builder.h index 8faa7aaed..8be50a3da 100644 --- a/include/sdfg/builder/structured_sdfg_builder.h +++ b/include/sdfg/builder/structured_sdfg_builder.h @@ -14,6 +14,7 @@ #include "sdfg/structured_control_flow/return.h" #include "sdfg/structured_control_flow/sequence.h" #include "sdfg/structured_control_flow/while.h" +#include "sdfg/structured_control_flow/for_each.h" #include "sdfg/structured_sdfg.h" #include "sdfg/types/scalar.h" @@ -288,6 +289,38 @@ class StructuredSDFGBuilder : public FunctionBuilder { const DebugInfo& debug_info = DebugInfo() ); + ForEach& add_for_each( + Sequence& parent, + const symbolic::Symbol iterator, + const symbolic::Symbol end, + const symbolic::Symbol update, + const symbolic::Symbol init = SymEngine::null, + const sdfg::control_flow::Assignments& assignments = {}, + const DebugInfo& debug_info = DebugInfo() + ); + + ForEach& add_for_each_after( + Sequence& parent, + ControlFlowNode& child, + const symbolic::Symbol iterator, + const symbolic::Symbol end, + const symbolic::Symbol update, + const symbolic::Symbol init = SymEngine::null, + const sdfg::control_flow::Assignments& assignments = {}, + const DebugInfo& debug_info = DebugInfo() + ); + + ForEach& add_for_each_before( + Sequence& parent, + ControlFlowNode& child, + const symbolic::Symbol iterator, + const symbolic::Symbol end, + const symbolic::Symbol update, + const symbolic::Symbol init = SymEngine::null, + const sdfg::control_flow::Assignments& assignments = {}, + const DebugInfo& debug_info = DebugInfo() + ); + Continue& add_continue( Sequence& parent, const sdfg::control_flow::Assignments& assignments = {}, diff --git a/include/sdfg/codegen/dispatchers/for_each_dispatcher.h b/include/sdfg/codegen/dispatchers/for_each_dispatcher.h new file mode 100644 index 000000000..7703ca56e --- /dev/null +++ b/include/sdfg/codegen/dispatchers/for_each_dispatcher.h @@ -0,0 +1,28 @@ +#pragma once + +#include "sdfg/codegen/dispatchers/block_dispatcher.h" +#include "sdfg/codegen/dispatchers/node_dispatcher_registry.h" +#include "sdfg/codegen/dispatchers/sequence_dispatcher.h" + +namespace sdfg { +namespace codegen { + +class ForEachDispatcher : public NodeDispatcher { +private: + structured_control_flow::ForEach& node_; + +public: + ForEachDispatcher( + LanguageExtension& language_extension, + StructuredSDFG& sdfg, + structured_control_flow::ForEach& node, + InstrumentationPlan& instrumentation_plan + ); + + void dispatch_node( + PrettyPrinter& main_stream, PrettyPrinter& globals_stream, CodeSnippetFactory& library_snippet_factory + ) override; +}; + +} // namespace codegen +} // namespace sdfg diff --git a/include/sdfg/passes/structured_control_flow/while_to_for_each_conversion.h b/include/sdfg/passes/structured_control_flow/while_to_for_each_conversion.h new file mode 100644 index 000000000..61da3cc68 --- /dev/null +++ b/include/sdfg/passes/structured_control_flow/while_to_for_each_conversion.h @@ -0,0 +1,27 @@ +#pragma once + +#include "sdfg/passes/pass.h" + +namespace sdfg { +namespace passes { + +class WhileToForEachConversion : public Pass { + private: + bool can_be_applied(builder::StructuredSDFGBuilder& builder, + analysis::AnalysisManager& analysis_manager, + structured_control_flow::While& loop); + + void apply(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager, + structured_control_flow::Sequence& parent, structured_control_flow::While& loop); + + public: + WhileToForEachConversion(); + + std::string name() override; + + virtual bool run_pass(builder::StructuredSDFGBuilder& builder, + analysis::AnalysisManager& analysis_manager) override; +}; + +} // namespace passes +} // namespace sdfg diff --git a/include/sdfg/serializer/json_serializer.h b/include/sdfg/serializer/json_serializer.h index f8ade8270..e52846418 100644 --- a/include/sdfg/serializer/json_serializer.h +++ b/include/sdfg/serializer/json_serializer.h @@ -8,6 +8,7 @@ #include "sdfg/structured_control_flow/block.h" #include "sdfg/structured_control_flow/control_flow_node.h" #include "sdfg/structured_control_flow/map.h" +#include "sdfg/structured_control_flow/for_each.h" #include "sdfg/structured_control_flow/sequence.h" #include "sdfg/structured_control_flow/while.h" #include "sdfg/structured_sdfg.h" @@ -39,6 +40,7 @@ class JSONSerializer { void continue_node_to_json(nlohmann::json& j, const sdfg::structured_control_flow::Continue& continue_node); void return_node_to_json(nlohmann::json& j, const sdfg::structured_control_flow::Return& return_node); void map_to_json(nlohmann::json& j, const sdfg::structured_control_flow::Map& map_node); + void for_each_to_json(nlohmann::json& j, const sdfg::structured_control_flow::ForEach& for_each_node); void debug_info_to_json(nlohmann::json& j, const sdfg::DebugInfo& debug_info); @@ -104,6 +106,12 @@ class JSONSerializer { sdfg::structured_control_flow::Sequence& parent, control_flow::Assignments& assignments ); + void json_to_for_each_node( + const nlohmann::json& j, + sdfg::builder::StructuredSDFGBuilder& builder, + sdfg::structured_control_flow::Sequence& parent, + control_flow::Assignments& assignments + ); std::unique_ptr json_to_type(const nlohmann::json& j); std::vector> json_to_arguments(const nlohmann::json& j); DebugInfo json_to_debug_info(const nlohmann::json& j); diff --git a/include/sdfg/structured_control_flow/for_each.h b/include/sdfg/structured_control_flow/for_each.h new file mode 100644 index 000000000..c0d5787bd --- /dev/null +++ b/include/sdfg/structured_control_flow/for_each.h @@ -0,0 +1,60 @@ +#pragma once + +#include "sdfg/structured_control_flow/control_flow_node.h" +#include "sdfg/structured_control_flow/sequence.h" +#include "sdfg/symbolic/symbolic.h" + +namespace sdfg { + +namespace builder { +class StructuredSDFGBuilder; +} + +namespace structured_control_flow { + +class ForEach : public ControlFlowNode { + friend class sdfg::builder::StructuredSDFGBuilder; + +protected: + symbolic::Symbol iterator_; + symbolic::Symbol end_; + symbolic::Symbol update_; + symbolic::Symbol init_; + + std::unique_ptr root_; + + ForEach( + size_t element_id, + const DebugInfo& debug_info, + symbolic::Symbol iterator, + symbolic::Symbol end, + symbolic::Symbol update, + symbolic::Symbol init = SymEngine::null + ); + +public: + virtual ~ForEach() = default; + + ForEach(const ForEach& node) = delete; + ForEach& operator=(const ForEach&) = delete; + + void validate(const Function& function) const override; + + const symbolic::Symbol iterator() const; + + const symbolic::Symbol end() const; + + const symbolic::Symbol update() const; + + const symbolic::Symbol init() const; + + bool has_init() const; + + Sequence& root() const; + + void replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) override; + +}; + +} // namespace structured_control_flow +} // namespace sdfg diff --git a/include/sdfg/structured_control_flow/sequence.h b/include/sdfg/structured_control_flow/sequence.h index 1283cfe04..169de724d 100644 --- a/include/sdfg/structured_control_flow/sequence.h +++ b/include/sdfg/structured_control_flow/sequence.h @@ -15,6 +15,7 @@ class StructuredSDFGBuilder; namespace structured_control_flow { +class ForEach; class While; class StructuredLoop; class Sequence; @@ -58,6 +59,7 @@ class Sequence : public ControlFlowNode { friend class sdfg::StructuredSDFG; + friend class sdfg::structured_control_flow::ForEach; friend class sdfg::structured_control_flow::While; friend class sdfg::structured_control_flow::StructuredLoop; diff --git a/include/sdfg/structured_sdfg.h b/include/sdfg/structured_sdfg.h index d34fb04db..eeddeaaef 100644 --- a/include/sdfg/structured_sdfg.h +++ b/include/sdfg/structured_sdfg.h @@ -14,6 +14,7 @@ #include "sdfg/structured_control_flow/block.h" #include "sdfg/structured_control_flow/control_flow_node.h" #include "sdfg/structured_control_flow/for.h" +#include "sdfg/structured_control_flow/for_each.h" #include "sdfg/structured_control_flow/if_else.h" #include "sdfg/structured_control_flow/map.h" #include "sdfg/structured_control_flow/return.h" diff --git a/include/sdfg/visitor/structured_sdfg_visitor.h b/include/sdfg/visitor/structured_sdfg_visitor.h index 0f364717b..e0b3b6f7c 100644 --- a/include/sdfg/visitor/structured_sdfg_visitor.h +++ b/include/sdfg/visitor/structured_sdfg_visitor.h @@ -31,6 +31,8 @@ class StructuredSDFGVisitor { virtual bool accept(structured_control_flow::For& node); + virtual bool accept(structured_control_flow::ForEach& node); + virtual bool accept(structured_control_flow::While& node); virtual bool accept(structured_control_flow::Continue& node); diff --git a/src/analysis/data_dependency_analysis.cpp b/src/analysis/data_dependency_analysis.cpp index ce7e31bc3..033987d64 100644 --- a/src/analysis/data_dependency_analysis.cpp +++ b/src/analysis/data_dependency_analysis.cpp @@ -397,6 +397,113 @@ void DataDependencyAnalysis::visit_for( } } +void DataDependencyAnalysis::visit_for_each( + analysis::Users& users, + analysis::AssumptionsAnalysis& assumptions_analysis, + structured_control_flow::ForEach& for_each, + std::unordered_set& undefined, + std::unordered_map>& open_definitions, + std::unordered_map>& closed_definitions +) { + // Init - Read + if (for_each.has_init()) { + auto sym = for_each.init(); + auto current_user = users.get_user(sym->get_name(), &for_each, Use::READ); + + bool found = false; + for (auto& user : open_definitions) { + if (user.first->container() == sym->get_name()) { + user.second.insert(current_user); + found = true; + } + } + if (!found) { + undefined.insert(current_user); + } + } + + // Condition - Read + if (!symbolic::eq(for_each.end(), symbolic::__nullptr__())) { + auto end = for_each.end(); + auto current_user = users.get_user(end->get_name(), &for_each, Use::READ); + + bool found = false; + for (auto& user : open_definitions) { + if (user.first->container() == end->get_name()) { + user.second.insert(current_user); + found = true; + } + } + if (!found) { + undefined.insert(current_user); + } + } + + std::unordered_map> open_definitions_for_each; + std::unordered_map> closed_definitions_for_each; + std::unordered_set undefined_for_each; + + // Add assumptions for body + visit_sequence( + users, + assumptions_analysis, + for_each.root(), + undefined_for_each, + open_definitions_for_each, + closed_definitions_for_each + ); + + // Update - Read + { + auto update = for_each.update(); + auto current_user = users.get_user(update->get_name(), &for_each, Use::READ); + + bool found = false; + for (auto& user : open_definitions_for_each) { + if (user.first->container() == update->get_name()) { + user.second.insert(current_user); + found = true; + } + } + if (!found) { + undefined_for_each.insert(current_user); + } + } + + // Merge for with outside + + // Scope-local closed definitions + for (auto& entry : closed_definitions_for_each) { + closed_definitions.insert(entry); + } + + for (auto open_read : undefined_for_each) { + // Over-Approximation: Add loop-carried dependencies for all open reads + for (auto& entry : open_definitions_for_each) { + if (entry.first->container() == open_read->container()) { + entry.second.insert(open_read); + } + } + + // Connect to outside + bool found = false; + for (auto& entry : open_definitions) { + if (entry.first->container() == open_read->container()) { + entry.second.insert(open_read); + found = true; + } + } + if (!found) { + undefined.insert(open_read); + } + } + + // Add open definitions from while to outside + for (auto& entry : open_definitions_for_each) { + open_definitions.insert(entry); + } +} + void DataDependencyAnalysis::visit_if_else( analysis::Users& users, analysis::AssumptionsAnalysis& assumptions_analysis, @@ -620,6 +727,8 @@ void DataDependencyAnalysis::visit_sequence( visit_block(users, assumptions_analysis, *block, undefined, open_definitions, closed_definitions); } else if (auto for_loop = dynamic_cast(&child.first)) { visit_for(users, assumptions_analysis, *for_loop, undefined, open_definitions, closed_definitions); + } else if (auto for_each = dynamic_cast(&child.first)) { + visit_for_each(users, assumptions_analysis, *for_each, undefined, open_definitions, closed_definitions); } else if (auto if_else = dynamic_cast(&child.first)) { visit_if_else(users, assumptions_analysis, *if_else, undefined, open_definitions, closed_definitions); } else if (auto while_loop = dynamic_cast(&child.first)) { diff --git a/src/analysis/loop_analysis.cpp b/src/analysis/loop_analysis.cpp index cd5f665e4..1799c6a2f 100644 --- a/src/analysis/loop_analysis.cpp +++ b/src/analysis/loop_analysis.cpp @@ -26,6 +26,9 @@ void LoopAnalysis:: } else if (auto loop_stmt = dynamic_cast(current)) { this->loops_.push_back(loop_stmt); this->loop_tree_[loop_stmt] = parent_loop; + } else if (auto for_each_stmt = dynamic_cast(current)) { + this->loops_.push_back(for_each_stmt); + this->loop_tree_[for_each_stmt] = parent_loop; } if (dynamic_cast(current)) { @@ -42,6 +45,8 @@ void LoopAnalysis:: this->run(while_stmt->root(), while_stmt); } else if (auto for_stmt = dynamic_cast(current)) { this->run(for_stmt->root(), for_stmt); + } else if (auto for_each_stmt = dynamic_cast(current)) { + this->run(for_each_stmt->root(), for_each_stmt); } else if (dynamic_cast(current)) { continue; } else if (dynamic_cast(current)) { @@ -68,6 +73,10 @@ structured_control_flow::ControlFlowNode* LoopAnalysis::find_loop_by_indvar(cons if (loop_stmt->indvar()->get_name() == indvar) { return loop; } + } else if (auto for_each_stmt = dynamic_cast(loop)) { + if (for_each_stmt->iterator()->get_name() == indvar) { + return loop; + } } } return nullptr; diff --git a/src/analysis/scope_analysis.cpp b/src/analysis/scope_analysis.cpp index b6a333861..741a95535 100644 --- a/src/analysis/scope_analysis.cpp +++ b/src/analysis/scope_analysis.cpp @@ -25,6 +25,9 @@ void ScopeAnalysis:: } else if (auto for_stmt = dynamic_cast(current)) { this->scope_tree_[current] = parent_scope; this->run(&for_stmt->root(), current); + } else if (auto for_each_stmt = dynamic_cast(current)) { + this->scope_tree_[current] = parent_scope; + this->run(&for_each_stmt->root(), current); } else if (dynamic_cast(current)) { this->scope_tree_[current] = parent_scope; } else if (dynamic_cast(current)) { diff --git a/src/analysis/users.cpp b/src/analysis/users.cpp index 9669f50c0..3bcd084a5 100644 --- a/src/analysis/users.cpp +++ b/src/analysis/users.cpp @@ -459,6 +459,76 @@ std::pair Users::traverse(structured_control_flow: this->exits_.insert({ret_stmt, this->users_.at(v).get()}); return {v, boost::graph_traits::null_vertex()}; } + } else if (auto for_each_stmt = dynamic_cast(&node)) { + // Add source + auto s = boost::add_vertex(this->graph_); + this->users_.insert({s, std::make_unique(s, "", for_each_stmt, Use::NOP)}); + this->entries_.insert({for_each_stmt, this->users_.at(s).get()}); + auto last = s; + + // Add sink + auto t = boost::add_vertex(this->graph_); + this->users_.insert({t, std::make_unique(t, "", for_each_stmt, Use::NOP)}); + this->exits_.insert({for_each_stmt, this->users_.at(t).get()}); + + // Init + if (for_each_stmt->has_init()) { + auto v_init = boost::add_vertex(this->graph_); + this->add_user(std::make_unique(v_init, for_each_stmt->init()->get_name(), for_each_stmt, Use::READ)); + boost::add_edge(last, v_init, this->graph_); + last = v_init; + + auto v_iter = boost::add_vertex(this->graph_); + this->add_user(std::make_unique(v_iter, for_each_stmt->iterator()->get_name(), for_each_stmt, Use::MOVE) + ); + boost::add_edge(last, v_iter, this->graph_); + last = v_iter; + } + + // Condition + auto v_cond_iterator = boost::add_vertex(this->graph_); + this->add_user(std::make_unique< + User>(v_cond_iterator, for_each_stmt->iterator()->get_name(), for_each_stmt, Use::READ)); + boost::add_edge(last, v_cond_iterator, this->graph_); + last = v_cond_iterator; + + // End + auto end = for_each_stmt->end(); + if (!symbolic::eq(end, symbolic::__nullptr__())) { + auto v = boost::add_vertex(this->graph_); + this->add_user(std::make_unique(v, end->get_name(), for_each_stmt, Use::READ)); + boost::add_edge(last, v, this->graph_); + last = v; + } + + // Case: condition false + boost::add_edge(last, t, this->graph_); + + // Case: condition true -> body + auto subgraph = this->traverse(for_each_stmt->root()); + boost::add_edge(last, subgraph.first, this->graph_); + + // Update + auto v_update = boost::add_vertex(this->graph_); + this->add_user(std::make_unique(v_update, for_each_stmt->update()->get_name(), for_each_stmt, Use::READ)); + if (subgraph.second != boost::graph_traits::null_vertex()) { + boost::add_edge(subgraph.second, v_update, this->graph_); + } + last = v_update; + + auto v_update_iter = boost::add_vertex(this->graph_); + this->add_user(std::make_unique< + User>(v_update_iter, for_each_stmt->iterator()->get_name(), for_each_stmt, Use::MOVE)); + boost::add_edge(last, v_update_iter, this->graph_); + last = v_update_iter; + + // Connect to sink + boost::add_edge(last, t, this->graph_); + + // Back edge + boost::add_edge(t, s, this->graph_); + + return {s, t}; } throw std::invalid_argument("Invalid control flow node type"); diff --git a/src/builder/structured_sdfg_builder.cpp b/src/builder/structured_sdfg_builder.cpp index 40dbbfdf2..a2a9b3a95 100644 --- a/src/builder/structured_sdfg_builder.cpp +++ b/src/builder/structured_sdfg_builder.cpp @@ -1053,6 +1053,91 @@ Map& StructuredSDFGBuilder::add_map_after( return static_cast(*parent.children_.at(index + 1).get()); }; +ForEach& StructuredSDFGBuilder::add_for_each( + Sequence& parent, + const symbolic::Symbol iterator, + const symbolic::Symbol end, + const symbolic::Symbol update, + const symbolic::Symbol init, + const sdfg::control_flow::Assignments& assignments, + const DebugInfo& debug_info +) { + parent.children_ + .push_back(std::unique_ptr(new ForEach(this->new_element_id(), debug_info, iterator, end, update, init) + )); + + // Increment element id for body node + this->new_element_id(); + + parent.transitions_ + .push_back(std::unique_ptr(new Transition(this->new_element_id(), debug_info, parent, assignments)) + ); + + return static_cast(*parent.children_.back().get()); +}; + +ForEach& StructuredSDFGBuilder::add_for_each_before( + Sequence& parent, + ControlFlowNode& child, + const symbolic::Symbol iterator, + const symbolic::Symbol end, + const symbolic::Symbol update, + const symbolic::Symbol init, + const sdfg::control_flow::Assignments& assignments, + const DebugInfo& debug_info +) { + int index = parent.index(child); + if (index == -1) { + throw InvalidSDFGException("StructuredSDFGBuilder: Child not found"); + } + + parent.children_.insert( + parent.children_.begin() + index, + std::unique_ptr(new ForEach(this->new_element_id(), debug_info, iterator, end, update, init)) + ); + + // Increment element id for body node + this->new_element_id(); + + parent.transitions_.insert( + parent.transitions_.begin() + index, + std::unique_ptr(new Transition(this->new_element_id(), debug_info, parent, assignments)) + ); + + return static_cast(*parent.children_.at(index).get()); +}; + +ForEach& StructuredSDFGBuilder::add_for_each_after( + Sequence& parent, + ControlFlowNode& child, + const symbolic::Symbol iterator, + const symbolic::Symbol end, + const symbolic::Symbol update, + const symbolic::Symbol init, + const sdfg::control_flow::Assignments& assignments, + const DebugInfo& debug_info +) { + int index = parent.index(child); + if (index == -1) { + throw InvalidSDFGException("StructuredSDFGBuilder: Child not found"); + } + + parent.children_.insert( + parent.children_.begin() + index + 1, + std::unique_ptr(new ForEach(this->new_element_id(), debug_info, iterator, end, update, init)) + ); + + // Increment element id for body node + this->new_element_id(); + + parent.transitions_.insert( + parent.transitions_.begin() + index + 1, + std::unique_ptr(new Transition(this->new_element_id(), debug_info, parent, assignments)) + ); + + return static_cast(*parent.children_.at(index + 1).get()); +}; + Continue& StructuredSDFGBuilder:: add_continue(Sequence& parent, const sdfg::control_flow::Assignments& assignments, const DebugInfo& debug_info) { // Check if continue is in a loop diff --git a/src/codegen/dispatchers/for_each_dispatcher.cpp b/src/codegen/dispatchers/for_each_dispatcher.cpp new file mode 100644 index 000000000..e8b8d95cc --- /dev/null +++ b/src/codegen/dispatchers/for_each_dispatcher.cpp @@ -0,0 +1,51 @@ +#include "sdfg/codegen/dispatchers/for_each_dispatcher.h" + +namespace sdfg { +namespace codegen { + +ForEachDispatcher::ForEachDispatcher( + LanguageExtension& language_extension, + StructuredSDFG& sdfg, + structured_control_flow::ForEach& node, + InstrumentationPlan& instrumentation_plan +) + : NodeDispatcher(language_extension, sdfg, node, instrumentation_plan), node_(node) { + + }; + +void ForEachDispatcher::dispatch_node( + PrettyPrinter& main_stream, PrettyPrinter& globals_stream, CodeSnippetFactory& library_snippet_factory +) { + types::Pointer ptr_type; + types::Pointer ptr_ptr_type(static_cast(ptr_type)); + + std::string iterator = language_extension_.expression(node_.iterator()); + + main_stream << "for"; + main_stream << "("; + if (node_.has_init()) { + main_stream << iterator; + main_stream << " = "; + main_stream << "*(" << language_extension_.type_cast(language_extension_.expression(node_.init()), ptr_ptr_type) << ")"; + } + main_stream << ";"; + main_stream << iterator; + main_stream << " != "; + main_stream << language_extension_.expression(node_.end()); + main_stream << ";"; + main_stream << iterator; + main_stream << " = "; + main_stream << "*(" << language_extension_.type_cast(language_extension_.expression(node_.update()), ptr_ptr_type) << ")"; + main_stream << ")" << std::endl; + main_stream << "{" << std::endl; + + main_stream.setIndent(main_stream.indent() + 4); + SequenceDispatcher dispatcher(language_extension_, sdfg_, node_.root(), instrumentation_plan_); + dispatcher.dispatch(main_stream, globals_stream, library_snippet_factory); + main_stream.setIndent(main_stream.indent() - 4); + + main_stream << "}" << std::endl; +}; + +} // namespace codegen +} // namespace sdfg diff --git a/src/codegen/dispatchers/node_dispatcher_registry.cpp b/src/codegen/dispatchers/node_dispatcher_registry.cpp index a65606e1c..27a31ea62 100644 --- a/src/codegen/dispatchers/node_dispatcher_registry.cpp +++ b/src/codegen/dispatchers/node_dispatcher_registry.cpp @@ -2,6 +2,7 @@ #include "sdfg/codegen/dispatchers/block_dispatcher.h" #include "sdfg/codegen/dispatchers/for_dispatcher.h" +#include "sdfg/codegen/dispatchers/for_each_dispatcher.h" #include "sdfg/codegen/dispatchers/if_else_dispatcher.h" #include "sdfg/codegen/dispatchers/map_dispatcher.h" #include "sdfg/codegen/dispatchers/sequence_dispatcher.h" @@ -87,6 +88,17 @@ void register_default_dispatchers() { ); } ); + NodeDispatcherRegistry::instance().register_dispatcher( + typeid(structured_control_flow::ForEach), + [](LanguageExtension& language_extension, + StructuredSDFG& sdfg, + structured_control_flow::ControlFlowNode& node, + InstrumentationPlan& instrumentation) { + return std::make_unique( + language_extension, sdfg, static_cast(node), instrumentation + ); + } + ); NodeDispatcherRegistry::instance().register_dispatcher( typeid(structured_control_flow::Map), [](LanguageExtension& language_extension, diff --git a/src/passes/dataflow/byte_reference_elimination.cpp b/src/passes/dataflow/byte_reference_elimination.cpp index 1882c26ad..d3c8df8e1 100644 --- a/src/passes/dataflow/byte_reference_elimination.cpp +++ b/src/passes/dataflow/byte_reference_elimination.cpp @@ -34,6 +34,9 @@ bool ByteReferenceElimination:: } auto move = *moves.begin(); auto move_node = dynamic_cast(move->element()); + if (!move_node) { + continue; + } auto& move_graph = move_node->get_parent(); auto& move_edge = *move_graph.in_edges(*move_node).begin(); auto& move_type = move_edge.base_type(); diff --git a/src/passes/dataflow/dead_reference_elimination.cpp b/src/passes/dataflow/dead_reference_elimination.cpp index 287cd908d..e03d9e17d 100644 --- a/src/passes/dataflow/dead_reference_elimination.cpp +++ b/src/passes/dataflow/dead_reference_elimination.cpp @@ -41,6 +41,9 @@ bool DeadReferenceElimination:: for (auto& move : moves) { auto access_node = dynamic_cast(move->element()); + if (!access_node) { + continue; + } auto& graph = dynamic_cast(access_node->get_parent()); auto& block = dynamic_cast(*graph.get_parent()); builder.clear_node(block, *access_node); diff --git a/src/passes/dataflow/reference_propagation.cpp b/src/passes/dataflow/reference_propagation.cpp index 98f66843c..08810f8c5 100644 --- a/src/passes/dataflow/reference_propagation.cpp +++ b/src/passes/dataflow/reference_propagation.cpp @@ -44,9 +44,12 @@ bool ReferencePropagation::run_pass(builder::StructuredSDFGBuilder& builder, ana // Eliminate views auto uses = users.uses(container); for (auto& move : moves) { - auto& access_node = static_cast(*move->element()); + auto access_node = dynamic_cast(move->element()); + if (!access_node) { + continue; + } auto& dataflow = *move->parent(); - auto& move_edge = *dataflow.in_edges(access_node).begin(); + auto& move_edge = *dataflow.in_edges(*access_node).begin(); // Criterion: Must be a reference memlet if (move_edge.type() != data_flow::MemletType::Reference) { diff --git a/src/passes/debug_info_propagation.cpp b/src/passes/debug_info_propagation.cpp index 3468ac73c..74d61aef5 100644 --- a/src/passes/debug_info_propagation.cpp +++ b/src/passes/debug_info_propagation.cpp @@ -33,6 +33,9 @@ void DebugInfoPropagation::propagate(structured_control_flow::ControlFlowNode* c } else if (auto loop_stmt = dynamic_cast(current)) { this->propagate(&loop_stmt->root()); current_debug_info = DebugInfo::merge(current_debug_info, loop_stmt->root().debug_info()); + } else if (auto for_each_stmt = dynamic_cast(current)) { + this->propagate(&for_each_stmt->root()); + current_debug_info = DebugInfo::merge(current_debug_info, for_each_stmt->root().debug_info()); } else if (auto break_stmt = dynamic_cast(current)) { current_debug_info = DebugInfo::merge(current_debug_info, break_stmt->debug_info()); } else if (auto continue_stmt = dynamic_cast(current)) { diff --git a/src/passes/structured_control_flow/dead_cfg_elimination.cpp b/src/passes/structured_control_flow/dead_cfg_elimination.cpp index de958b592..b1ec99897 100644 --- a/src/passes/structured_control_flow/dead_cfg_elimination.cpp +++ b/src/passes/structured_control_flow/dead_cfg_elimination.cpp @@ -154,6 +154,9 @@ bool DeadCFGElimination::run_pass(builder::StructuredSDFGBuilder& builder, analy } else if (auto map_stmt = dynamic_cast(curr)) { auto& root = map_stmt->root(); queue.push_back(&root); + } else if (auto for_each_stmt = dynamic_cast(curr)) { + auto& root = for_each_stmt->root(); + queue.push_back(&root); } } diff --git a/src/passes/structured_control_flow/sequence_fusion.cpp b/src/passes/structured_control_flow/sequence_fusion.cpp index f142965f9..a623da48f 100644 --- a/src/passes/structured_control_flow/sequence_fusion.cpp +++ b/src/passes/structured_control_flow/sequence_fusion.cpp @@ -50,6 +50,8 @@ bool SequenceFusion::run_pass(builder::StructuredSDFGBuilder& builder, analysis: queue.push_back(&loop_stmt->root()); } else if (auto sloop_stmt = dynamic_cast(current)) { queue.push_back(&sloop_stmt->root()); + } else if (auto for_each_stmt = dynamic_cast(current)) { + queue.push_back(&for_each_stmt->root()); } } diff --git a/src/passes/structured_control_flow/while_to_for_conversion.cpp b/src/passes/structured_control_flow/while_to_for_conversion.cpp index aac54ddc8..6152f0a96 100644 --- a/src/passes/structured_control_flow/while_to_for_conversion.cpp +++ b/src/passes/structured_control_flow/while_to_for_conversion.cpp @@ -250,6 +250,8 @@ bool WhileToForConversion::run_pass(builder::StructuredSDFGBuilder& builder, ana queue.push_back(&loop_stmt->root()); } else if (auto sloop_stmt = dynamic_cast(current)) { queue.push_back(&sloop_stmt->root()); + } else if (auto for_each_stmt = dynamic_cast(current)) { + queue.push_back(&for_each_stmt->root()); } } diff --git a/src/passes/structured_control_flow/while_to_for_each_conversion.cpp b/src/passes/structured_control_flow/while_to_for_each_conversion.cpp new file mode 100644 index 000000000..d9045812d --- /dev/null +++ b/src/passes/structured_control_flow/while_to_for_each_conversion.cpp @@ -0,0 +1,294 @@ +#include "sdfg/passes/structured_control_flow/while_to_for_each_conversion.h" + +#include "sdfg/analysis/data_dependency_analysis.h" +#include "sdfg/builder/structured_sdfg_builder.h" + +namespace sdfg { +namespace passes { + +bool WhileToForEachConversion::can_be_applied( + builder::StructuredSDFGBuilder& builder, + analysis::AnalysisManager& analysis_manager, + structured_control_flow::While& loop +) { + auto& sdfg = builder.subject(); + auto& body = loop.root(); + if (loop.root().size() < 2) { + return false; + } + + // Identify break and continue conditions + auto end_of_body = body.at(body.size() - 1); + if (end_of_body.second.size() > 0) { + return false; + } + auto if_else_stmt = dynamic_cast(&end_of_body.first); + if (!if_else_stmt || if_else_stmt->size() != 2) { + return false; + } + + bool first_is_continue = false; + bool first_is_break = false; + auto& first_branch = if_else_stmt->at(0).first; + if (first_branch.size() != 1) { + return false; + } + auto first_condition = if_else_stmt->at(0).second; + if (dynamic_cast(&first_branch.at(0).first)) { + first_is_break = true; + } else if (dynamic_cast(&first_branch.at(0).first)) { + first_is_continue = true; + } + if (!first_is_break && !first_is_continue) { + return false; + } + + bool second_is_continue = false; + bool second_is_break = false; + auto& second_branch = if_else_stmt->at(1).first; + if (second_branch.size() != 1) { + return false; + } + auto second_condition = if_else_stmt->at(1).second; + if (dynamic_cast(&second_branch.at(0).first)) { + second_is_break = true; + } else if (dynamic_cast(&second_branch.at(0).first)) { + second_is_continue = true; + } + if (!second_is_break && !second_is_continue) { + return false; + } + if (first_is_break == second_is_break) { + return false; + } + + if (symbolic::atoms(first_condition).size() != 2 || + symbolic::atoms(second_condition).size() != 2) { + return false; + } + + // Criterion: Continue condition is an equality between two symbols + auto sym1 = *symbolic::atoms(first_condition).begin(); + auto sym2 = *(++symbolic::atoms(first_condition).begin()); + auto cont_condition = symbolic::Ne(sym1, sym2); + auto cont_condition_alt = symbolic::Eq(symbolic::__false__(), symbolic::Eq(sym1, sym2)); + if (first_is_continue && !(symbolic::eq(cont_condition, first_condition) || symbolic::eq(cont_condition_alt, first_condition))) { + return false; + } + if (second_is_continue && !(symbolic::eq(cont_condition, second_condition) || symbolic::eq(cont_condition_alt, second_condition))) { + return false; + } + if (!symbolic::eq(first_condition, symbolic::Not(second_condition))) { + return false; + } + + // We know that the while body ends with an if-else continue-break structure + // We now check that there is exactly one iterator, which is moved once per + // iteration + // All other variables in the condition must be constants + + auto& all_users = analysis_manager.get(); + analysis::UsersView body_users(all_users, body); + + // Candidates: all symbols in the condition which are pointers + analysis::User* update = nullptr; + for (auto& sym : symbolic::atoms(first_condition)) { + if (symbolic::eq(sym, symbolic::__nullptr__())) { + continue; + } + auto& type = sdfg.type(sym->get_name()); + if (!dynamic_cast(&type)) { + return false; + } + + auto moves = body_users.moves(sym->get_name()); + // Not an iterator + if (moves.empty()) { + continue; + } + // Not well-formed + if (moves.size() > 1) { + return false; + } + // Exactly one iterator + if (update != nullptr) { + return false; + } + update = moves.at(0); + } + if (update == nullptr) { + return false; + } + auto iterator = symbolic::symbol(update->container()); + + // Criterion: Update is a dereference memlet + // iterator = *ptr + auto move_dst = dynamic_cast(update->element()); + auto& graph = move_dst->get_parent(); + auto& block = static_cast(*graph.get_parent()); + auto& move_edge = *graph.in_edges(*move_dst).begin(); + auto& move_src = static_cast(move_edge.src()); + if (move_edge.type() != data_flow::MemletType::Dereference_Src) { + return false; + } + + // Criterion: Update happens right before the condition + if (body.index(block) != body.size() - 2) { + return false; + } + + // No other continue, break or return inside loop body + std::list queue = {&loop.root()}; + while (!queue.empty()) { + auto current = queue.front(); + queue.pop_front(); + if (dynamic_cast(current)) { + return false; + } else if (dynamic_cast(current)) { + return false; + } else if (dynamic_cast(current)) { + return false; + } + + if (auto sequence_stmt = dynamic_cast(current)) { + for (size_t i = 0; i < sequence_stmt->size(); i++) { + queue.push_back(&sequence_stmt->at(i).first); + } + } else if (auto if_else = dynamic_cast(current)) { + // Ignore the if_else_stmt + if (if_else == if_else_stmt) { + continue; + } + for (size_t i = 0; i < if_else->size(); i++) { + queue.push_back(&if_else->at(i).first); + } + } + } + + return true; +} + +void WhileToForEachConversion::apply( + builder::StructuredSDFGBuilder& builder, + analysis::AnalysisManager& analysis_manager, + structured_control_flow::Sequence& parent, + structured_control_flow::While& loop +) { + auto& sdfg = builder.subject(); + auto& body = loop.root(); + + // Identify break and continue conditions + auto last_element = body.at(body.size() - 1); + auto if_else_stmt = dynamic_cast(&last_element.first); + + auto first_condition = if_else_stmt->at(0).second; + auto second_condition = if_else_stmt->at(1).second; + + bool second_is_break = false; + auto& second_branch = if_else_stmt->at(1).first; + if (dynamic_cast(&second_branch.at(0).first)) { + second_is_break = true; + } + + // Identify iterator + auto& all_users = analysis_manager.get(); + analysis::UsersView body_users(all_users, body); + analysis::User* update = nullptr; + for (auto& sym : symbolic::atoms(first_condition)) { + if (symbolic::eq(sym, symbolic::__nullptr__())) { + continue; + } + auto moves = body_users.moves(sym->get_name()); + if (moves.size() == 1) { + update = moves.at(0); + break; + } + } + symbolic::Symbol iterator = symbolic::symbol(update->container()); + symbolic::Symbol end = SymEngine::null; + for (auto& atom : symbolic::atoms(second_condition)) { + if (atom->get_name() != iterator->get_name()) { + end = atom; + break; + } + } + + // Identify update / move statement + auto move_dst = dynamic_cast(update->element()); + auto& graph = move_dst->get_parent(); + auto& move_edge = *graph.in_edges(*move_dst).begin(); + auto& move_src = static_cast(move_edge.src()); + symbolic::Symbol update_ptr = symbolic::symbol(move_src.data()); + + // Remove update from block + builder.remove_child(body, body.size() - 2); + // Remove the if-else statement + builder.remove_child(body, body.size() - 1); + + // find index of while + int while_index = parent.index(loop); + auto& transition = parent.at(while_index).second; + + // Create for-each loop + auto& for_each = builder.add_for_each_after( + parent, + loop, + iterator, + end, + update_ptr, + SymEngine::null, + transition.assignments(), + loop.debug_info() + ); + builder.move_children(body, for_each.root()); + + // Remove while loop + builder.remove_child(parent, while_index); +}; + +WhileToForEachConversion::WhileToForEachConversion() + : Pass() { + + }; + +std::string WhileToForEachConversion::name() { return "WhileToForEachConversion"; }; + +bool WhileToForEachConversion::run_pass(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) { + bool applied = false; + + // Traverse structured SDFG + std::list queue = {&builder.subject().root()}; + while (!queue.empty()) { + auto current = queue.front(); + queue.pop_front(); + + // Add children to queue + if (auto sequence_stmt = dynamic_cast(current)) { + for (size_t i = 0; i < sequence_stmt->size(); i++) { + if (auto match = dynamic_cast(&sequence_stmt->at(i).first)) { + if (this->can_be_applied(builder, analysis_manager, *match)) { + this->apply(builder, analysis_manager, *sequence_stmt, *match); + return true; + } + } + + queue.push_back(&sequence_stmt->at(i).first); + } + } else if (auto if_else = dynamic_cast(current)) { + for (size_t i = 0; i < if_else->size(); i++) { + queue.push_back(&if_else->at(i).first); + } + } else if (auto loop_stmt = dynamic_cast(current)) { + queue.push_back(&loop_stmt->root()); + } else if (auto sloop_stmt = dynamic_cast(current)) { + queue.push_back(&sloop_stmt->root()); + } else if (auto for_each = dynamic_cast(current)) { + queue.push_back(&for_each->root()); + } + } + + return applied; +}; + +} // namespace passes +} // namespace sdfg diff --git a/src/passes/symbolic/symbol_evolution.cpp b/src/passes/symbolic/symbol_evolution.cpp index 31aa5700f..1636ad243 100644 --- a/src/passes/symbolic/symbol_evolution.cpp +++ b/src/passes/symbolic/symbol_evolution.cpp @@ -244,6 +244,8 @@ bool SymbolEvolution::run_pass(builder::StructuredSDFGBuilder& builder, analysis queue.push_back(&loop_stmt->root()); } else if (auto sloop_stmt = dynamic_cast(current)) { queue.push_back(&sloop_stmt->root()); + } else if (auto for_each_stmt = dynamic_cast(current)) { + queue.push_back(&for_each_stmt->root()); } } diff --git a/src/passes/symbolic/symbol_promotion.cpp b/src/passes/symbolic/symbol_promotion.cpp index 4039cc244..51e0f9c36 100644 --- a/src/passes/symbolic/symbol_promotion.cpp +++ b/src/passes/symbolic/symbol_promotion.cpp @@ -289,6 +289,8 @@ bool SymbolPromotion::run_pass(builder::StructuredSDFGBuilder& builder, analysis queue.push_back(&loop_stmt->root()); } else if (auto sloop_stmt = dynamic_cast(current)) { queue.push_back(&sloop_stmt->root()); + } else if (auto for_each_stmt = dynamic_cast(current)) { + queue.push_back(&for_each_stmt->root()); } } diff --git a/src/serializer/json_serializer.cpp b/src/serializer/json_serializer.cpp index 12d3030cd..c5add42fa 100644 --- a/src/serializer/json_serializer.cpp +++ b/src/serializer/json_serializer.cpp @@ -288,6 +288,25 @@ void JSONSerializer::map_to_json(nlohmann::json& j, const structured_control_flo j["root"] = body_json; } +void JSONSerializer::for_each_to_json(nlohmann::json& j, const structured_control_flow::ForEach& for_each_node) { + j["type"] = "for_each"; + j["element_id"] = for_each_node.element_id(); + + j["debug_info"] = nlohmann::json::object(); + debug_info_to_json(j["debug_info"], for_each_node.debug_info()); + + j["iterator"] = expression(for_each_node.iterator()); + j["end"] = expression(for_each_node.end()); + j["update"] = expression(for_each_node.update()); + if (for_each_node.has_init()) { + j["init"] = expression(for_each_node.init()); + } + + nlohmann::json body_json; + sequence_to_json(body_json, for_each_node.root()); + j["root"] = body_json; +} + void JSONSerializer::return_node_to_json(nlohmann::json& j, const structured_control_flow::Return& return_node) { j["type"] = "return"; j["element_id"] = return_node.element_id(); @@ -323,6 +342,8 @@ void JSONSerializer::sequence_to_json(nlohmann::json& j, const structured_contro block_to_json(child_json, *block); } else if (auto for_node = dynamic_cast(&child)) { for_to_json(child_json, *for_node); + } else if (auto for_each_node = dynamic_cast(&child)) { + for_each_to_json(child_json, *for_each_node); } else if (auto sequence_node = dynamic_cast(&child)) { sequence_to_json(child_json, *sequence_node); } else if (auto condition_node = dynamic_cast(&child)) { @@ -702,6 +723,8 @@ void JSONSerializer::json_to_sequence( json_to_block_node(child, builder, sequence, assignments); } else if (child["type"] == "for") { json_to_for_node(child, builder, sequence, assignments); + } else if (child["type"] == "for_each") { + json_to_for_each_node(child, builder, sequence, assignments); } else if (child["type"] == "if_else") { json_to_if_else_node(child, builder, sequence, assignments); } else if (child["type"] == "while") { @@ -915,6 +938,45 @@ void JSONSerializer::json_to_map_node( json_to_sequence(j["root"], builder, map_node.root()); } +void JSONSerializer::json_to_for_each_node( + const nlohmann::json& j, + builder::StructuredSDFGBuilder& builder, + structured_control_flow::Sequence& parent, + control_flow::Assignments& assignments +) { + assert(j.contains("type")); + assert(j["type"].is_string()); + assert(j["type"] == "for_each"); + assert(j.contains("iterator")); + assert(j["iterator"].is_string()); + assert(j.contains("end")); + assert(j["end"].is_string()); + assert(j.contains("update")); + assert(j["update"].is_string()); + if (j.contains("init")) { + assert(j["init"].is_string()); + } + assert(j.contains("root")); + assert(j["root"].is_object()); + + symbolic::Symbol iterator = symbolic::symbol(j["iterator"]); + symbolic::Symbol end = symbolic::symbol(j["end"]); + symbolic::Symbol update = symbolic::symbol(j["update"]); + symbolic::Symbol init = SymEngine::null; + if (j.contains("init")) { + init = symbolic::symbol(j["init"]); + } + + auto& for_each_node = + builder.add_for_each(parent, iterator, end, update, init, assignments, json_to_debug_info(j["debug_info"])); + for_each_node.element_id_ = j["element_id"]; + + assert(j["root"].contains("type")); + assert(j["root"]["type"].is_string()); + assert(j["root"]["type"] == "sequence"); + json_to_sequence(j["root"], builder, for_each_node.root()); +} + void JSONSerializer::json_to_return_node( const nlohmann::json& j, builder::StructuredSDFGBuilder& builder, diff --git a/src/structured_control_flow/for_each.cpp b/src/structured_control_flow/for_each.cpp new file mode 100644 index 000000000..c5d88a7fb --- /dev/null +++ b/src/structured_control_flow/for_each.cpp @@ -0,0 +1,95 @@ +#include "sdfg/structured_control_flow/for_each.h" + +#include "sdfg/function.h" + +namespace sdfg { +namespace structured_control_flow { + +ForEach::ForEach( + size_t element_id, + const DebugInfo& debug_info, + symbolic::Symbol iterator, + symbolic::Symbol end, + symbolic::Symbol update, + symbolic::Symbol init +) + : ControlFlowNode(element_id, debug_info), iterator_(iterator), update_(update), end_(end), init_(init) { + this->root_ = std::unique_ptr(new Sequence(++element_id, debug_info)); +} + +void ForEach::validate(const Function& function) const { + root_->validate(function); + + if (iterator_.is_null()) { + throw InvalidSDFGException("ForEach node has a null iterator."); + } + if (end_.is_null()) { + throw InvalidSDFGException("ForEach node has a null end."); + } + if (update_.is_null()) { + throw InvalidSDFGException("ForEach node has a null update."); + } + + // Criterion: Iterator must be pointer + auto& iterator_type = function.type(iterator_->get_name()); + if (iterator_type.type_id() != types::TypeID::Pointer) { + throw InvalidSDFGException("ForEach iterator must be of pointer type."); + } + + // Criterion: End must be pointer + if (!symbolic::eq(end_, symbolic::__nullptr__())) { + auto& end_type = function.type(end_->get_name()); + if (end_type.type_id() != types::TypeID::Pointer) { + throw InvalidSDFGException("ForEach end must be of pointer type."); + } + } + + // Criterion: Update must be pointer + auto& update_type = function.type(update_->get_name()); + if (update_type.type_id() != types::TypeID::Pointer) { + throw InvalidSDFGException("ForEach update must be of pointer type."); + } + + // Criterion: Init must be pointer + if (!init_.is_null()) { + auto& init_type = function.type(init_->get_name()); + if (init_type.type_id() != types::TypeID::Pointer) { + throw InvalidSDFGException("ForEach init must be of pointer type."); + } + } +}; + +const symbolic::Symbol ForEach::iterator() const { return iterator_; } + +const symbolic::Symbol ForEach::end() const { return end_; } + +const symbolic::Symbol ForEach::update() const { return update_; } + +const symbolic::Symbol ForEach::init() const { return init_; } + +bool ForEach::has_init() const { return !init_.is_null(); } + +Sequence& ForEach::root() const { return *root_; } + +void ForEach::replace(const symbolic::Expression old_expression, const symbolic::Expression new_expression) { + root_->replace(old_expression, new_expression); + + if (symbolic::eq(iterator_, old_expression)) { + iterator_ = SymEngine::rcp_dynamic_cast(new_expression); + } + + if (symbolic::eq(end_, old_expression)) { + end_ = SymEngine::rcp_dynamic_cast(new_expression); + } + + if (symbolic::eq(update_, old_expression)) { + update_ = SymEngine::rcp_dynamic_cast(new_expression); + } + + if (symbolic::eq(init_, old_expression)) { + init_ = SymEngine::rcp_dynamic_cast(new_expression); + } +} + +} // namespace structured_control_flow +} // namespace sdfg diff --git a/src/visitor/structured_sdfg_visitor.cpp b/src/visitor/structured_sdfg_visitor.cpp index 70e65ce24..2a37dcd79 100644 --- a/src/visitor/structured_sdfg_visitor.cpp +++ b/src/visitor/structured_sdfg_visitor.cpp @@ -43,6 +43,14 @@ bool StructuredSDFGVisitor::visit(structured_control_flow::Sequence& parent) { if (this->visit(for_stmt->root())) { return true; } + } else if (auto for_each_stmt = dynamic_cast(¤t)) { + if (this->accept(*for_each_stmt)) { + return true; + } + + if (this->visit(for_each_stmt->root())) { + return true; + } } else if (auto map_stmt = dynamic_cast(¤t)) { if (this->accept(*map_stmt)) { return true; @@ -93,6 +101,8 @@ bool StructuredSDFGVisitor::accept(structured_control_flow::Break& node) { retur bool StructuredSDFGVisitor::accept(structured_control_flow::For& node) { return false; }; +bool StructuredSDFGVisitor::accept(structured_control_flow::ForEach& node) { return false; }; + bool StructuredSDFGVisitor::accept(structured_control_flow::Map& node) { return false; }; } // namespace visitor diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ca9e3c301..0b4642453 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -49,12 +49,14 @@ set(TEST_FILES passes/structured_control_flow/dead_cfg_elimination_test.cpp passes/structured_control_flow/for2map_test.cpp passes/structured_control_flow/loop_normalization_test.cpp + passes/structured_control_flow/while_to_for_each_conversion_test.cpp passes/debug_info_propagation_test.cpp passes/symbolic/symbol_promotion_test.cpp passes/symbolic/symbol_propagation_test.cpp replace/symbol_replace_test.cpp sdfg_test.cpp serializer/json_serializer_test.cpp + structured_control_flow/for_each_test.cpp structured_control_flow/map_test.cpp structured_sdfg_test.cpp symbolic/assumptions_test.cpp diff --git a/tests/builder/structured_sdfg_builder_test.cpp b/tests/builder/structured_sdfg_builder_test.cpp index fc1d618a4..3e4c61092 100644 --- a/tests/builder/structured_sdfg_builder_test.cpp +++ b/tests/builder/structured_sdfg_builder_test.cpp @@ -667,3 +667,179 @@ TEST(StructuredSDFGBuilderTest, FindElementById_Block) { EXPECT_EQ(builder.find_element_by_id(block.element_id()), &block); } + +TEST(StructuredSDFGBuilderTest, addForEach) { + builder::StructuredSDFGBuilder builder("sdfg_1", FunctionType_CPU); + + types::Pointer opaque_desc; + builder.add_container("list", opaque_desc, true); + builder.add_container("iter", opaque_desc); + + auto sym_list = symbolic::symbol("list"); + auto sym_iter = symbolic::symbol("iter"); + + auto& root = builder.subject().root(); + EXPECT_EQ(root.element_id(), 0); + + /** + * Doubled linked list with start and end 'list' + * iter: { + * next: ptr; + * value: ... + * } + * for (auto iter = *list; iter != list; iter = *iter) { + * + * } + */ + + auto& scope = builder.add_for_each(root, sym_iter, sym_list, sym_iter, sym_list); + EXPECT_EQ(scope.element_id(), 1); + EXPECT_EQ(scope.root().element_id(), 2); + EXPECT_EQ(root.at(0).second.element_id(), 3); + + EXPECT_EQ(scope.has_init(), true); + EXPECT_TRUE(symbolic::eq(scope.init(), sym_list)); + EXPECT_TRUE(symbolic::eq(scope.iterator(), sym_iter)); + EXPECT_TRUE(symbolic::eq(scope.end(), sym_list)); + EXPECT_TRUE(symbolic::eq(scope.update(), sym_iter)); + + auto sdfg = builder.move(); + + EXPECT_EQ(sdfg->root().size(), 1); + auto child = sdfg->root().at(0); + EXPECT_EQ(&child.first, &scope); + EXPECT_EQ(child.second.size(), 0); +} + +TEST(StructuredSDFGBuilderTest, addForEach_Transition) { + builder::StructuredSDFGBuilder builder("sdfg_1", FunctionType_CPU); + + types::Pointer opaque_desc; + builder.add_container("list", opaque_desc, true); + builder.add_container("iter", opaque_desc); + + types::Scalar int_desc(types::PrimitiveType::Int64); + builder.add_container("N", int_desc, true); + + auto sym_list = symbolic::symbol("list"); + auto sym_iter = symbolic::symbol("iter"); + + auto& root = builder.subject().root(); + EXPECT_EQ(root.element_id(), 0); + + /** + * Doubled linked list with start and end 'list' + * iter: { + * next: ptr; + * value: ... + * } + * for (auto iter = *list; iter != list; iter = *iter) { + * + * } + */ + + auto& scope = + builder.add_for_each(root, sym_iter, sym_list, sym_iter, sym_list, {{symbolic::symbol("N"), symbolic::zero()}}); + EXPECT_EQ(scope.element_id(), 1); + EXPECT_EQ(scope.root().element_id(), 2); + EXPECT_EQ(root.at(0).second.element_id(), 3); + + EXPECT_EQ(scope.has_init(), true); + EXPECT_TRUE(symbolic::eq(scope.init(), sym_list)); + EXPECT_TRUE(symbolic::eq(scope.iterator(), sym_iter)); + EXPECT_TRUE(symbolic::eq(scope.end(), sym_list)); + EXPECT_TRUE(symbolic::eq(scope.update(), sym_iter)); + + auto sdfg = builder.move(); + + EXPECT_EQ(sdfg->root().size(), 1); + auto child = sdfg->root().at(0); + EXPECT_EQ(&child.first, &scope); + EXPECT_EQ(child.second.size(), 1); + EXPECT_TRUE(symbolic::eq(child.second.assignments().at(symbolic::symbol("N")), symbolic::zero())); +} + +TEST(StructuredSDFGBuilderTest, addForEachBefore) { + builder::StructuredSDFGBuilder builder("sdfg_1", FunctionType_CPU); + + types::Pointer opaque_desc; + builder.add_container("list", opaque_desc, true); + builder.add_container("iter", opaque_desc); + + types::Scalar int_desc(types::PrimitiveType::Int64); + builder.add_container("N", int_desc, true); + + auto sym_list = symbolic::symbol("list"); + auto sym_iter = symbolic::symbol("iter"); + + auto& root = builder.subject().root(); + EXPECT_EQ(root.element_id(), 0); + + auto& block_base = + builder.add_block(root, control_flow::Assignments{{symbolic::symbol("N"), SymEngine::integer(10)}}); + + auto& scope = builder.add_for_each_before( + root, block_base, sym_iter, sym_list, sym_iter, SymEngine::null, {{symbolic::symbol("N"), symbolic::zero()}} + ); + + EXPECT_EQ(scope.has_init(), false); + EXPECT_TRUE(symbolic::eq(scope.iterator(), sym_iter)); + EXPECT_TRUE(symbolic::eq(scope.end(), sym_list)); + EXPECT_TRUE(symbolic::eq(scope.update(), sym_iter)); + + auto sdfg = builder.move(); + + auto child = sdfg->root().at(0); + EXPECT_EQ(&child.first, &scope); + EXPECT_EQ(child.second.size(), 1); + EXPECT_TRUE(symbolic::eq(child.second.assignments().at(symbolic::symbol("N")), symbolic::zero())); +} + +TEST(StructuredSDFGBuilderTest, addForEachAfter) { + builder::StructuredSDFGBuilder builder("sdfg_1", FunctionType_CPU); + + types::Pointer opaque_desc; + builder.add_container("list", opaque_desc, true); + builder.add_container("iter", opaque_desc); + + types::Scalar int_desc(types::PrimitiveType::Int64); + builder.add_container("N", int_desc, true); + + auto sym_list = symbolic::symbol("list"); + auto sym_iter = symbolic::symbol("iter"); + + auto& root = builder.subject().root(); + EXPECT_EQ(root.element_id(), 0); + + auto& block_base = + builder.add_block(root, control_flow::Assignments{{symbolic::symbol("N"), SymEngine::integer(10)}}); + auto& block_base2 = + builder.add_block(root, control_flow::Assignments{{symbolic::symbol("N"), SymEngine::integer(10)}}); + + /** + * Doubled linked list with start and end 'list' + * iter: { + * next: ptr; + * value: ... + * } + * for (auto iter = *list; iter != list; iter = *iter) { + * + * } + */ + + auto& scope = builder.add_for_each_after( + root, block_base, sym_iter, sym_list, sym_iter, SymEngine::null, {{symbolic::symbol("N"), symbolic::zero()}} + ); + + EXPECT_EQ(scope.has_init(), false); + EXPECT_TRUE(symbolic::eq(scope.iterator(), sym_iter)); + EXPECT_TRUE(symbolic::eq(scope.end(), sym_list)); + EXPECT_TRUE(symbolic::eq(scope.update(), sym_iter)); + + auto sdfg = builder.move(); + + auto child = sdfg->root().at(1); + EXPECT_EQ(&child.first, &scope); + EXPECT_EQ(child.second.size(), 1); + EXPECT_TRUE(symbolic::eq(child.second.assignments().at(symbolic::symbol("N")), symbolic::zero())); +} diff --git a/tests/passes/structured_control_flow/while_to_for_each_conversion_test.cpp b/tests/passes/structured_control_flow/while_to_for_each_conversion_test.cpp new file mode 100644 index 000000000..edb126840 --- /dev/null +++ b/tests/passes/structured_control_flow/while_to_for_each_conversion_test.cpp @@ -0,0 +1,192 @@ +#include "sdfg/passes/structured_control_flow/while_to_for_each_conversion.h" + +#include + +#include "sdfg/builder/structured_sdfg_builder.h" +#include "sdfg/symbolic/symbolic.h" + +using namespace sdfg; + +TEST(WhileToForEachConversionTest, LinkedList) { + builder::StructuredSDFGBuilder builder("sdfg_test", FunctionType_CPU); + + auto& sdfg = builder.subject(); + auto& root = sdfg.root(); + + // Add containers + types::Pointer opaque_desc; + builder.add_container("list", opaque_desc, true); + builder.add_container("iter", opaque_desc); + + auto sym_iter = symbolic::symbol("iter"); + auto sym_list = symbolic::symbol("list"); + auto sym_nullptr = symbolic::__nullptr__(); + + // Reinterpret cast pointers to pointer(pointer) for dereferencing + types::Pointer ptr_ptr_desc(static_cast(opaque_desc)); + + // Init: iter = *list + { + auto& block = builder.add_block(root); + auto& list = builder.add_access(block, "list"); + auto& iter = builder.add_access(block, "iter"); + builder.add_dereference_memlet(block, list, iter, true, ptr_ptr_desc); + } + + auto& loop = builder.add_while(root); + auto& body = loop.root(); + + // Update: iter = *iter + auto& block = builder.add_block(body); + auto& iter_in = builder.add_access(block, "iter"); + auto& iter_out = builder.add_access(block, "iter"); + builder.add_dereference_memlet(block, iter_in, iter_out, true, ptr_ptr_desc); + + // Condition: iter != nullptr -> continue + auto& ifelse = builder.add_if_else(body); + auto& continue_scope = builder.add_case(ifelse, symbolic::Ne(sym_iter, sym_nullptr)); + builder.add_continue(continue_scope); + auto& break_scope = builder.add_case(ifelse, symbolic::Eq(sym_iter, sym_nullptr)); + builder.add_break(break_scope); + + // Analysis + analysis::AnalysisManager analysis_manager(builder.subject()); + passes::WhileToForEachConversion conversion_pass; + EXPECT_TRUE(conversion_pass.run(builder, analysis_manager)); + + // Check + auto for_each_node = dynamic_cast(&builder.subject().root().at(1).first); + EXPECT_TRUE(for_each_node != nullptr); + EXPECT_FALSE(for_each_node->has_init()); + EXPECT_TRUE(symbolic::eq(for_each_node->iterator(), sym_iter)); + EXPECT_TRUE(symbolic::eq(for_each_node->end(), sym_nullptr)); + EXPECT_TRUE(symbolic::eq(for_each_node->update(), sym_iter)); +} + +TEST(WhileToForEachConversionTest, DoubleLinkedList) { + builder::StructuredSDFGBuilder builder("sdfg_test", FunctionType_CPU); + + auto& sdfg = builder.subject(); + auto& root = sdfg.root(); + + // Add containers + types::Pointer opaque_desc; + builder.add_container("list", opaque_desc, true); + builder.add_container("iter", opaque_desc); + + auto sym_iter = symbolic::symbol("iter"); + auto sym_list = symbolic::symbol("list"); + + // Reinterpret cast pointers to pointer(pointer) for dereferencing + types::Pointer ptr_ptr_desc(static_cast(opaque_desc)); + + // Init: iter = *list + { + auto& block = builder.add_block(root); + auto& list = builder.add_access(block, "list"); + auto& iter = builder.add_access(block, "iter"); + builder.add_dereference_memlet(block, list, iter, true, ptr_ptr_desc); + } + + auto& loop = builder.add_while(root); + auto& body = loop.root(); + + // Update: iter = *iter + auto& block = builder.add_block(body); + auto& iter_in = builder.add_access(block, "iter"); + auto& iter_out = builder.add_access(block, "iter"); + builder.add_dereference_memlet(block, iter_in, iter_out, true, ptr_ptr_desc); + + // Condition: iter != nullptr -> continue + auto& ifelse = builder.add_if_else(body); + auto& continue_scope = builder.add_case(ifelse, symbolic::Ne(sym_iter, sym_list)); + builder.add_continue(continue_scope); + auto& break_scope = builder.add_case(ifelse, symbolic::Eq(sym_iter, sym_list)); + builder.add_break(break_scope); + + // Analysis + analysis::AnalysisManager analysis_manager(builder.subject()); + passes::WhileToForEachConversion conversion_pass; + EXPECT_TRUE(conversion_pass.run(builder, analysis_manager)); + + // Check + auto for_each_node = dynamic_cast(&builder.subject().root().at(1).first); + EXPECT_TRUE(for_each_node != nullptr); + EXPECT_FALSE(for_each_node->has_init()); + EXPECT_TRUE(symbolic::eq(for_each_node->iterator(), sym_iter)); + EXPECT_TRUE(symbolic::eq(for_each_node->end(), sym_list)); + EXPECT_TRUE(symbolic::eq(for_each_node->update(), sym_iter)); +} + +TEST(WhileToForEachConversionTest, DoubleLinkedList_NextPtrWithOffset) { + builder::StructuredSDFGBuilder builder("sdfg_test", FunctionType_CPU); + + auto& sdfg = builder.subject(); + auto& root = sdfg.root(); + + // Add containers + types::Pointer opaque_desc; + builder.add_container("list", opaque_desc, true); + builder.add_container("iter", opaque_desc); + builder.add_container("next_ptr", opaque_desc); + + auto sym_iter = symbolic::symbol("iter"); + auto sym_list = symbolic::symbol("list"); + auto sym_next_ptr = symbolic::symbol("next_ptr"); + + // Reinterpret cast pointers to pointer(pointer) for dereferencing + types::Pointer ptr_ptr_desc(static_cast(opaque_desc)); + + // Init + { + // next_ptr = list + offset + auto& block1 = builder.add_block(root); + auto& list = builder.add_access(block1, "list"); + auto& next_ptr = builder.add_access(block1, "next_ptr"); + builder.add_reference_memlet(block1, list, next_ptr, {symbolic::integer(4)}, ptr_ptr_desc); + + auto& block2 = builder.add_block(root); + auto& next_ptr2 = builder.add_access(block2, "next_ptr"); + auto& iter = builder.add_access(block2, "iter"); + builder.add_dereference_memlet(block2, next_ptr2, iter, true, ptr_ptr_desc); + + } + + auto& loop = builder.add_while(root); + auto& body = loop.root(); + + // Update + { + // next_ptr = iter + offset + auto& block1 = builder.add_block(body); + auto& iter_in = builder.add_access(block1, "iter"); + auto& next_ptr = builder.add_access(block1, "next_ptr"); + builder.add_reference_memlet(block1, iter_in, next_ptr, {symbolic::integer(4)}, ptr_ptr_desc); + + // iter = *next_ptr + auto& block2 = builder.add_block(body); + auto& next_ptr2 = builder.add_access(block2, "next_ptr"); + auto& iter_out = builder.add_access(block2, "iter"); + builder.add_dereference_memlet(block2, next_ptr2, iter_out, true, ptr_ptr_desc); + } + + // Condition: iter != nullptr -> continue + auto& ifelse = builder.add_if_else(body); + auto& continue_scope = builder.add_case(ifelse, symbolic::Ne(sym_iter, sym_list)); + builder.add_continue(continue_scope); + auto& break_scope = builder.add_case(ifelse, symbolic::Eq(sym_iter, sym_list)); + builder.add_break(break_scope); + + // Analysis + analysis::AnalysisManager analysis_manager(builder.subject()); + passes::WhileToForEachConversion conversion_pass; + EXPECT_TRUE(conversion_pass.run(builder, analysis_manager)); + + // Check + auto for_each_node = dynamic_cast(&builder.subject().root().at(2).first); + EXPECT_TRUE(for_each_node != nullptr); + EXPECT_FALSE(for_each_node->has_init()); + EXPECT_TRUE(symbolic::eq(for_each_node->iterator(), sym_iter)); + EXPECT_TRUE(symbolic::eq(for_each_node->end(), sym_list)); + EXPECT_TRUE(symbolic::eq(for_each_node->update(), sym_next_ptr)); +} diff --git a/tests/structured_control_flow/for_each_test.cpp b/tests/structured_control_flow/for_each_test.cpp new file mode 100644 index 000000000..e66812d85 --- /dev/null +++ b/tests/structured_control_flow/for_each_test.cpp @@ -0,0 +1,102 @@ +#include +#include "sdfg/structured_control_flow/map.h" + +#include "sdfg/codegen/dispatchers/for_each_dispatcher.h" +#include "sdfg/codegen/language_extensions/c_language_extension.h" +#include "sdfg/serializer/json_serializer.h" +#include "sdfg/symbolic/symbolic.h" +#include "sdfg/visitor/structured_sdfg_visitor.h" + +using namespace sdfg; + +TEST(ForEachTest, SerializeDeserialize) { + builder::StructuredSDFGBuilder builder("sdfg_1", FunctionType_CPU); + + auto& sdfg = builder.subject(); + auto& root = sdfg.root(); + + // Add containers + types::Pointer opaque_desc; + builder.add_container("list", opaque_desc, true); + builder.add_container("iter", opaque_desc); + + auto sym_iter = symbolic::symbol("iter"); + auto sym_list = symbolic::symbol("list"); + auto sym_nullptr = symbolic::__nullptr__(); + + auto& for_each = builder.add_for_each(root, sym_iter, sym_nullptr, sym_iter, sym_list); + + serializer::JSONSerializer serializer; + auto json = serializer.serialize(sdfg); + auto deserialized_sdfg = serializer.deserialize(json); + + auto& deserialized_root = deserialized_sdfg->root(); + auto deserialized_for_each = dynamic_cast(&deserialized_root.at(0).first); + EXPECT_TRUE(deserialized_for_each != nullptr); + EXPECT_TRUE(deserialized_for_each->has_init()); + EXPECT_TRUE(symbolic::eq(deserialized_for_each->init(), sym_list)); + EXPECT_TRUE(symbolic::eq(deserialized_for_each->iterator(), sym_iter)); + EXPECT_TRUE(symbolic::eq(deserialized_for_each->end(), sym_nullptr)); + EXPECT_TRUE(symbolic::eq(deserialized_for_each->update(), sym_iter)); +} + +TEST(ForEachTest, Dispatch) { + builder::StructuredSDFGBuilder builder("sdfg_1", FunctionType_CPU); + + auto& sdfg = builder.subject(); + auto& root = sdfg.root(); + + // Add containers + types::Pointer opaque_desc; + builder.add_container("list", opaque_desc, true); + builder.add_container("iter", opaque_desc); + + auto sym_iter = symbolic::symbol("iter"); + auto sym_list = symbolic::symbol("list"); + auto sym_nullptr = symbolic::__nullptr__(); + + auto& for_each = builder.add_for_each(root, sym_iter, sym_nullptr, sym_iter, sym_list); + + codegen::CLanguageExtension language_extension; + auto instrumentation = codegen::InstrumentationPlan::none(sdfg); + codegen::ForEachDispatcher dispatcher(language_extension, sdfg, for_each, *instrumentation); + + codegen::PrettyPrinter main_stream; + codegen::PrettyPrinter globals_stream; + codegen::CodeSnippetFactory library_factory; + dispatcher.dispatch_node(main_stream, globals_stream, library_factory); + + EXPECT_EQ(globals_stream.str(), ""); + EXPECT_TRUE(library_factory.snippets().empty()); + EXPECT_EQ(main_stream.str(), "for(iter = *((void* *) list);iter != NULL;iter = *((void* *) iter))\n{\n}\n"); +} + +class ForEachVisitor : public visitor::StructuredSDFGVisitor { +public: + ForEachVisitor(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) + : visitor::StructuredSDFGVisitor(builder, analysis_manager) {} + + bool accept(structured_control_flow::ForEach& node) override { return true; }; +}; + +TEST(StructuredSDFGVisitorTest, ForEach) { + builder::StructuredSDFGBuilder builder("sdfg_1", FunctionType_CPU); + + auto& sdfg = builder.subject(); + auto& root = sdfg.root(); + + // Add containers + types::Pointer opaque_desc; + builder.add_container("list", opaque_desc, true); + builder.add_container("iter", opaque_desc); + + auto sym_iter = symbolic::symbol("iter"); + auto sym_list = symbolic::symbol("list"); + auto sym_nullptr = symbolic::__nullptr__(); + + auto& for_each = builder.add_for_each(root, sym_iter, sym_nullptr, sym_iter, sym_list); + + analysis::AnalysisManager analysis_manager(builder.subject()); + ForEachVisitor visitor(builder, analysis_manager); + EXPECT_TRUE(visitor.visit()); +}