Source code for rfpipe.source

from __future__ import print_function, division, absolute_import, unicode_literals
from builtins import bytes, dict, object, range, map, input, str
from future.utils import itervalues, viewitems, iteritems, listvalues, listitems

import os.path
import numpy as np
from numba import jit
from astropy import time
import sdmpy
import pwkit.environments.casa.util as casautil
from rfpipe import util, calibration, fileLock
import pickle

import logging
logger = logging.getLogger(__name__)

qa = casautil.tools.quanta()
default_timeout = 10  # multiple of read time in seconds to wait


[docs]def data_prep(st, segment, data, flagversion="latest"): """ Applies calibration, flags, and subtracts time mean for data. flagversion can be "latest" or "rtpipe". Optionally prepares data with antenna flags, fixing out of order data, calibration, downsampling, etc.. """ if not np.any(data): return data # take pols of interest takepol = [st.metadata.pols_orig.index(pol) for pol in st.pols] logger.debug('Selecting pols {0}'.format(st.pols)) datap = data.take(takepol, axis=3).copy() datap = prep_standard(st, segment, datap) datap = calibration.apply_telcal(st, datap) # support backwards compatibility for reproducible flagging if flagversion == "latest": datap = flag_data(st, datap) elif flagversion == "rtpipe": datap = flag_data_rtpipe(st, datap) if st.prefs.timesub == 'mean': logger.info('Subtracting mean visibility in time.') datap = util.meantsub(datap, parallel=st.prefs.nthread > 1) else: logger.info('No visibility subtraction done.') if st.prefs.savenoise: save_noise(st, segment, datap.take(st.chans, axis=2)) logger.debug('Selecting chans {0}'.format(st.chans)) return datap.take(st.chans, axis=2)
[docs]def read_segment(st, segment, cfile=None, timeout=default_timeout): """ Read a segment of data. cfile and timeout are specific to vys data. Returns data as defined in metadata (no downselection yet) """ # assumed read shape (st.readints, st.nbl, st.metadata.nchan_orig, st.npol) if st.metadata.datasource == 'sdm': data_read = read_bdf_segment(st, segment) elif st.metadata.datasource == 'vys': data_read = read_vys_segment(st, segment, cfile=cfile, timeout=timeout) elif st.metadata.datasource == 'sim': data_read = simulate_segment(st) else: logger.error('Datasource {0} not recognized.' .format(st.metadata.datasource)) if not np.any(data_read): logger.info('No data read.') return np.array([]) else: return data_read
[docs]def prep_standard(st, segment, data): """ Common first data prep stages, incl online flags, resampling, and mock transients. """ if not np.any(data): return data # read Flag.xml and apply flags for given ant/time range if st.prefs.applyonlineflags and st.metadata.datasource == 'sdm': sdm = getsdm(st.metadata.filename, bdfdir=st.metadata.bdfdir) scan = sdm.scan(st.metadata.scan) # segment flagged from logical OR from (start, stop) flags t0, t1 = st.segmenttimes[segment] # 0=bad, 1=good. axis=0 is time axis. flags = scan.flags([t0, t1]).all(axis=0) if not flags.all(): logger.info('Found antennas to flag in time range {0}-{1} ' .format(t0, t1)) data = np.where(flags[None, :, None, None] == 1, np.require(data, requirements='W'), 0j) else: logger.info('No flagged antennas in time range {0}-{1} ' .format(t0, t1)) else: logger.info('Not applying online flags.') # optionally integrate (downsample) if ((st.prefs.read_tdownsample > 1) or (st.prefs.read_fdownsample > 1)): data2 = np.zeros(st.datashape, dtype='complex64') if st.prefs.read_tdownsample > 1: logger.info('Downsampling in time by {0}' .format(st.prefs.read_tdownsample)) for i in range(st.datashape[0]): data2[i] = data[ i*st.prefs.read_tdownsample:(i+1)*st.prefs.read_tdownsample].mean(axis=0) if st.prefs.read_fdownsample > 1: logger.info('Downsampling in frequency by {0}' .format(st.prefs.read_fdownsample)) for i in range(st.datashape[2]): data2[:, :, i, :] = data[:, :, i*st.prefs.read_fdownsample:(i+1)*st.prefs.read_fdownsample].mean(axis=2) data = data2 # optionally add transients if st.prefs.simulated_transient is not None: assert isinstance(st.prefs.simulated_transient, list), "Simulated transient must be list of tuples." uvw = util.get_uvw_segment(st, segment) for params in st.prefs.simulated_transient: assert len(params) == 7, ("Transient requires 7 parameters: " "(segment, i0/int, dm/pc/cm3, dt/s, " "amp/sys, dl/rad, dm/rad)") (mock_segment, i0, dm, dt, amp, l, m) = params if segment == mock_segment: logger.info("Adding transient to segment {0} at int {1}, DM {2}, " "dt {3} with amp {4} and l,m={5},{6}" .format(mock_segment, i0, dm, dt, amp, l, m)) try: model = np.require(np.broadcast_to(generate_transient(st, amp, i0, dm, dt) .transpose()[:, None, :, None], data.shape), requirements='W') except IndexError: logger.warn("IndexError while adding transient. Skipping...") continue model = calibration.apply_telcal(st, model, sign=-1) util.phase_shift(model, uvw, -l, -m) data += model return data
[docs]def read_vys_segment(st, seg, cfile=None, timeout=default_timeout, offset=4): """ Read segment seg defined by state st from vys stream. Uses vysmaw application timefilter to receive multicast messages and pull spectra on the CBE. timeout is a multiple of read time in seconds to wait. offset is extra time in seconds to keep vys reader open. """ # TODO: support for time downsampling try: import vysmaw_reader except ImportError: logger.error('ImportError for vysmaw_reader. Cannot ' 'consume vys data.') t0 = time.Time(st.segmenttimes[seg][0], format='mjd', precision=9).unix t1 = time.Time(st.segmenttimes[seg][1], format='mjd', precision=9).unix # data = np.empty((st.readints, st.nbl, # st.metadata.nchan_orig, st.metadata.npol_orig), # dtype='complex64', order='C') logger.info('Reading {0} s ints into shape {1} from {2} - {3} unix seconds' .format(st.metadata.inttime, st.datashape_orig, t0, t1)) polmap_standard = ['A*A', 'A*B', 'B*A', 'B*B'] bbmap_standard = ['AC1', 'AC2', 'AC', 'BD1', 'BD2', 'BD'] # TODO: vysmaw currently pulls all data, but allocates buffer based on these. # buffer will be too small if taking subset of all data. pollist = np.array([polmap_standard.index(pol) for pol in st.metadata.pols_orig], dtype=np.int32) # TODO: use st.pols when vysmaw filter can too antlist = np.array([int(ant.lstrip('ea')) for ant in st.ants], dtype=np.int32) spwlist = list(zip(*st.metadata.spworder))[0] # list of strings ["bb-spw"] in increasing freq order bbsplist = np.array([(int(bbmap_standard.index(spw.split('-')[0])), int(spw.split('-')[1])) for spw in spwlist], dtype=np.int32) with vysmaw_reader.Reader(t0, t1, antlist, pollist, bbsplist, inttime_micros=st.metadata.inttime*1000000., nchan=st.metadata.spw_nchan[0], cfile=cfile, timeout=timeout, offset=offset) as reader: if reader is not None: data = reader.readwindow() else: data = None # TODO: move pol selection up and into vysmaw filter function if data is not None: return data else: return np.array([])
[docs]def read_bdf_segment(st, segment): """ Uses sdmpy to reads bdf (sdm) format data into numpy array in given segment. Each segment has st.readints integrations. """ assert segment < st.nsegment, ('segment {0} is too big for nsegment {1}' .format(segment, st.nsegment)) # define integration range nskip = (24*3600*(st.segmenttimes[segment, 0] - st.metadata.starttime_mjd)/st.metadata.inttime).astype(int) logger.info('Reading scan {0}, segment {1}/{2}, times {3} to {4}' .format(st.metadata.scan, segment, len(st.segmenttimes)-1, qa.time(qa.quantity(st.segmenttimes[segment, 0], 'd'), form=['hms'], prec=9)[0], qa.time(qa.quantity(st.segmenttimes[segment, 1], 'd'), form=['hms'], prec=9)[0])) data = read_bdf(st, nskip=nskip).astype('complex64') return data
[docs]def read_bdf(st, nskip=0): """ Uses sdmpy to read a given range of integrations from sdm of given scan. readints=0 will read all of bdf (skipping nskip). Returns data in increasing frequency order. """ assert os.path.exists(st.metadata.filename), ('sdmfile {0} does not exist' .format(st.metadata.filename)) assert st.metadata.bdfstr, ('bdfstr not defined for scan {0}' .format(st.metadata.scan)) logger.info('Reading %d ints starting at int %d' % (st.readints, nskip)) sdm = getsdm(st.metadata.filename, bdfdir=st.metadata.bdfdir) scan = sdm.scan(st.metadata.scan) data = np.empty((st.readints, st.metadata.nbl_orig, st.metadata.nchan_orig, st.metadata.npol_orig), dtype='complex64', order='C') sortind = np.argsort(st.metadata.spw_reffreq) for i in range(nskip, nskip+st.readints): read = scan.bdf.get_integration(i).get_data(spwidx='all', type='cross') data[i-nskip] = read.take(sortind, axis=1).reshape(st.metadata.nbl_orig, st.metadata.nchan_orig, st.metadata.npol_orig) # data[:] = scan.bdf.get_data(trange=[nskip, nskip+st.readints]).reshape(data.shape) return data
[docs]def save_noise(st, segment, data, chunk=200): """ Calculates noise properties and save values to pickle. chunk defines window for measurement. at least one measurement always made. """ from rfpipe.search import grid_image uvw = util.get_uvw_segment(st, segment) chunk = min(chunk, max(1, st.readints-1)) # ensure at least one measurement ranges = list(zip(list(range(0, st.readints-chunk, chunk)), list(range(chunk, st.readints, chunk)))) results = [] for (r0, r1) in ranges: imid = (r0+r1)//2 noiseperbl = estimate_noiseperbl(data[r0:r1]) imstd = grid_image(data, uvw, st.npixx, st.npixy, st.uvres, 'fftw', 1, integrations=imid).std() zerofrac = float(len(np.where(data[r0:r1] == 0j)[0]))/data[r0:r1].size results.append((segment, imid, noiseperbl, zerofrac, imstd)) try: noisefile = st.noisefile with fileLock.FileLock(noisefile+'.lock', timeout=10): with open(noisefile, 'ab+') as pkl: pickle.dump(results, pkl) except fileLock.FileLock.FileLockException: noisefile = ('{0}_seg{1}.pkl' .format(st.noisefile.rstrip('.pkl'), segment)) logger.warn('Noise file writing timeout. ' 'Spilling to new file {0}.'.format(noisefile)) with open(noisefile, 'ab+') as pkl: pickle.dump(results, pkl) if len(results): logger.info('Wrote {0} noise measurement{1} from segment {2} to {3}' .format(len(results), 's'[:len(results)-1], segment, noisefile))
[docs]def estimate_noiseperbl(data): """ Takes large data array and sigma clips it to find noise per bl for input to detect_bispectra. Takes mean across pols and channels for now, as in detect_bispectra. """ # define noise per baseline for data seen by detect_bispectra or image datamean = data.mean(axis=2).imag # use imaginary part to estimate noise without calibrated, on-axis signal noiseperbl = datamean.std() # measure single noise for input to detect_bispectra logger.debug('Measured noise per baseline of {0:.3f}'.format(noiseperbl)) return noiseperbl
[docs]def flag_data(st, data): """ Identifies bad data and flags it to 0. """ # data = np.ma.masked_equal(data, 0j) # TODO remove this and ignore zeros manually flags = np.ones_like(data, dtype=bool) for flagparams in st.prefs.flaglist: mode, arg0, arg1 = flagparams if mode == 'blstd': flags *= flag_blstd(data, arg0, arg1)[:, None, :, :] elif mode == 'badchtslide': flags *= flag_badchtslide(data, arg0, arg1)[:, None, :, :] else: logger.warn("Flaging mode {0} not available.".format(mode)) return data*flags
[docs]def flag_blstd(data, sigma, convergence): """ Use data (4d) to calculate (int, chan, pol) to be flagged. """ sh = data.shape flags = np.ones((sh[0], sh[2], sh[3]), dtype=bool) blstd = data.std(axis=1) # iterate to good median and std values blstdmednew = np.ma.median(blstd) blstdstdnew = blstd.std() blstdstd = blstdstdnew*2 # TODO: is this initialization used? while (blstdstd-blstdstdnew)/blstdstd > convergence: blstdstd = blstdstdnew blstdmed = blstdmednew blstd = np.ma.masked_where(blstd > blstdmed + sigma*blstdstd, blstd, copy=False) blstdmednew = np.ma.median(blstd) blstdstdnew = blstd.std() # flag blstd too high badt, badch, badpol = np.where(blstd > blstdmednew + sigma*blstdstdnew) logger.info("flag by blstd: {0} of {1} total channel/time/pol cells flagged." .format(len(badt), sh[0]*sh[2]*sh[3])) for i in range(len(badt)): flags[badt[i], badch[i], badpol[i]] = False return flags
[docs]def flag_badchtslide(data, sigma, win): """ Use data (4d) to calculate (int, chan, pol) to be flagged """ sh = data.shape flags = np.ones((sh[0], sh[2], sh[3]), dtype=bool) meanamp = np.abs(data).mean(axis=1) spec = meanamp.mean(axis=0) lc = meanamp.mean(axis=1) # calc badch as deviation from median of window specmed = slidedev(spec, win) badch = np.where(specmed > sigma*specmed.std(axis=0)) # calc badt as deviation from median of window lcmed = slidedev(lc, win) badt = np.where(lcmed > sigma*lcmed.std(axis=0)) badtcnt = len(np.unique(badt)) badchcnt = len(np.unique(badch)) logger.info("flag by badchtslide: {0}/{1} pol-times and {2}/{3} pol-chans flagged." .format(badtcnt, sh[0]*sh[3], badchcnt, sh[2]*sh[3])) for i in range(len(badch[0])): flags[:, badch[0][i], badch[1][i]] = False for i in range(len(badt[0])): flags[badt[0][i], :, badt[1][i]] = False return flags
[docs]@jit def slidedev(arr, win): """ Given a (len x 2) array, calculate the deviation from the median per pol. Calculates median over a window, win. """ med = np.zeros((len(arr), 2)) for i in range(len(arr)): inds = list(range(max(0, i-win//2), i)) + list(range(i+1, min(i+win//2, len(arr)))) for j in inds: med[j] = np.ma.median(arr.take(inds, axis=0), axis=0) return arr-med
[docs]def flag_data_rtpipe(st, data): """ Flagging data in single process Deprecated. """ try: import rtlib_cython as rtlib except ImportError: logger.error("rtpipe not installed. Cannot import rtlib for flagging.") # **hack!** d = {'dataformat': 'sdm', 'ants': [int(ant.lstrip('ea')) for ant in st.ants], 'excludeants': st.prefs.excludeants, 'nants': len(st.ants)} for flag in st.prefs.flaglist: mode, sig, conv = flag for spw in st.spw: chans = np.arange(st.metadata.spw_nchan[spw]*spw, st.metadata.spw_nchan[spw]*(1+spw)) for pol in range(st.npol): status = rtlib.dataflag(data, chans, pol, d, sig, mode, conv) logger.info(status) # hack to get rid of bad spw/pol combos whacked by rfi if st.prefs.badspwpol: logger.info('Comparing overall power between spw/pol. Removing those with {0} times typical value'.format(st.prefs.badspwpol)) spwpol = {} for spw in st.spw: chans = np.arange(st.metadata.spw_nchan[spw]*spw, st.metadata.spw_nchan[spw]*(1+spw)) for pol in range(st.npol): spwpol[(spw, pol)] = np.abs(data[:, :, chans, pol]).std() meanstd = np.mean(list(spwpol.values())) for (spw,pol) in spwpol: if spwpol[(spw, pol)] > st.prefs.badspwpol*meanstd: logger.info('Flagging all of (spw %d, pol %d) for excess noise.' % (spw, pol)) chans = np.arange(st.metadata.spw_nchan[spw]*spw, st.metadata.spw_nchan[spw]*(1+spw)) data[:, :, chans, pol] = 0j return data
[docs]def simulate_segment(st, loc=0., scale=1.): """ Simulates visibilities for a segment. """ logger.info('Simulating data with shape {0}'.format(st.datashape_orig)) data = np.empty(st.datashape_orig, dtype='complex64', order='C') data.real = np.random.normal(loc=loc, scale=scale, size=st.datashape_orig).astype(np.float32) data.imag = np.random.normal(loc=loc, scale=scale, size=st.datashape_orig).astype(np.float32) return data
[docs]def sdm_sources(sdmname): """ Use sdmpy to get all sources and ra,dec per scan as dict """ sdm = getsdm(sdmname) sourcedict = {} for row in sdm['Field']: src = str(row.fieldName) sourcenum = int(row.sourceId) direction = str(row.referenceDir) # skip first two values in string (ra, dec) = [float(val) for val in direction.split(' ')[3:]] sourcedict[sourcenum] = {} sourcedict[sourcenum]['source'] = src sourcedict[sourcenum]['ra'] = ra sourcedict[sourcenum]['dec'] = dec return sourcedict
[docs]def getsdm(*args, **kwargs): """ Wrap sdmpy.SDM to get around schema change error """ try: sdm = sdmpy.SDM(*args, **kwargs) except: kwargs['use_xsd'] = False sdm = sdmpy.SDM(*args, **kwargs) return sdm
[docs]def generate_transient(st, amp, i0, dm, dt): """ Create a dynamic spectrum for given parameters amp is in system units (post calibration) i0 is a float for integration relative to start of segment. dm/dt are in units of pc/cm3 and seconds, respectively """ model = np.zeros((st.metadata.nchan_orig, st.readints), dtype='complex64') chans = np.arange(st.nchan) i = i0 + util.calc_delay2(st.freq, st.freq.max(), dm)/st.inttime # print(i) i_f = np.floor(i).astype(int) imax = np.ceil(i + dt/st.inttime).astype(int) imin = i_f i_r = imax - imin # print(i_r) if np.any(i_r == 1): ir1 = np.where(i_r == 1) # print(ir1) model[chans[ir1], i_f[ir1]] += amp if np.any(i_r == 2): ir2 = np.where(i_r == 2) i_c = np.ceil(i).astype(int) f1 = (dt/st.inttime - (i_c - i))/(dt/st.inttime) f0 = 1 - f1 # print(np.vstack((ir2, f0[ir2], f1[ir2])).transpose()) model[chans[ir2], i_f[ir2]] += f0[ir2]*amp model[chans[ir2], i_f[ir2]+1] += f1[ir2]*amp if np.any(i_r == 3): ir3 = np.where(i_r == 3) f2 = (i + dt/st.inttime - (imax - 1))/(dt/st.inttime) f0 = ((i_f + 1) - i)/(dt/st.inttime) f1 = 1 - f2 - f0 # print(np.vstack((ir3, f0[ir3], f1[ir3], f2[ir3])).transpose()) model[chans[ir3], i_f[ir3]] += f0[ir3]*amp model[chans[ir3], i_f[ir3]+1] += f1[ir3]*amp model[chans[ir3], i_f[ir3]+2] += f2[ir3]*amp return model