Coverage for src / sgn_gwframe / sources / frame.py: 96.0%

125 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-02-12 13:29 -0800

1"""Read GW frame files from a frame cache file.""" 

2 

3# Copyright (C) 2024 Becca Ewing, Yun-Jing Huang 

4 

5from __future__ import annotations 

6 

7import logging 

8from dataclasses import dataclass 

9from typing import TYPE_CHECKING, Optional 

10 

11import gwframe 

12import igwn_segments as segments 

13from sgnts.base import Audioadapter, Offset, SeriesBuffer, TSFrame, TSSource 

14 

15if TYPE_CHECKING: 

16 from sgn.base import SourcePad 

17 

18logger = logging.getLogger("sgn") 

19 

20 

21@dataclass 

22class CacheEntry: 

23 """Simple cache entry parser for frame cache files. 

24 

25 Cache file format: observatory description gps_start duration path 

26 Example: L L1_GWOSC_16KHZ_R1 1240215487 32 ./path/to/file.gwf 

27 """ 

28 

29 observatory: str 

30 description: str 

31 gps_start: float 

32 duration: float 

33 path: str 

34 

35 @classmethod 

36 def from_line(cls, line: str) -> CacheEntry: 

37 """Parse a cache file line into a CacheEntry.""" 

38 parts = line.strip().split() 

39 if len(parts) != 5: 

40 msg = f"Invalid cache line format: {line}" 

41 raise ValueError(msg) 

42 return cls( 

43 observatory=parts[0], 

44 description=parts[1], 

45 gps_start=float(parts[2]), 

46 duration=float(parts[3]), 

47 path=parts[4], 

48 ) 

49 

50 @property 

51 def segment(self) -> segments.segment: 

52 """Return the time segment for this cache entry.""" 

53 return segments.segment(self.gps_start, self.gps_start + self.duration) 

54 

55 

56@dataclass(kw_only=True) 

57class FrameSource(TSSource): 

58 """Read GW frame files from a frame cache file 

59 

60 Args: 

61 channel_names: 

62 list[str], a list of channel names of the data, e.g., 

63 ["L1:GWOSC-16KHZ_R1_STRAIN", "L1:GWOSC-16KHZ_R1_DQMASK"]. Source pads will 

64 be automatically generated for each channel, with channel name as pad name. 

65 framecache: 

66 str, cache file to read data from 

67 instrument: 

68 str, optional, only read gwf files from this instrument. Default: None 

69 """ 

70 

71 channel_names: list[str] 

72 framecache: str 

73 instrument: Optional[str] = None 

74 

75 def __post_init__(self) -> None: 

76 if len(self.source_pad_names) > 0: 

77 if self.source_pad_names != tuple(self.channel_names): 

78 msg = "Expected source pad names to match channel names" 

79 raise ValueError(msg) 

80 else: 

81 self.source_pad_names = tuple(self.channel_names) 

82 

83 super().__post_init__() 

84 self.cnt = dict.fromkeys(self.source_pads, 0) 

85 

86 if self.start is None: 

87 msg = "FrameSource requires a start time to be specified" 

88 raise ValueError(msg) 

89 self.last_epoch: float = self.start 

90 

91 if self.instrument is not None: 

92 for channel in self.channel_names: 

93 if self.instrument not in channel: 

94 msg = ( 

95 f"Instrument '{self.instrument}' does not match " 

96 f"channel name '{channel}'" 

97 ) 

98 raise ValueError(msg) 

99 

100 # init analysis segment 

101 self.analysis_seg = segments.segment(self.start, self.end) 

102 

103 # load the cache file 

104 self.logger.info("Loading frame cache from %s", self.framecache) 

105 with open(self.framecache) as f: 

106 cache: list[CacheEntry] = [ 

107 CacheEntry.from_line(line) for line in f if line.strip() 

108 ] 

109 

110 if self.instrument is not None: 

111 # only keep files with the correct instrument 

112 cache = [ 

113 entry 

114 for entry in cache 

115 if entry.observatory in self.ifo_strings(self.instrument) 

116 ] 

117 

118 # only keep files that intersect the analysis segment 

119 self.cache: list[CacheEntry] = [] 

120 for entry in cache: 

121 try: 

122 self.analysis_seg & entry.segment 

123 except ValueError: 

124 continue 

125 else: 

126 self.cache.append(entry) 

127 

128 # make sure it is sorted by gps time 

129 self.cache.sort(key=lambda x: x.segment[0]) 

130 

131 # Check if there are missing segments 

132 segment_remaining = self.analysis_seg 

133 missing_segments = [] 

134 for c in self.cache: 

135 if segment_remaining in c.segment: 

136 # the cache contains all the rest of the proposed segment 

137 segment_remaining = segments.segment(0, 0) 

138 elif segment_remaining[0] < c.segment[0]: 

139 # there is a discontinuity 

140 missing_segments.append( 

141 segments.segment(segment_remaining[0], c.segment[0]) 

142 ) 

143 if c.segment[1] <= segment_remaining[1]: 

144 segment_remaining = segments.segment( 

145 c.segment[1], segment_remaining[1] 

146 ) 

147 else: 

148 segment_remaining = segments.segment(0, 0) 

149 else: 

150 segment_remaining -= c.segment 

151 

