1#!/usr/bin/env python2.7 2# Copyright 2015 gRPC authors. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16"""Starts a local DNS server for use in tests""" 17 18import argparse 19import sys 20import yaml 21import signal 22import os 23import threading 24import time 25 26import twisted 27import twisted.internet 28import twisted.internet.reactor 29import twisted.internet.threads 30import twisted.internet.defer 31import twisted.internet.protocol 32import twisted.names 33import twisted.names.client 34import twisted.names.dns 35import twisted.names.server 36from twisted.names import client, server, common, authority, dns 37import argparse 38import platform 39 40_SERVER_HEALTH_CHECK_RECORD_NAME = 'health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp' # missing end '.' for twisted syntax 41_SERVER_HEALTH_CHECK_RECORD_DATA = '123.123.123.123' 42 43class NoFileAuthority(authority.FileAuthority): 44 def __init__(self, soa, records): 45 # skip FileAuthority 46 common.ResolverBase.__init__(self) 47 self.soa = soa 48 self.records = records 49 50def start_local_dns_server(args): 51 all_records = {} 52 def _push_record(name, r): 53 print('pushing record: |%s|' % name) 54 if all_records.get(name) is not None: 55 all_records[name].append(r) 56 return 57 all_records[name] = [r] 58 59 def _maybe_split_up_txt_data(name, txt_data, r_ttl): 60 start = 0 61 txt_data_list = [] 62 while len(txt_data[start:]) > 0: 63 next_read = len(txt_data[start:]) 64 if next_read > 255: 65 next_read = 255 66 txt_data_list.append(txt_data[start:start+next_read]) 67 start += next_read 68 _push_record(name, dns.Record_TXT(*txt_data_list, ttl=r_ttl)) 69 70 with open(args.records_config_path) as config: 71 test_records_config = yaml.load(config) 72 common_zone_name = test_records_config['resolver_tests_common_zone_name'] 73 for group in test_records_config['resolver_component_tests']: 74 for name in group['records'].keys(): 75 for record in group['records'][name]: 76 r_type = record['type'] 77 r_data = record['data'] 78 r_ttl = int(record['TTL']) 79 record_full_name = '%s.%s' % (name, common_zone_name) 80 assert record_full_name[-1] == '.' 81 record_full_name = record_full_name[:-1] 82 if r_type == 'A': 83 _push_record(record_full_name, dns.Record_A(r_data, ttl=r_ttl)) 84 if r_type == 'AAAA': 85 _push_record(record_full_name, dns.Record_AAAA(r_data, ttl=r_ttl)) 86 if r_type == 'SRV': 87 p, w, port, target = r_data.split(' ') 88 p = int(p) 89 w = int(w) 90 port = int(port) 91 target_full_name = '%s.%s' % (target, common_zone_name) 92 r_data = '%s %s %s %s' % (p, w, port, target_full_name) 93 _push_record(record_full_name, dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl)) 94 if r_type == 'TXT': 95 _maybe_split_up_txt_data(record_full_name, r_data, r_ttl) 96 # Server health check record 97 _push_record(_SERVER_HEALTH_CHECK_RECORD_NAME, dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0)) 98 soa_record = dns.Record_SOA(mname = common_zone_name) 99 test_domain_com = NoFileAuthority( 100 soa = (common_zone_name, soa_record), 101 records = all_records, 102 ) 103 server = twisted.names.server.DNSServerFactory( 104 authorities=[test_domain_com], verbose=2) 105 server.noisy = 2 106 twisted.internet.reactor.listenTCP(args.port, server) 107 dns_proto = twisted.names.dns.DNSDatagramProtocol(server) 108 dns_proto.noisy = 2 109 twisted.internet.reactor.listenUDP(args.port, dns_proto) 110 print('starting local dns server on 127.0.0.1:%s' % args.port) 111 print('starting twisted.internet.reactor') 112 twisted.internet.reactor.suggestThreadPoolSize(1) 113 twisted.internet.reactor.run() 114 115def _quit_on_signal(signum, _frame): 116 print('Received SIGNAL %d. Quitting with exit code 0' % signum) 117 twisted.internet.reactor.stop() 118 sys.stdout.flush() 119 sys.exit(0) 120 121def flush_stdout_loop(): 122 num_timeouts_so_far = 0 123 sleep_time = 1 124 # Prevent zombies. Tests that use this server are short-lived. 125 max_timeouts = 60 * 2 126 while num_timeouts_so_far < max_timeouts: 127 sys.stdout.flush() 128 time.sleep(sleep_time) 129 num_timeouts_so_far += 1 130 print('Process timeout reached, or cancelled. Exitting 0.') 131 os.kill(os.getpid(), signal.SIGTERM) 132 133def main(): 134 argp = argparse.ArgumentParser(description='Local DNS Server for resolver tests') 135 argp.add_argument('-p', '--port', default=None, type=int, 136 help='Port for DNS server to listen on for TCP and UDP.') 137 argp.add_argument('-r', '--records_config_path', default=None, type=str, 138 help=('Directory of resolver_test_record_groups.yaml file. ' 139 'Defauls to path needed when the test is invoked as part of run_tests.py.')) 140 args = argp.parse_args() 141 signal.signal(signal.SIGTERM, _quit_on_signal) 142 signal.signal(signal.SIGINT, _quit_on_signal) 143 output_flush_thread = threading.Thread(target=flush_stdout_loop) 144 output_flush_thread.setDaemon(True) 145 output_flush_thread.start() 146 start_local_dns_server(args) 147 148if __name__ == '__main__': 149 main() 150