Coverage for src / sgn_gwframe / sources / frame.py: 96.0%
125 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-12 13:29 -0800
« prev ^ index » next coverage.py v7.13.2, created at 2026-02-12 13:29 -0800
1"""Read GW frame files from a frame cache file."""
3# Copyright (C) 2024 Becca Ewing, Yun-Jing Huang
5from __future__ import annotations
7import logging
8from dataclasses import dataclass
9from typing import TYPE_CHECKING, Optional
11import gwframe
12import igwn_segments as segments
13from sgnts.base import Audioadapter, Offset, SeriesBuffer, TSFrame, TSSource
15if TYPE_CHECKING:
16 from sgn.base import SourcePad
18logger = logging.getLogger("sgn")
21@dataclass
22class CacheEntry:
23 """Simple cache entry parser for frame cache files.
25 Cache file format: observatory description gps_start duration path
26 Example: L L1_GWOSC_16KHZ_R1 1240215487 32 ./path/to/file.gwf
27 """
29 observatory: str
30 description: str
31 gps_start: float
32 duration: float
33 path: str
35 @classmethod
36 def from_line(cls, line: str) -> CacheEntry:
37 """Parse a cache file line into a CacheEntry."""
38 parts = line.strip().split()
39 if len(parts) != 5:
40 msg = f"Invalid cache line format: {line}"
41 raise ValueError(msg)
42 return cls(
43 observatory=parts[0],
44 description=parts[1],
45 gps_start=float(parts[2]),
46 duration=float(parts[3]),
47 path=parts[4],
48 )
50 @property
51 def segment(self) -> segments.segment:
52 """Return the time segment for this cache entry."""
53 return segments.segment(self.gps_start, self.gps_start + self.duration)
56@dataclass(kw_only=True)
57class FrameSource(TSSource):
58 """Read GW frame files from a frame cache file
60 Args:
61 channel_names:
62 list[str], a list of channel names of the data, e.g.,
63 ["L1:GWOSC-16KHZ_R1_STRAIN", "L1:GWOSC-16KHZ_R1_DQMASK"]. Source pads will
64 be automatically generated for each channel, with channel name as pad name.
65 framecache:
66 str, cache file to read data from
67 instrument:
68 str, optional, only read gwf files from this instrument. Default: None
69 """
71 channel_names: list[str]
72 framecache: str
73 instrument: Optional[str] = None
75 def __post_init__(self) -> None:
76 if len(self.source_pad_names) > 0:
77 if self.source_pad_names != tuple(self.channel_names):
78 msg = "Expected source pad names to match channel names"
79 raise ValueError(msg)
80 else:
81 self.source_pad_names = tuple(self.channel_names)
83 super().__post_init__()
84 self.cnt = dict.fromkeys(self.source_pads, 0)
86 if self.start is None:
87 msg = "FrameSource requires a start time to be specified"
88 raise ValueError(msg)
89 self.last_epoch: float = self.start
91 if self.instrument is not None:
92 for channel in self.channel_names:
93 if self.instrument not in channel:
94 msg = (
95 f"Instrument '{self.instrument}' does not match "
96 f"channel name '{channel}'"
97 )
98 raise ValueError(msg)
100 # init analysis segment
101 self.analysis_seg = segments.segment(self.start, self.end)
103 # load the cache file
104 self.logger.info("Loading frame cache from %s", self.framecache)
105 with open(self.framecache) as f:
106 cache: list[CacheEntry] = [
107 CacheEntry.from_line(line) for line in f if line.strip()
108 ]
110 if self.instrument is not None:
111 # only keep files with the correct instrument
112 cache = [
113 entry
114 for entry in cache
115 if entry.observatory in self.ifo_strings(self.instrument)
116 ]
118 # only keep files that intersect the analysis segment
119 self.cache: list[CacheEntry] = []
120 for entry in cache:
121 try:
122 self.analysis_seg & entry.segment
123 except ValueError:
124 continue
125 else:
126 self.cache.append(entry)
128 # make sure it is sorted by gps time
129 self.cache.sort(key=lambda x: x.segment[0])
131 # Check if there are missing segments
132 segment_remaining = self.analysis_seg
133 missing_segments = []
134 for c in self.cache:
135 if segment_remaining in c.segment:
136 # the cache contains all the rest of the proposed segment
137 segment_remaining = segments.segment(0, 0)
138 elif segment_remaining[0] < c.segment[0]:
139 # there is a discontinuity
140 missing_segments.append(
141 segments.segment(segment_remaining[0], c.segment[0])
142 )
143 if c.segment[1] <= segment_remaining[1]:
144 segment_remaining = segments.segment(
145 c.segment[1], segment_remaining[1]
146 )
147 else:
148 segment_remaining = segments.segment(0, 0)
149 else:
150 segment_remaining -= c.segment
152 if segment_remaining:
153 missing_segments.append(segment_remaining)
155 if missing_segments:
156 self.logger.warning(
157 "Cache has missing segments %s, padding with gaps",
158 missing_segments,
159 )
161 self.A = {c: Audioadapter() for c in self.channel_names}
163 # load first segment of data to read sample rate
164 self.rates: dict[str, int] = {}
165 self.load_gwf_data(self.cache[0])
166 self.logger.info("Initialized sample rates per channel: %s", self.rates)
168 # Set buffer parameters for each pad (supports different sample rates per pad)
169 for pad in self.source_pads:
170 self.set_pad_buffer_params(
171 sample_shape=(), rate=self.rates[self.rsrcs[pad]], pad=pad
172 )
174 # now that we have loaded data from this frame,
175 # remove it from the cache
176 self.cache.pop(0)
178 @staticmethod
179 def ifo_strings(ifo: str) -> tuple[str, str]:
180 """Make a tuple of possible ifo strings, with and without the "1" at the end.
181 I dont know if the given self.instrument will be in the form of e.g., "H" or
182 "H1", just make a tuple of both options for string comparison
184 Args:
185 ifo:
186 str, the ifo name, e.g., "H" or "H1"
188 Returns:
189 tuple[str, str], a tuple of the ifo name with and without the "1" at the end
190 """
191 if ifo[-1] == "1":
192 return (ifo[0], ifo)
193 return (ifo, ifo + "1")
195 def load_gwf_data(self, frame_file: CacheEntry) -> None:
196 """Load timeseries data from a gwf frame file.
198 Args:
199 frame_file:
200 CacheEntry, the gwf frame file to read timeseries data from
202 Returns:
203 dict[str, np.ndarray], a dictionary with channel names as keys and
204 numpy arrays of timeseries data as values
205 """
207 # get first cache entry
208 segment = frame_file.segment
210 intersection = self.analysis_seg & segment
211 start = intersection[0]
212 end = intersection[1]
214 data_dict = gwframe.read(
215 frame_file.path, channel=self.channel_names, start=start, end=end
216 )
218 if len(self.rates) == 0:
219 for channel, data in data_dict.items():
220 self.rates[channel] = int(data.sample_rate)
222 for channel, data in data_dict.items():
223 if self.last_epoch < start:
224 self.logger.warning(
225 "Unexpected epoch: %f, expected: %f, sending gap buffer",
226 start,
227 self.last_epoch,
228 )
229 self.A[channel].push(
230 SeriesBuffer(
231 offset=Offset.fromsec(self.last_epoch),
232 sample_rate=self.rates[channel],
233 data=None,
234 shape=(int((start - self.last_epoch) * self.rates[channel]),),
235 )
236 )
237 elif self.last_epoch > start:
238 msg = (
239 f"Unepected epoch: {start}, expected: {self.last_epoch}, sending "
240 "gap buffer"
241 )
242 raise ValueError(msg)
243 self.A[channel].push(
244 SeriesBuffer(
245 offset=Offset.fromsec(float(start)),
246 sample_rate=self.rates[channel],
247 data=data.array,
248 )
249 )
251 self.last_epoch = end
253 def internal(self) -> None:
254 """Check if we need to read the next gw frame file in the cache. All channels
255 are read at once.
256 """
258 # load next frame of data from disk when we have less than
259 # one buffer length of data left
260 read_new = False
261 for channel, adapter in self.A.items():
262 if adapter.size < self.num_samples(self.rates[channel]):
263 read_new = True
264 break
266 if read_new and self.cache:
267 # Read multiple channels at once
268 self.load_gwf_data(self.cache[0])
270 # now that we have loaded data from this frame,
271 # remove it from the cache
272 self.cache.pop(0)
274 def new(self, pad: SourcePad) -> TSFrame:
275 """New frames are created on "pad" with an instance specific count and a name
276 derived from the channel name. "EOS" is set once we have procssed all data in
277 the cache within the analysis segment.
279 Args:
280 pad:
281 SourcePad, the pad for which to produce a new TSFrame
283 Returns:
284 TSFrame, the TSFrame that carries a list of SeriesBuffers
285 """
287 self.cnt[pad] += 1
289 channel = self.rsrcs[pad]
291 metadata = {"cnt": self.cnt[pad], "name": "'%s'" % pad.name}
293 frame = self.prepare_frame(pad, metadata=metadata)
295 if self.A[channel].end_offset >= frame.end_offset:
296 bufs = self.A[channel].get_sliced_buffers((frame.offset, frame.end_offset))
298 frame.set_buffers(list(bufs))
300 self.A[channel].flush_samples_by_end_offset(frame.end_offset)
302 return frame