152 if segment_remaining: 

153 missing_segments.append(segment_remaining) 

154 

155 if missing_segments: 

156 self.logger.warning( 

157 "Cache has missing segments %s, padding with gaps", 

158 missing_segments, 

159 ) 

160 

161 self.A = {c: Audioadapter() for c in self.channel_names} 

162 

163 # load first segment of data to read sample rate 

164 self.rates: dict[str, int] = {} 

165 self.load_gwf_data(self.cache[0]) 

166 self.logger.info("Initialized sample rates per channel: %s", self.rates) 

167 

168 # Set buffer parameters for each pad (supports different sample rates per pad) 

169 for pad in self.source_pads: 

170 self.set_pad_buffer_params( 

171 sample_shape=(), rate=self.rates[self.rsrcs[pad]], pad=pad 

172 ) 

173 

174 # now that we have loaded data from this frame, 

175 # remove it from the cache 

176 self.cache.pop(0) 

177 

178 @staticmethod 

179 def ifo_strings(ifo: str) -> tuple[str, str]: 

180 """Make a tuple of possible ifo strings, with and without the "1" at the end. 

181 I dont know if the given self.instrument will be in the form of e.g., "H" or 

182 "H1", just make a tuple of both options for string comparison 

183 

184 Args: 

185 ifo: 

186 str, the ifo name, e.g., "H" or "H1" 

187 

188 Returns: 

189 tuple[str, str], a tuple of the ifo name with and without the "1" at the end 

190 """ 

191 if ifo[-1] == "1": 

192 return (ifo[0], ifo) 

193 return (ifo, ifo + "1") 

194 

195 def load_gwf_data(self, frame_file: CacheEntry) -> None: 

196 """Load timeseries data from a gwf frame file. 

197 

198 Args: 

199 frame_file: 

200 CacheEntry, the gwf frame file to read timeseries data from 

201 

202 Returns: 

203 dict[str, np.ndarray], a dictionary with channel names as keys and 

204 numpy arrays of timeseries data as values 

205 """ 

206 

207 # get first cache entry 

208 segment = frame_file.segment 

209 

210 intersection = self.analysis_seg & segment 

211 start = intersection[0] 

212 end = intersection[1] 

213 

214 data_dict = gwframe.read( 

215 frame_file.path, channel=self.channel_names, start=start, end=end 

216 ) 

217 

218 if len(self.rates) == 0: 

219 for channel, data in data_dict.items(): 

220 self.rates[channel] = int(data.sample_rate) 

221 

222 for channel, data in data_dict.items(): 

223 if self.last_epoch < start: 

224 self.logger.warning( 

225 "Unexpected epoch: %f, expected: %f, sending gap buffer", 

226 start, 

227 self.last_epoch, 

228 ) 

229 self.A[channel].push( 

230 SeriesBuffer( 

231 offset=Offset.fromsec(self.last_epoch), 

232 sample_rate=self.rates[channel], 

233 data=None, 

234 shape=(int((start - self.last_epoch) * self.rates[channel]),), 

235 ) 

236 ) 

237 elif self.last_epoch > start: 

238 msg = ( 

239 f"Unepected epoch: {start}, expected: {self.last_epoch}, sending " 

240 "gap buffer" 

241 ) 

242 raise ValueError(msg) 

243 self.A[channel].push( 

244 SeriesBuffer( 

245 offset=Offset.fromsec(float(start)), 

246 sample_rate=self.rates[channel], 

247 data=data.array, 

248 ) 

249 ) 

250 

251 self.last_epoch = end 

252 

253 def internal(self) -> None: 

254 """Check if we need to read the next gw frame file in the cache. All channels 

255 are read at once. 

256 """ 

257 

258 # load next frame of data from disk when we have less than 

259 # one buffer length of data left 

260 read_new = False 

261 for channel, adapter in self.A.items(): 

262 if adapter.size < self.num_samples(self.rates[channel]): 

263 read_new = True 

264 break 

265 

266 if read_new and self.cache: 

267 # Read multiple channels at once 

268 self.load_gwf_data(self.cache[0]) 

269 

270 # now that we have loaded data from this frame, 

271 # remove it from the cache 

272 self.cache.pop(0) 

273 

274 def new(self, pad: SourcePad) -> TSFrame: 

275 """New frames are created on "pad" with an instance specific count and a name 

276 derived from the channel name. "EOS" is set once we have procssed all data in 

277 the cache within the analysis segment. 

278 

279 Args: 

280 pad: 

281 SourcePad, the pad for which to produce a new TSFrame 

282 

283 Returns: 

284 TSFrame, the TSFrame that carries a list of SeriesBuffers 

285 """ 

286 

287 self.cnt[pad] += 1 

288 

289 channel = self.rsrcs[pad] 

290 

291 metadata = {"cnt": self.cnt[pad], "name": "'%s'" % pad.name} 

292 

293 frame = self.prepare_frame(pad, metadata=metadata) 

294 

295 if self.A[channel].end_offset >= frame.end_offset: 

296 bufs = self.A[channel].get_sliced_buffers((frame.offset, frame.end_offset)) 

297 

298 frame.set_buffers(list(bufs)) 

299 

300 self.A[channel].flush_samples_by_end_offset(frame.end_offset) 

301 

302 return frame