|
31 | 31 | }, |
32 | 32 | { |
33 | 33 | "cell_type": "code", |
34 | | - "execution_count": 2, |
| 34 | + "execution_count": null, |
35 | 35 | "metadata": { |
36 | 36 | "colab": { |
37 | 37 | "base_uri": "https://localhost:8080/", |
|
46 | 46 | "\n", |
47 | 47 | "%load_ext autoreload\n", |
48 | 48 | "%autoreload 2\n", |
| 49 | + "from tempfile import TemporaryDirectory\n", |
49 | 50 | "import torch \n", |
50 | 51 | "\n", |
51 | 52 | "from pathlib import Path\n", |
52 | | - "\n", |
53 | | - "\n", |
54 | 53 | "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", |
55 | | - "results_path = Path(\"~/tm_example_pytorch/\").expanduser()\n", |
56 | | - "results_path.mkdir(parents=True, exist_ok=True)" |
| 54 | + "results_path = Path(TemporaryDirectory().name)/\"tm_example_pytorch\"\n", |
| 55 | + "results_path.mkdir(parents=True, exist_ok=True)\n", |
| 56 | + "print(f\"Saving results to {results_path}\")" |
57 | 57 | ] |
58 | 58 | }, |
59 | 59 | { |
|
279 | 279 | }, |
280 | 280 | { |
281 | 281 | "cell_type": "code", |
282 | | - "execution_count": 6, |
| 282 | + "execution_count": null, |
283 | 283 | "metadata": {}, |
284 | 284 | "outputs": [ |
285 | 285 | { |
|
313 | 313 | "# print(model)\n", |
314 | 314 | "\n", |
315 | 315 | "# Create an ActivationsModule from the vanilla model\n", |
316 | | - "activations_module = tm.pytorch.AutoActivationsModule(normalized_model)\n", |
317 | | - "print(activations_module.activation_names())\n", |
318 | | - "def filter(module, name): return True #name.endswith(\"conv2\")\n", |
319 | 316 | "\n", |
320 | | - "activations_module = tm.pytorch.model.FilteredActivationsModule(activations_module,filter)\n", |
| 317 | + "activations = tm.pytorch.get_activations(model)\n", |
| 318 | + "def filter_stochastic(a):\n", |
| 319 | + " return not str(a).startswith(\"StochasticDepth\")\n", |
| 320 | + "activations = {k:v for k,v in activations.items() if filter_stochastic(k)}\n", |
| 321 | + "\n", |
| 322 | + "activations_module = tm.pytorch.ActivationsModule(model,activations)\n", |
| 323 | + "\n", |
321 | 324 | "print(activations_module.activation_names())" |
322 | 325 | ] |
323 | 326 | }, |
|
336 | 339 | }, |
337 | 340 | { |
338 | 341 | "cell_type": "code", |
339 | | - "execution_count": 29, |
| 342 | + "execution_count": null, |
340 | 343 | "metadata": {}, |
341 | 344 | "outputs": [ |
342 | 345 | { |
|
357 | 360 | " measure_result = pickle.load(f)\n", |
358 | 361 | " print(f\"loaded measure results from {filepath}\", measure_result) \n", |
359 | 362 | "else:\n", |
360 | | - " options = tm.pytorch.PyTorchMeasureOptions(batch_size=128, num_workers=0,model_device=device,measure_device=device,data_device=\"cpu\")\n", |
| 363 | + " options = tm.pytorch.PyTorchMeasureOptions(batch_size=128, num_workers=0,model_device=device,measure_device=device,data_device=torch.device(\"cpu\"))\n", |
361 | 364 | "\n", |
362 | 365 | " # Define the measure and evaluate it\n", |
363 | 366 | " measure = tm.pytorch.NormalizedVarianceInvariance()\n", |
|
425 | 428 | }, |
426 | 429 | { |
427 | 430 | "cell_type": "code", |
428 | | - "execution_count": 15, |
| 431 | + "execution_count": null, |
429 | 432 | "metadata": {}, |
430 | 433 | "outputs": [ |
431 | 434 | { |
|
453 | 456 | "from tmeasures.visualization.weights import reorder_conv2d_weights,plot_conv2d_filters,plot_conv2d_filters_rgb\n", |
454 | 457 | "\n", |
455 | 458 | "activation_name = \"/ResNet_1/layer3/BasicBlock_1/conv2\"\n", |
456 | | - "#\"/ResNet_1/layer1/BasicBlock_0/conv2\"\n", |
457 | | - "activation_index = activations_module.inner_model.names.index(activation_name)\n", |
458 | | - "activation = activations_module.inner_model.activations[activation_index]\n", |
| 459 | + "\n", |
| 460 | + "activation = activations[activation_name]\n", |
459 | 461 | "invariance = measure_result.layers_dict()[activation_name]\n", |
460 | 462 | "# print(activation)\n", |
461 | 463 | "# print(invariance.shape)\n", |
|
563 | 565 | "provenance": [] |
564 | 566 | }, |
565 | 567 | "kernelspec": { |
566 | | - "display_name": "tm", |
| 568 | + "display_name": ".venv", |
567 | 569 | "language": "python", |
568 | 570 | "name": "python3" |
569 | 571 | }, |
|
0 commit comments