Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llvm/src/lifting/function_to_sdfg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ std::unique_ptr<sdfg::StructuredSDFG> FunctionToSDFG::simplify(std::unique_ptr<s
bool success = pass.run(builder_opt, analysis_manager);
} else if (DOCC_expand != "none") {
LLVM_DEBUG_PRINTLN("Expanding all library nodes");
auto expansion_pass = sdfg::passes::ExpansionPass();
auto expansion_pass = sdfg::passes::MathExpansionPass();
bool expanded = expansion_pass.run(builder_opt, analysis_manager);
}

Expand Down
6 changes: 4 additions & 2 deletions mlir/test/sdfg-json-to-c/sdfg-json-to-c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#include <sdfg/passes/pipeline.h>
#include <sdfg/serializer/json_serializer.h>

#include "sdfg/passes/schedules/expansion_pass.h"

int main(int argc, char* argv[]) {
sdfg::codegen::register_default_dispatchers();
sdfg::serializer::register_default_serializers();
Expand All @@ -29,8 +31,8 @@ int main(int argc, char* argv[]) {
sdfg::builder::StructuredSDFGBuilder builder(*sdfg);
sdfg::analysis::AnalysisManager analysis_manager(builder.subject());

sdfg::passes::Pipeline libnode_expansion = sdfg::passes::Pipeline::expansion();
libnode_expansion.run(builder, analysis_manager);
sdfg::passes::MathExpansionPass math_expansion;
math_expansion.run(builder, analysis_manager);

sdfg::passes::TensorToPointerConversionPass tensor_to_pointer_conversion_pass;
tensor_to_pointer_conversion_pass.run(builder, analysis_manager);
Expand Down
6 changes: 6 additions & 0 deletions python/bindings/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ PYBIND11_MODULE(_sdfg, m) {
py::arg("dump_json") = true,
py::arg("record_for_instrumentation") = false
)
.def(
"set_output_dir",
static_cast<void (PyStructuredSDFG::*)(const std::string&)>(&PyStructuredSDFG::set_output_dir),
py::arg("path"),
"Set the output directory"
)
.def("normalize", &PyStructuredSDFG::normalize, "Normalize the SDFG")
.def(
"schedule",
Expand Down
12 changes: 7 additions & 5 deletions python/bindings/py_structured_sdfg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
#include "sdfg/passes/rpc/daisytuner_rpc_context.h"
#include "sdfg/passes/rpc/rpc_context.h"
#include "sdfg/passes/rpc/rpc_scheduler.h"
#include "sdfg/passes/schedules/expansion_pass.h"
#include "sdfg/passes/targets/target_mapping_pass.h"
#include "sdfg/util/offloading_instrumentation_plan.h"
#include "targets/target_mapping.h"
Expand Down Expand Up @@ -113,6 +114,10 @@ PyStructuredSDFG PyStructuredSDFG::from_sdfg(sdfg::plugins::Context& ctx, std::u

std::string PyStructuredSDFG::name() const { return sdfg_->name(); }

void PyStructuredSDFG::set_output_dir(const std::string& path) { this->set_output_dir(std::filesystem::path(path)); }

void PyStructuredSDFG::set_output_dir(const std::filesystem::path& path) { this->output_dir_ = path; }

sdfg::plugins::Context& PyStructuredSDFG::docc_context() const { return docc_context_; }

const sdfg::types::IType& PyStructuredSDFG::return_type() const { return sdfg_->return_type(); }
Expand Down Expand Up @@ -182,8 +187,8 @@ void PyStructuredSDFG::expand(const docc::target::TargetOptions& options) {
local_buffer_reuse_pipeline.run(builder_opt, analysis_manager);

// Expand library nodes
sdfg::passes::Pipeline libnode_expansion = sdfg::passes::Pipeline::expansion();
libnode_expansion.run(builder_opt, analysis_manager);
sdfg::passes::MathExpansionPass math_expand;
math_expand.run(builder_opt, analysis_manager);

sdfg::passes::TensorToPointerConversionPass tensor_to_pointer_conversion_pass;
tensor_to_pointer_conversion_pass.run(builder_opt, analysis_manager);
Expand Down Expand Up @@ -448,10 +453,7 @@ std::string PyStructuredSDFG::compile(

sdfg::analysis::AnalysisManager analysis_manager(*sdfg_);

// Run expansion pass
sdfg::passes::Pipeline expansion = sdfg::passes::Pipeline::expansion();
sdfg::builder::StructuredSDFGBuilder builder_opt(*sdfg_);
expansion.run(builder_opt, analysis_manager);

// Instrumentation plan
std::unique_ptr<sdfg::codegen::InstrumentationPlan> instrumentation_plan;
Expand Down
4 changes: 4 additions & 0 deletions python/bindings/py_structured_sdfg.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class PyStructuredSDFG {
private:
sdfg::plugins::Context& docc_context_;
std::unique_ptr<sdfg::StructuredSDFG> sdfg_;
std::optional<std::filesystem::path> output_dir_;

PyStructuredSDFG(sdfg::plugins::Context& ctx, std::unique_ptr<sdfg::StructuredSDFG>& sdfg);

Expand All @@ -34,6 +35,9 @@ class PyStructuredSDFG {

std::string name() const;

void set_output_dir(const std::string& path);
void set_output_dir(const std::filesystem::path& path);

sdfg::plugins::Context& docc_context() const;

sdfg::StructuredSDFG& sdfg() { return *sdfg_; }
Expand Down
2 changes: 2 additions & 0 deletions python/docc/compiler/docc_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ def sdfg_pipe(
target_options.target = self.target
target_options.category = self.category
target_options.remote_tuning = remote_tuning
if self.debug_dump:
sdfg.set_output_dir(output_folder)

# Einsum detection
sdfg.einsum()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ class ReduceNode : public TensorNode {

data_flow::PointerAccessType pointer_access_type(int input_idx) const override;

std::string toStr() const override;

protected:
virtual bool expand_inner(
builder::StructuredSDFGBuilder& builder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ class TensorLayout {
std::unique_ptr<TensorLayout> squeeze() const;

std::unique_ptr<TensorLayout> reshape(const symbolic::MultiExpression& new_shape) const;

static std::ostream& emit_symbolic_list(std::ostream& stream, const symbolic::MultiExpression& list);
};

std::ostream& operator<<(std::ostream& stream, const TensorLayout& layout);
Expand Down
2 changes: 0 additions & 2 deletions sdfg/include/sdfg/passes/pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,6 @@ class Pipeline : public Pass {
static Pipeline data_parallelism();

static Pipeline memory();

static Pipeline expansion();
};

} // namespace passes
Expand Down
55 changes: 26 additions & 29 deletions sdfg/include/sdfg/passes/schedules/expansion_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,52 +23,49 @@

#pragma once

#include "sdfg/data_flow/library_nodes/math/math_node.h"
#include "sdfg/passes/pass.h"
#include "sdfg/visitor/structured_sdfg_visitor.h"

namespace sdfg {
namespace passes {

/**
* @class Expansion
* @brief Visitor that expands library nodes into primitive operations
*
* The Expansion visitor traverses the SDFG and expands library nodes that
* have ImplementationType_NONE. This allows high-level mathematical operations
* to be transformed into lower-level constructs that can be optimized and
* scheduled.
*/
class Expansion : public visitor::StructuredSDFGVisitor {
class MathExpansionPass;

class MathExpansionVisitor : public visitor::ActualStructuredSDFGVisitor {
friend MathExpansionPass;

private:
builder::StructuredSDFGBuilder& builder_;
analysis::AnalysisManager& analysis_manager_;

struct LibNodeContainer {
math::MathNode& node;
structured_control_flow::Block& block;
};

std::vector<LibNodeContainer> nodes_to_expand_;

public:
/**
* @brief Construct the expansion visitor
* @param builder SDFG builder for creating new nodes
* @param analysis_manager Analysis manager for querying properties
*/
Expansion(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager);
MathExpansionVisitor(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager);

/**
* @brief Get the pass name
* @return Name of the pass
*/
static std::string name() { return "Expansion"; };

/**
* @brief Visit a block and attempt to expand its library nodes
* @param node Block to visit
* @return True if any expansion occurred
*/
bool accept(structured_control_flow::Block& node) override;
bool visit(sdfg::structured_control_flow::Block& node) override;
};

/**
* @typedef ExpansionPass
* @brief Pass wrapper for the Expansion visitor
*
* This typedef creates a pass from the Expansion visitor, allowing it to be
* used in the pass pipeline system.
* @class MathExpansionPass
* @brief Looks for and expands math-nodes that are not already mapped to a specific target
*/
typedef VisitorPass<Expansion> ExpansionPass;
class MathExpansionPass : public Pass {
std::string name() override { return "MathExpansion"; }

bool run_pass(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) override;
};

} // namespace passes
} // namespace sdfg
21 changes: 21 additions & 0 deletions sdfg/src/data_flow/library_nodes/math/tensor/reduce_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,27 @@ data_flow::PointerAccessType ReduceNode::pointer_access_type(int input_idx) cons
}
}

std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& list) {
os << "[";
for (size_t i = 0; i < list.size(); ++i) {
if (i > 0) os << ", ";
os << list[i];
}
os << "]";
return os;
}

std::string ReduceNode::toStr() const {
std::stringstream ss;
ss << this->code_.value();
ss << "(shape=";
TensorLayout::emit_symbolic_list(ss, shape_);
ss << ", axes=" << axes_;
ss << ", keep=" << this->keepdims_;
ss << ")";
return ss.str();
}

bool ReduceNode::expand_inner(
builder::StructuredSDFGBuilder& builder,
analysis::AnalysisManager& analysis_manager,
Expand Down
21 changes: 12 additions & 9 deletions sdfg/src/data_flow/library_nodes/math/tensor/tensor_layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,21 @@ TensorLayout TensorLayout::deserialize_from_json(const nlohmann::json& j) {
return std::move(TensorLayout(shape, strides, offset));
}

std::ostream& operator<<(std::ostream& stream, const TensorLayout& layout) {
stream << "{shape[";
for (size_t i = 0; i < layout.shape().size(); ++i) {
if (i > 0) stream << ", ";
stream << layout.shape().at(i)->__str__();
}
stream << "], strides=[";
for (size_t i = 0; i < layout.strides().size(); ++i) {
std::ostream& TensorLayout::emit_symbolic_list(std::ostream& stream, const symbolic::MultiExpression& list) {
stream << "[";
for (size_t i = 0; i < list.size(); ++i) {
if (i > 0) stream << ", ";
stream << layout.strides().at(i)->__str__();
stream << list.at(i)->__str__();
}
stream << "]";
return stream;
}

std::ostream& operator<<(std::ostream& stream, const TensorLayout& layout) {
stream << "{shape=";
TensorLayout::emit_symbolic_list(stream, layout.shape());
stream << ", strides=";
TensorLayout::emit_symbolic_list(stream, layout.strides());
if (SymEngine::neq(*layout.offset(), *symbolic::integer(0))) {
stream << ", off=" << layout.offset()->__str__();
}
Expand Down
8 changes: 0 additions & 8 deletions sdfg/src/passes/pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,5 @@ Pipeline Pipeline::memory() {
return p;
};

Pipeline Pipeline::expansion() {
Pipeline p("Expansion");

p.register_pass<ExpansionPass>();

return p;
};

} // namespace passes
} // namespace sdfg
38 changes: 30 additions & 8 deletions sdfg/src/passes/schedules/expansion_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,48 @@
namespace sdfg {
namespace passes {

Expansion::Expansion(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager)
: visitor::StructuredSDFGVisitor(builder, analysis_manager) {}
MathExpansionVisitor::
MathExpansionVisitor(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager)
: visitor::ActualStructuredSDFGVisitor(), builder_(builder), analysis_manager_(analysis_manager) {}

bool Expansion::accept(structured_control_flow::Block& node) {
bool MathExpansionVisitor::visit(structured_control_flow::Block& node) {
auto& dataflow = node.dataflow();

bool applied = false;

for (auto* library_node : dataflow.library_nodes()) {
if (library_node->implementation_type() != data_flow::ImplementationType_NONE) {
continue;
}

if (auto math_node = dynamic_cast<math::MathNode*>(library_node)) {
if (math_node->expand(this->builder_, this->analysis_manager_)) {
return true;
}
this->nodes_to_expand_.emplace_back(*math_node, node);
}
}
return applied;
return true;
}

bool MathExpansionPass::run_pass(builder::StructuredSDFGBuilder& builder, analysis::AnalysisManager& analysis_manager) {
MathExpansionVisitor v(builder, analysis_manager);

v.dispatch(builder.subject().root());

auto& nodes = v.nodes_to_expand_;

bool expanded_any = false;

for (auto& entry : std::views::reverse(nodes)) {
// TODO: check if the prerequisits are met, like if the libNode is standalone or if we need to cut it out of a
// larger block first

if (entry.node.expand(builder, analysis_manager)) {
// If expansion was successful, remove the original library node // TODO requires new API to do this clean
// builder.remove_node(entry.block, entry.node);
// remove block
expanded_any |= true;
}
}

return expanded_any;
}
} // namespace passes
} // namespace sdfg
4 changes: 2 additions & 2 deletions sdfg/tests/passes/schedules/expansion_pass_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ TEST(ExpansionPassTest, MeanNode_2D) {
EXPECT_EQ(block.dataflow().nodes().size(), 3);

analysis::AnalysisManager analysis_manager(builder.subject());
passes::ExpansionPass expansion_pass;
passes::MathExpansionPass expansion_pass;
EXPECT_TRUE(expansion_pass.run(builder, analysis_manager));

dump_sdfg(builder.subject(), "1.expanded");
Expand Down Expand Up @@ -86,7 +86,7 @@ TEST(ExpansionPassTest, StdNode_1D) {
dump_sdfg(builder.subject(), "0.init");

analysis::AnalysisManager analysis_manager(builder.subject());
passes::ExpansionPass expansion_pass;
passes::MathExpansionPass expansion_pass;
EXPECT_TRUE(expansion_pass.run(builder, analysis_manager));

dump_sdfg(builder.subject(), "1.expanded");
Expand Down
Loading