diff --git a/runtime/image_classification/models/resnet50/gpus=2/__init__.py b/runtime/image_classification/models/resnet50/gpus=2/__init__.py index 49e5a26..4d1edcd 100644 --- a/runtime/image_classification/models/resnet50/gpus=2/__init__.py +++ b/runtime/image_classification/models/resnet50/gpus=2/__init__.py @@ -12,7 +12,7 @@ def arch(): def model(criterion): return [ - (Stage0(), ["input"], ["out0", "out1"]), + (Stage0(), ["input0"], ["out0", "out1"]), (Stage1(), ["out0", "out1"], ["out3", "out2"]), (Stage2(), ["out3", "out2"], ["out4", "out5"]), (Stage3(), ["out4", "out5"], ["out6"]),