diff --git a/CSIKit/tools/batch_graph.py b/CSIKit/tools/batch_graph.py index 8b94f23..bf362f4 100644 --- a/CSIKit/tools/batch_graph.py +++ b/CSIKit/tools/batch_graph.py @@ -7,13 +7,16 @@ from CSIKit.reader import get_reader, IWLBeamformReader DEFAULT_PATH = "./data/intel/misc/log.all_csi.6.7.6.dat" + + # DEFAULT_PATH = "./data/pi/walk_1597159475.pcap" class BatchGraph: - def __init__(self, path: str=DEFAULT_PATH, scaled: bool=False, filter_mac: str=None): + def __init__(self, path: str = DEFAULT_PATH, scaled: bool = False, filter_mac: str = None): reader = get_reader(path) self.csi_data = reader.read_file(path, scaled=scaled, filter_mac=filter_mac) + self.path = path def prepostfilter(self): @@ -42,58 +45,48 @@ def prepostfilter(self): plt.ylabel("Amplitude (dBm)") plt.legend(loc="upper right") - plt.show() + output_path = self.path.replace('.dat', '_prepostfilter.png') + plt.savefig(output_path) + plt.close() def plotAllSubcarriers(self): finalEntry, no_frames, _ = get_CSI(self.csi_data) for x in finalEntry: - plt.plot(np.arange(no_frames)/20, x) + plt.plot(np.arange(no_frames) / 20, x) plt.xlabel("Time (s)") plt.ylabel("Amplitude (dBm)") plt.legend(loc="upper right") - plt.show() + output_path = self.path.replace('.dat', '_all_subcarriers.png') + plt.savefig(output_path) + plt.close() def heatmap(self): - # self.csi_data.frames = self.csi_data.frames[slice(1, len(self.csi_data.frames), 2)] - finalEntry, no_frames, no_subcarriers = get_CSI(self.csi_data) if len(finalEntry.shape) == 4: - #>1 antenna stream. - #Loading the first for ease. finalEntry = finalEntry[:, :, 0, 0] - # from CSIKit.filters.wavelets.dwt import denoise - # finalEntry = denoise(finalEntry) - - #Transpose to get subcarriers * amplitude. finalEntry = np.transpose(finalEntry) x_label = "Time (s)" try: x = self.csi_data.timestamps - x = [timestamp-x[0] for timestamp in x] - except AttributeError as e: - #No timestamp in frame. Likely an IWL entry. - #Will be moving timestamps to CSIData to account for this. + x = [timestamp - x[0] for timestamp in x] + except AttributeError: x = [0] if sum(x) == 0: - #Some files have invalid timestamp_low values which means we can't plot based on timestamps. - #Instead we'll just plot by frame count. - xlim = no_frames - x_label = "Frame No." else: xlim = max(x) limits = [0, xlim, 1, no_subcarriers] - + _, ax = plt.subplots() im = ax.imshow(finalEntry, cmap="jet", extent=limits, aspect="auto") @@ -105,13 +98,12 @@ def heatmap(self): plt.title(self.csi_data.filename) - plt.show() + output_path = self.path.replace('.dat', '_heatmap.png') + plt.savefig(output_path) + plt.close() - # Simple implementation. - # Takes a 2D matrix of CSI data and timestamps to visualise amplitude. - # CSI data should be arranged as (no_frames, no_subcarriers) @staticmethod - def plot_heatmap(csi_matrix, timestamps): + def plot_heatmap(csi_matrix, timestamps, output_path): csi_matrix = np.transpose(csi_matrix) @@ -119,20 +111,14 @@ def plot_heatmap(csi_matrix, timestamps): try: x = timestamps x = [timestamp - x[0] for timestamp in x] - except AttributeError as e: - # No timestamp in frame. Likely an IWL entry. - # Will be moving timestamps to CSIData to account for this. + except AttributeError: x = [0] if sum(x) == 0: - # Some files have invalid timestamp_low values which means we can't plot based on timestamps. - # Instead we'll just plot by frame count. - xlim = csi_matrix.shape[1] - x_label = "Frame No." else: - xlim = max(x) + xlim = max(x) limits = [0, xlim, 1, csi_matrix.shape[0]] @@ -147,31 +133,22 @@ def plot_heatmap(csi_matrix, timestamps): plt.title("CSI Amplitude Heatmap Plot") - plt.show() + plt.savefig(output_path) + plt.close() def sumsqrssi(self): finalEntry, no_frames, no_subcarriers = get_CSI(self.csi_data, extract_as_dBm=False) if len(finalEntry.shape) == 4: - # >1 antenna stream. - # Loading the first for ease. finalEntry = finalEntry[:, :, 0, 0] - rssis = [] - sumsq = [] - - csi = finalEntry rss = [x.rssi for x in self.csi_data.frames] sumsq = np.sum(csi ** 2, axis=1) norm_sumsq = np.sqrt(sumsq) / no_subcarriers - line = [] - - for rss_value, sumsq_value in zip(rss, norm_sumsq): - line.append(matlab.db(sumsq_value) / rss_value) + line = [matlab.db(sumsq_value) / rss_value for rss_value, sumsq_value in zip(rss, norm_sumsq)] - # plt.scatter(rssis, sumsq) plt.plot(self.csi_data.timestamps, line) plt.xlabel("Time (s)") @@ -179,8 +156,11 @@ def sumsqrssi(self): plt.title(self.csi_data.filename) - plt.show() + output_path = self.path.replace('.dat', '_sumsqrssi.png') + plt.savefig(output_path) + plt.close() + if __name__ == "__main__": bg = BatchGraph() - bg.heatmap() \ No newline at end of file + bg.heatmap()