From 3ffc9ae59d13066ffbed6edabba24989a938bb48 Mon Sep 17 00:00:00 2001 From: sebwrede Date: Tue, 28 Jun 2022 14:14:22 +0200 Subject: [PATCH] Include All Privacy Constraints in Remote Retrieval Add Privacy Constraint Propagation to Federated Cost-Based Planner Add PrivacyConstraintLoader Which Handles Loading of Privacy Constraints from Federated Workers and Propagation of the Constraints at the Coordinator Add Privacy Constraint to Explain and Fix Bug Add FederatedPlannerUtil Class and Edit Params Edit Hop Propagation to Throw Exception When Hop Type Is Unknown and Hop Has Privacy Constraint on Input Add Param in Documentation --- .../fedplanner/FederatedPlannerCostbased.java | 38 +-- .../fedplanner/FederatedPlannerUtils.java | 67 +++++ .../fedplanner/PrivacyConstraintLoader.java | 281 ++++++++++++++++++ .../hops/ipa/IPAPassRewriteFederatedPlan.java | 17 +- .../sysds/hops/rewrite/ProgramRewriter.java | 7 - .../rewrite/RewriteFederatedExecution.java | 187 ------------ .../runtime/privacy/PrivacyConstraint.java | 7 +- .../propagation/PrivacyPropagator.java | 44 ++- .../java/org/apache/sysds/utils/Explain.java | 3 + .../FederatedMultiplyPlanningTest.java | 8 + .../FederatedMultiplyPlanningTest11.dml | 34 +++ ...deratedMultiplyPlanningTest11Reference.dml | 32 ++ 12 files changed, 497 insertions(+), 228 deletions(-) create mode 100644 src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java create mode 100644 src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java delete mode 100644 src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java create mode 100644 src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml create mode 100644 src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java index 1f9abb4c180..3c33d783ab1 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java @@ -170,7 +170,7 @@ private ArrayList rewriteDefaultStatementBlock(DMLProgram prog, selectFederatedExecutionPlan(sbHop, paramMap); if(sbHop instanceof FunctionOp) { String funcName = ((FunctionOp) sbHop).getFunctionName(); - Map funcParamMap = getParamMap((FunctionOp) sbHop); + Map funcParamMap = FederatedPlannerUtils.getParamMap((FunctionOp) sbHop); if ( paramMap != null && funcParamMap != null) funcParamMap.putAll(paramMap); paramMap = funcParamMap; @@ -182,22 +182,6 @@ private ArrayList rewriteDefaultStatementBlock(DMLProgram prog, return new ArrayList<>(Collections.singletonList(sb)); } - /** - * Return parameter map containing the mapping from parameter name to input hop - * for all parameters of the function hop. - * @param funcOp hop for which the mapping of parameter names to input hops are made - * @return parameter map or empty map if function has no parameters - */ - private Map getParamMap(FunctionOp funcOp){ - String[] inputNames = funcOp.getInputVariableNames(); - Map paramMap = new HashMap<>(); - if ( inputNames != null ){ - for ( int i = 0; i < funcOp.getInput().size(); i++ ) - paramMap.put(inputNames[i],funcOp.getInput(i)); - } - return paramMap; - } - /** * Set final fedouts of all hops starting from terminal hops. */ @@ -327,13 +311,21 @@ private void visitFedPlanHop(Hop currentHop, Map paramMap) { ArrayList hopRels = getFedPlans(currentHop, paramMap); // Put NONE HopRel into memo table if no FOUT or LOUT HopRels were added if(hopRels.isEmpty()) - hopRels.add(getNONEHopRel(currentHop)); + hopRels.add(getNONEHopRel(currentHop, paramMap)); addTrace(hopRels); hopRelMemo.put(currentHop, hopRels); } - private HopRel getNONEHopRel(Hop currentHop){ - HopRel noneHopRel = new HopRel(currentHop, FederatedOutput.NONE, hopRelMemo); + private ArrayList getHopInputs(Hop currentHop, Map paramMap){ + if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTREAD) ) + return FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites); + else + return currentHop.getInput(); + } + + private HopRel getNONEHopRel(Hop currentHop, Map paramMap){ + ArrayList inputs = getHopInputs(currentHop, paramMap); + HopRel noneHopRel = new HopRel(currentHop, FederatedOutput.NONE, hopRelMemo, inputs); FType[] inputFType = noneHopRel.getInputDependency().stream().map(HopRel::getFType).toArray(FType[]::new); FType outputFType = getFederatedOut(currentHop, inputFType); noneHopRel.setFType(outputFType); @@ -348,9 +340,7 @@ private HopRel getNONEHopRel(Hop currentHop){ */ private ArrayList getFedPlans(Hop currentHop, Map paramMap){ ArrayList hopRels = new ArrayList<>(); - ArrayList inputHops = currentHop.getInput(); - if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTREAD) ) - inputHops = getTransientInputs(currentHop, paramMap); + ArrayList inputHops = getHopInputs(currentHop, paramMap); if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTWRITE) ) transientWrites.put(currentHop.getName(), currentHop); if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.FEDERATED) ) @@ -453,6 +443,8 @@ private void updateExplain(){ private void debugLog(Hop currentHop){ if ( LOG.isDebugEnabled() ){ LOG.debug("Visiting HOP: " + currentHop + " Input size: " + currentHop.getInput().size()); + if (currentHop.getPrivacy() != null) + LOG.debug(currentHop.getPrivacy()); int index = 0; for ( Hop hop : currentHop.getInput()){ if ( hop == null ) diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java new file mode 100644 index 00000000000..45b711a41d6 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.fedplanner; + +import org.apache.sysds.hops.FunctionOp; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.runtime.DMLRuntimeException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class FederatedPlannerUtils { + /** + * Get transient inputs from either paramMap or transientWrites. + * Inputs from paramMap has higher priority than inputs from transientWrites. + * @param currentHop hop for which inputs are read from maps + * @param paramMap of local parameters + * @param transientWrites map of transient writes + * @return inputs of currentHop + */ + public static ArrayList getTransientInputs(Hop currentHop, Map paramMap, Map transientWrites){ + Hop tWriteHop = null; + if ( paramMap != null) + tWriteHop = paramMap.get(currentHop.getName()); + if ( tWriteHop == null ) + tWriteHop = transientWrites.get(currentHop.getName()); + if ( tWriteHop == null ) + throw new DMLRuntimeException("Transient write not found for " + currentHop); + else + return new ArrayList<>(Collections.singletonList(tWriteHop)); + } + + /** + * Return parameter map containing the mapping from parameter name to input hop + * for all parameters of the function hop. + * @param funcOp hop for which the mapping of parameter names to input hops are made + * @return parameter map or empty map if function has no parameters + */ + public static Map getParamMap(FunctionOp funcOp){ + String[] inputNames = funcOp.getInputVariableNames(); + Map paramMap = new HashMap<>(); + if ( inputNames != null ){ + for ( int i = 0; i < funcOp.getInput().size(); i++ ) + paramMap.put(inputNames[i],funcOp.getInput(i)); + } + return paramMap; + } +} diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java b/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java new file mode 100644 index 00000000000..82e43169887 --- /dev/null +++ b/src/main/java/org/apache/sysds/hops/fedplanner/PrivacyConstraintLoader.java @@ -0,0 +1,281 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.hops.fedplanner; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.sysds.api.DMLException; +import org.apache.sysds.common.Types; +import org.apache.sysds.hops.FunctionOp; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.DataExpression; +import org.apache.sysds.parser.ForStatement; +import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.FunctionStatement; +import org.apache.sysds.parser.FunctionStatementBlock; +import org.apache.sysds.parser.IfStatement; +import org.apache.sysds.parser.IfStatementBlock; +import org.apache.sysds.parser.Statement; +import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.WhileStatement; +import org.apache.sysds.parser.WhileStatementBlock; +import org.apache.sysds.runtime.DMLRuntimeException; +import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysds.runtime.controlprogram.federated.FederatedData; +import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; +import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; +import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; +import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException; +import org.apache.sysds.runtime.instructions.cp.Data; +import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction; +import org.apache.sysds.runtime.io.IOUtilFunctions; +import org.apache.sysds.runtime.lineage.LineageItem; +import org.apache.sysds.runtime.privacy.DMLPrivacyException; +import org.apache.sysds.runtime.privacy.PrivacyConstraint; +import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator; +import org.apache.sysds.utils.JSONHelper; +import org.apache.wink.json4j.JSONObject; + +import java.io.BufferedReader; +import java.io.InputStreamReader; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Future; + +public class PrivacyConstraintLoader { + + private final Map memo = new HashMap<>(); + private final Map transientWrites = new HashMap<>(); + + public void loadConstraints(DMLProgram prog){ + rewriteStatementBlocks(prog, prog.getStatementBlocks(), null); + } + + private void rewriteStatementBlocks(DMLProgram prog, List sbs, Map paramMap) { + sbs.forEach(block -> rewriteStatementBlock(prog, block, paramMap)); + } + + private void rewriteStatementBlock(DMLProgram prog, StatementBlock block, Map paramMap){ + if(block instanceof WhileStatementBlock) + rewriteWhileStatementBlock(prog, (WhileStatementBlock) block, paramMap); + else if(block instanceof IfStatementBlock) + rewriteIfStatementBlock(prog, (IfStatementBlock) block, paramMap); + else if(block instanceof ForStatementBlock) { + // This also includes ParForStatementBlocks + rewriteForStatementBlock(prog, (ForStatementBlock) block, paramMap); + } + else if(block instanceof FunctionStatementBlock) + rewriteFunctionStatementBlock(prog, (FunctionStatementBlock) block, paramMap); + else { + // StatementBlock type (no subclass) + rewriteDefaultStatementBlock(prog, block, paramMap); + } + } + + private void rewriteWhileStatementBlock(DMLProgram prog, WhileStatementBlock whileSB, Map paramMap) { + Hop whilePredicateHop = whileSB.getPredicateHops(); + loadPrivacyConstraint(whilePredicateHop, paramMap); + for(Statement stm : whileSB.getStatements()) { + WhileStatement whileStm = (WhileStatement) stm; + rewriteStatementBlocks(prog, whileStm.getBody(), paramMap); + } + } + + private void rewriteIfStatementBlock(DMLProgram prog, IfStatementBlock ifSB, Map paramMap) { + loadPrivacyConstraint(ifSB.getPredicateHops(), paramMap); + for(Statement statement : ifSB.getStatements()) { + IfStatement ifStatement = (IfStatement) statement; + rewriteStatementBlocks(prog, ifStatement.getIfBody(), paramMap); + rewriteStatementBlocks(prog, ifStatement.getElseBody(), paramMap); + } + } + + private void rewriteForStatementBlock(DMLProgram prog, ForStatementBlock forSB, Map paramMap) { + loadPrivacyConstraint(forSB.getFromHops(), paramMap); + loadPrivacyConstraint(forSB.getToHops(), paramMap); + loadPrivacyConstraint(forSB.getIncrementHops(), paramMap); + for(Statement statement : forSB.getStatements()) { + ForStatement forStatement = ((ForStatement) statement); + rewriteStatementBlocks(prog, forStatement.getBody(), paramMap); + } + } + + private void rewriteFunctionStatementBlock(DMLProgram prog, FunctionStatementBlock funcSB, Map paramMap) { + for(Statement statement : funcSB.getStatements()) { + FunctionStatement funcStm = (FunctionStatement) statement; + rewriteStatementBlocks(prog, funcStm.getBody(), paramMap); + } + } + + private void rewriteDefaultStatementBlock(DMLProgram prog, StatementBlock sb, Map paramMap) { + if(sb.hasHops()) { + for(Hop sbHop : sb.getHops()) { + loadPrivacyConstraint(sbHop, paramMap); + if(sbHop instanceof FunctionOp) { + String funcName = ((FunctionOp) sbHop).getFunctionName(); + Map funcParamMap = FederatedPlannerUtils.getParamMap((FunctionOp) sbHop); + if ( paramMap != null && funcParamMap != null) + funcParamMap.putAll(paramMap); + paramMap = funcParamMap; + FunctionStatementBlock sbFuncBlock = prog.getBuiltinFunctionDictionary().getFunction(funcName); + rewriteStatementBlock(prog, sbFuncBlock, paramMap); + } + } + } + } + + private void loadPrivacyConstraint(Hop root, Map paramMap){ + if ( root != null && !memo.containsKey(root.getHopID()) ){ + for ( Hop input : root.getInput() ){ + loadPrivacyConstraint(input, paramMap); + } + propagatePrivConstraintsLocal(root, paramMap); + memo.put(root.getHopID(), root); + } + } + + private void propagatePrivConstraintsLocal(Hop currentHop, Map paramMap){ + if ( currentHop.isFederatedDataOp() ) + loadFederatedPrivacyConstraints(currentHop); + else if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTWRITE) ){ + currentHop.setPrivacy(currentHop.getInput(0).getPrivacy()); + transientWrites.put(currentHop.getName(), currentHop); + } + else if ( HopRewriteUtils.isData(currentHop, Types.OpOpData.TRANSIENTREAD) ){ + currentHop.setPrivacy(FederatedPlannerUtils.getTransientInputs(currentHop, paramMap, transientWrites).get(0).getPrivacy()); + } else { + PrivacyPropagator.hopPropagation(currentHop); + } + } + + /** + * Get privacy constraints from federated workers for DataOps. + * @hop hop for which privacy constraints are loaded + */ + private static void loadFederatedPrivacyConstraints(Hop hop){ + try { + PrivacyConstraint.PrivacyLevel constraintLevel = hop.getInput(0).getInput().stream().parallel() + .map( in -> ((LiteralOp)in).getStringValue() ) + .map(PrivacyConstraintLoader::sendPrivConstraintRequest) + .map(PrivacyConstraintLoader::unwrapPrivConstraint) + .map(constraint -> (constraint != null) ? constraint.getPrivacyLevel() : PrivacyConstraint.PrivacyLevel.None) + .reduce(PrivacyConstraint.PrivacyLevel.None, (out,in) -> { + if ( out == PrivacyConstraint.PrivacyLevel.Private || in == PrivacyConstraint.PrivacyLevel.Private ) + return PrivacyConstraint.PrivacyLevel.Private; + else if ( out == PrivacyConstraint.PrivacyLevel.PrivateAggregation || in == PrivacyConstraint.PrivacyLevel.PrivateAggregation ) + return PrivacyConstraint.PrivacyLevel.PrivateAggregation; + else + return out; + }); + PrivacyConstraint fedDataPrivConstraint = (constraintLevel != PrivacyConstraint.PrivacyLevel.None) ? + new PrivacyConstraint(constraintLevel) : null; + + hop.setPrivacy(fedDataPrivConstraint); + } + catch(Exception ex) { + throw new DMLException(ex); + } + } + + private static Future sendPrivConstraintRequest(String address) + { + try{ + String[] parsedAddress = InitFEDInstruction.parseURL(address); + String host = parsedAddress[0]; + int port = Integer.parseInt(parsedAddress[1]); + PrivacyConstraintRetriever retriever = new PrivacyConstraintRetriever(parsedAddress[2]); + FederatedRequest privacyRetrieval = + new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, retriever); + InetSocketAddress inetAddress = new InetSocketAddress(InetAddress.getByName(host), port); + return FederatedData.executeFederatedOperation(inetAddress, privacyRetrieval); + } catch(UnknownHostException ex){ + throw new DMLException(ex); + } + } + + private static PrivacyConstraint unwrapPrivConstraint(Future privConstraintFuture) + { + try { + FederatedResponse privConstraintResponse = privConstraintFuture.get(); + return (PrivacyConstraint) privConstraintResponse.getData()[0]; + } catch(Exception ex){ + throw new DMLException(ex); + } + } + + /** + * FederatedUDF for retrieving privacy constraint of data stored in file name. + */ + public static class PrivacyConstraintRetriever extends FederatedUDF { + private static final long serialVersionUID = 3551741240135587183L; + private final String filename; + + public PrivacyConstraintRetriever(String filename){ + super(new long[]{}); + this.filename = filename; + } + + /** + * Reads metadata JSON object, parses privacy constraint and returns the constraint in FederatedResponse. + * @param ec execution context + * @param data one or many data objects + * @return FederatedResponse with privacy constraint object + */ + @Override + public FederatedResponse execute(ExecutionContext ec, Data... data) { + PrivacyConstraint privacyConstraint; + FileSystem fs = null; + try { + String mtdname = DataExpression.getMTDFileName(filename); + Path path = new Path(mtdname); + fs = IOUtilFunctions.getFileSystem(mtdname); + try(BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) { + JSONObject metadataObject = JSONHelper.parse(br); + privacyConstraint = PrivacyPropagator.parseAndReturnPrivacyConstraint(metadataObject); + } + } + catch (DMLPrivacyException | FederatedWorkerHandlerException ex){ + throw ex; + } + catch (Exception ex) { + String msg = "Exception in reading metadata of: " + filename; + throw new DMLRuntimeException(msg); + } + finally { + IOUtilFunctions.closeSilently(fs); + } + return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, privacyConstraint); + } + + @Override + public Pair getLineageItem(ExecutionContext ec) { + return null; + } + } + +} diff --git a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java index 6be3b9c8ec8..e6c683eb386 100644 --- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java +++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java @@ -23,6 +23,7 @@ import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner; +import org.apache.sysds.hops.fedplanner.PrivacyConstraintLoader; import org.apache.sysds.parser.DMLProgram; /** @@ -58,16 +59,24 @@ public boolean isApplicable(FunctionCallGraph fgraph) { */ @Override public boolean rewriteProgram(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes) { - // obtain planner instance according to config String splanner = ConfigurationManager.getDMLConfig() .getTextValue(DMLConfig.FEDERATED_PLANNER); + loadPrivacyConstraints(prog, splanner); + generatePlan(prog, fgraph, fcallSizes, splanner); + return false; + } + + private void loadPrivacyConstraints(DMLProgram prog, String splanner){ + if (FederatedPlanner.isCompiled(splanner)) + new PrivacyConstraintLoader().loadConstraints(prog); + } + + private void generatePlan(DMLProgram prog, FunctionCallGraph fgraph, FunctionCallSizeInfo fcallSizes, String splanner){ FederatedPlanner planner = FederatedPlanner.isCompiled(splanner) ? FederatedPlanner.valueOf(splanner.toUpperCase()) : FederatedPlanner.COMPILE_COST_BASED; - + // run planner rewrite with forced federated exec types planner.getPlanner().rewriteProgram(prog, fgraph, fcallSizes); - - return false; } } diff --git a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java index faec3504e9e..db20ada2808 100644 --- a/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java +++ b/src/main/java/org/apache/sysds/hops/rewrite/ProgramRewriter.java @@ -27,10 +27,8 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.conf.CompilerConfig.ConfigType; -import org.apache.sysds.conf.DMLConfig; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.OptimizerUtils; -import org.apache.sysds.hops.fedplanner.FTypes; import org.apache.sysds.parser.DMLProgram; import org.apache.sysds.parser.ForStatement; import org.apache.sysds.parser.ForStatementBlock; @@ -141,11 +139,6 @@ public ProgramRewriter(boolean staticRewrites, boolean dynamicRewrites) _dagRuleSet.add( new RewriteAlgebraicSimplificationDynamic() ); //dependencies: cse _dagRuleSet.add( new RewriteAlgebraicSimplificationStatic() ); //dependencies: cse } - String planner = ConfigurationManager.getDMLConfig() - .getTextValue(DMLConfig.FEDERATED_PLANNER); - if ( OptimizerUtils.FEDERATED_COMPILATION || FTypes.FederatedPlanner.isCompiled(planner) ) { - _dagRuleSet.add( new RewriteFederatedExecution() ); - } } // cleanup after all rewrites applied diff --git a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java b/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java deleted file mode 100644 index 822b4b5d952..00000000000 --- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteFederatedExecution.java +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.sysds.hops.rewrite; - -import org.apache.commons.lang3.tuple.Pair; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.log4j.Logger; -import org.apache.sysds.api.DMLException; -import org.apache.sysds.hops.Hop; -import org.apache.sysds.hops.LiteralOp; -import org.apache.sysds.parser.DataExpression; -import org.apache.sysds.runtime.DMLRuntimeException; -import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; -import org.apache.sysds.runtime.controlprogram.federated.FederatedData; -import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; -import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse; -import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF; -import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException; -import org.apache.sysds.runtime.instructions.cp.Data; -import org.apache.sysds.runtime.instructions.fed.InitFEDInstruction; -import org.apache.sysds.runtime.io.IOUtilFunctions; -import org.apache.sysds.runtime.lineage.LineageItem; -import org.apache.sysds.runtime.privacy.DMLPrivacyException; -import org.apache.sysds.runtime.privacy.PrivacyConstraint; -import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator; -import org.apache.sysds.utils.JSONHelper; -import org.apache.wink.json4j.JSONObject; - -import javax.net.ssl.SSLException; -import java.io.BufferedReader; -import java.io.InputStreamReader; -import java.net.InetAddress; -import java.net.InetSocketAddress; -import java.net.UnknownHostException; -import java.util.ArrayList; -import java.util.concurrent.Future; - -public class RewriteFederatedExecution extends HopRewriteRule { - private static final Logger LOG = Logger.getLogger(RewriteFederatedExecution.class); - - @Override - public ArrayList rewriteHopDAGs(ArrayList roots, ProgramRewriteStatus state) { - if ( roots != null ) - for ( Hop root : roots ) - rewriteHopDAG(root, state); - return roots; - } - - @Override - public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { - if ( root != null ) - visitHop(root); - return root; - } - - private void visitHop(Hop hop){ - if (hop.isVisited()) - return; - - LOG.debug("RewriteFederatedExecution visitHop + " + hop); - - // Depth first to get to the input - for ( Hop input : hop.getInput() ) - visitHop(input); - - privacyBasedHopDecisionWithFedCall(hop); - hop.setVisited(); - } - - /** - * Get privacy constraints of DataOps from federated worker, - * propagate privacy constraints from input to current hop, - * and set federated output flag. - * @param hop current hop - */ - private static void privacyBasedHopDecisionWithFedCall(Hop hop){ - loadFederatedPrivacyConstraints(hop); - PrivacyPropagator.hopPropagation(hop); - } - - /** - * Get privacy constraints from federated workers for DataOps. - * @hop hop for which privacy constraints are loaded - */ - private static void loadFederatedPrivacyConstraints(Hop hop){ - if ( hop.isFederatedDataOp() && hop.getPrivacy() == null){ - try { - LOG.debug("Load privacy constraints of " + hop); - PrivacyConstraint privConstraint = unwrapPrivConstraint(sendPrivConstraintRequest(hop)); - LOG.debug("PrivacyConstraint retrieved: " + privConstraint); - hop.setPrivacy(privConstraint); - } - catch(Exception e) { - throw new DMLException(e); - } - } - } - - private static Future sendPrivConstraintRequest(Hop hop) - throws UnknownHostException, SSLException - { - String address = ((LiteralOp) hop.getInput(0).getInput(0)).getStringValue(); - String[] parsedAddress = InitFEDInstruction.parseURL(address); - String host = parsedAddress[0]; - int port = Integer.parseInt(parsedAddress[1]); - PrivacyConstraintRetriever retriever = new PrivacyConstraintRetriever(parsedAddress[2]); - FederatedRequest privacyRetrieval = - new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1, retriever); - InetSocketAddress inetAddress = new InetSocketAddress(InetAddress.getByName(host), port); - return FederatedData.executeFederatedOperation(inetAddress, privacyRetrieval); - } - - private static PrivacyConstraint unwrapPrivConstraint(Future privConstraintFuture) - throws Exception - { - FederatedResponse privConstraintResponse = privConstraintFuture.get(); - return (PrivacyConstraint) privConstraintResponse.getData()[0]; - } - - /** - * FederatedUDF for retrieving privacy constraint of data stored in file name. - */ - public static class PrivacyConstraintRetriever extends FederatedUDF { - private static final long serialVersionUID = 3551741240135587183L; - private final String filename; - - public PrivacyConstraintRetriever(String filename){ - super(new long[]{}); - this.filename = filename; - } - - /** - * Reads metadata JSON object, parses privacy constraint and returns the constraint in FederatedResponse. - * @param ec execution context - * @param data one or many data objects - * @return FederatedResponse with privacy constraint object - */ - @Override - public FederatedResponse execute(ExecutionContext ec, Data... data) { - PrivacyConstraint privacyConstraint; - FileSystem fs = null; - try { - String mtdname = DataExpression.getMTDFileName(filename); - Path path = new Path(mtdname); - fs = IOUtilFunctions.getFileSystem(mtdname); - try(BufferedReader br = new BufferedReader(new InputStreamReader(fs.open(path)))) { - JSONObject metadataObject = JSONHelper.parse(br); - privacyConstraint = PrivacyPropagator.parseAndReturnPrivacyConstraint(metadataObject); - } - } - catch (DMLPrivacyException | FederatedWorkerHandlerException ex){ - throw ex; - } - catch (Exception ex) { - String msg = "Exception in reading metadata of: " + filename; - throw new DMLRuntimeException(msg); - } - finally { - IOUtilFunctions.closeSilently(fs); - } - return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, privacyConstraint); - } - - @Override - public Pair getLineageItem(ExecutionContext ec) { - return null; - } - } -} diff --git a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java index 8ea061844ad..fc9ba440c85 100644 --- a/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java +++ b/src/main/java/org/apache/sysds/runtime/privacy/PrivacyConstraint.java @@ -262,8 +262,11 @@ public boolean equals(Object other){ @Override public String toString(){ - return "General privacy level: " + privacyLevel + System.getProperty("line.separator") - + "Fine-grained privacy level: " + fineGrainedPrivacy.toString(); + String constraintString = "General privacy level: " + privacyLevel; + if ( fineGrainedPrivacy != null && fineGrainedPrivacy.hasConstraints() ) + constraintString = constraintString + System.getProperty("line.separator") + + "Fine-grained privacy level: " + fineGrainedPrivacy.toString(); + return constraintString; } } diff --git a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java index 7e6c0127e52..94834ebc6e1 100644 --- a/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java +++ b/src/main/java/org/apache/sysds/runtime/privacy/propagation/PrivacyPropagator.java @@ -23,11 +23,18 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Objects; +import org.apache.sysds.api.DMLException; import org.apache.sysds.hops.AggBinaryOp; import org.apache.sysds.hops.AggUnaryOp; import org.apache.sysds.hops.BinaryOp; +import org.apache.sysds.hops.DataGenOp; +import org.apache.sysds.hops.DataOp; +import org.apache.sysds.hops.FunctionOp; import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.LiteralOp; +import org.apache.sysds.hops.NaryOp; import org.apache.sysds.hops.ReorgOp; import org.apache.sysds.hops.TernaryOp; import org.apache.sysds.hops.UnaryOp; @@ -168,12 +175,39 @@ else if (privacyConstraint2 != null) * @param hop which the privacy constraints are propagated to */ public static void hopPropagation(Hop hop){ - PrivacyConstraint[] inputConstraints = hop.getInput().stream() + hopPropagation(hop, hop.getInput()); + } + + /** + * Propagate privacy constraints from input hops to given hop. + * @param hop which the privacy constraints are propagated to + * @param inputHops inputs to given hop + */ + public static void hopPropagation(Hop hop, ArrayList inputHops){ + PrivacyConstraint[] inputConstraints = inputHops.stream() .map(Hop::getPrivacy).toArray(PrivacyConstraint[]::new); - if ( hop instanceof TernaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp ) - hop.setPrivacy(mergeNary(inputConstraints, OperatorType.NonAggregate)); + OperatorType opType = getOpType(hop); + hop.setPrivacy(mergeNary(inputConstraints, opType)); + if (opType == null && Arrays.stream(inputConstraints).anyMatch(Objects::nonNull)) + throw new DMLException("Input has constraint but hop type not recognized by PrivacyPropagator. " + + "Hop is " + hop + " " + hop.getClass()); + } + + /** + * Get operator type of given hop. + * Returns null if hop type is not known. + * @param hop for which operator type is returned + * @return operator type of hop or null if hop type is unknown + */ + private static OperatorType getOpType(Hop hop){ + if ( hop instanceof TernaryOp || hop instanceof BinaryOp || hop instanceof ReorgOp + || hop instanceof DataOp || hop instanceof LiteralOp || hop instanceof NaryOp + || hop instanceof DataGenOp || hop instanceof FunctionOp ) + return OperatorType.NonAggregate; else if ( hop instanceof AggBinaryOp || hop instanceof AggUnaryOp || hop instanceof UnaryOp ) - hop.setPrivacy(mergeNary(inputConstraints, OperatorType.Aggregate)); + return OperatorType.Aggregate; + else + return null; } /** @@ -406,7 +440,7 @@ private static Instruction throwExceptionIfInputOrInstPrivacy(Instruction inst, if (inputOperands != null){ for ( CPOperand input : inputOperands ){ PrivacyConstraint privacyConstraint = getInputPrivacyConstraint(ec, input); - if ( privacyConstraint != null){ + if ( privacyConstraint != null && privacyConstraint.hasConstraints()){ throw new DMLPrivacyException("Input of instruction " + inst + " has privacy constraints activated, but the constraints are not propagated during preprocessing of instruction."); } } diff --git a/src/main/java/org/apache/sysds/utils/Explain.java b/src/main/java/org/apache/sysds/utils/Explain.java index 589f23a845b..ded46c039a1 100644 --- a/src/main/java/org/apache/sysds/utils/Explain.java +++ b/src/main/java/org/apache/sysds/utils/Explain.java @@ -626,6 +626,9 @@ else if( hop.requiresCheckpoint() ) } } + if ( hop.getPrivacy() != null ) + sb.append(" ").append(hop.getPrivacy().getPrivacyLevel().name()); + sb.append('\n'); hop.setVisited(); diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java index 14c093ebe82..2477bdef851 100644 --- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java +++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedMultiplyPlanningTest.java @@ -55,6 +55,7 @@ public class FederatedMultiplyPlanningTest extends AutomatedTestBase { private final static String TEST_NAME_8 = "FederatedMultiplyPlanningTest8"; private final static String TEST_NAME_9 = "FederatedMultiplyPlanningTest9"; private final static String TEST_NAME_10 = "FederatedMultiplyPlanningTest10"; + private final static String TEST_NAME_11 = "FederatedMultiplyPlanningTest11"; private final static String TEST_CLASS_DIR = TEST_DIR + FederatedMultiplyPlanningTest.class.getSimpleName() + "/"; private static File TEST_CONF_FILE = new File(SCRIPT_DIR + TEST_DIR, "SystemDS-config-cost-based.xml"); @@ -77,6 +78,7 @@ public void setUp() { addTestConfiguration(TEST_NAME_8, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_8, new String[] {"Z.scalar"})); addTestConfiguration(TEST_NAME_9, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_9, new String[] {"Z.scalar"})); addTestConfiguration(TEST_NAME_10, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_10, new String[] {"Z"})); + addTestConfiguration(TEST_NAME_11, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME_11, new String[] {"Z"})); } @Parameterized.Parameters @@ -153,6 +155,12 @@ public void federatedMultiplyPlanningTest10(){ federatedTwoMatricesSingleNodeTest(TEST_NAME_10, expectedHeavyHitters); } + @Test + public void federatedMultiplyPlanningTest11(){ + String[] expectedHeavyHitters = new String[]{"fed_fedinit"}; + federatedTwoMatricesSingleNodeTest(TEST_NAME_11, expectedHeavyHitters); + } + private void writeStandardMatrix(String matrixName, long seed){ writeStandardMatrix(matrixName, seed, new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.PrivateAggregation)); } diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml new file mode 100644 index 00000000000..147bf2cd139 --- /dev/null +++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11.dml @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = federated(addresses=list($X1, $X2), + ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c))) +Y = federated(addresses=list($Y1, $Y2), + ranges=list(list(0, 0), list($r/2, $c), list($r / 2, 0), list($r, $c))) + +i = 0 +while(i < 10){ + Z0 = X * Y + Z = t(Z0) %*% X + i=i+1 +} + +write(Z, $Z) diff --git a/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml new file mode 100644 index 00000000000..187623bbfeb --- /dev/null +++ b/src/test/scripts/functions/privacy/fedplanning/FederatedMultiplyPlanningTest11Reference.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rbind(read($X1), read($X2)) +Y = rbind(read($Y1), read($Y2)) + +i = 0 +while(i < 10){ + Z0 = X * Y + Z = t(Z0) %*% X + i=i+1 +} + +write(Z, $Z)