mirror of
https://github.com/ytdl-org/youtube-dl
synced 2024-11-30 14:42:58 +01:00
[utils] Add ability to control skipping false values in dict_get
This commit is contained in:
parent
52f5889f77
commit
86296ad2cd
@ -452,9 +452,15 @@ class TestUtil(unittest.TestCase):
|
|||||||
self.assertTrue(isinstance(data, bytes))
|
self.assertTrue(isinstance(data, bytes))
|
||||||
|
|
||||||
def test_dict_get(self):
|
def test_dict_get(self):
|
||||||
d = {
|
FALSE_VALUES = {
|
||||||
'a': 42,
|
'none': None,
|
||||||
|
'false': False,
|
||||||
|
'zero': 0,
|
||||||
|
'empty_string': '',
|
||||||
|
'empty_list': [],
|
||||||
}
|
}
|
||||||
|
d = FALSE_VALUES.copy()
|
||||||
|
d['a'] = 42
|
||||||
self.assertEqual(dict_get(d, 'a'), 42)
|
self.assertEqual(dict_get(d, 'a'), 42)
|
||||||
self.assertEqual(dict_get(d, 'b'), None)
|
self.assertEqual(dict_get(d, 'b'), None)
|
||||||
self.assertEqual(dict_get(d, 'b', 42), 42)
|
self.assertEqual(dict_get(d, 'b', 42), 42)
|
||||||
@ -463,6 +469,9 @@ class TestUtil(unittest.TestCase):
|
|||||||
self.assertEqual(dict_get(d, ('b', 'c', 'a', 'd', )), 42)
|
self.assertEqual(dict_get(d, ('b', 'c', 'a', 'd', )), 42)
|
||||||
self.assertEqual(dict_get(d, ('b', 'c', )), None)
|
self.assertEqual(dict_get(d, ('b', 'c', )), None)
|
||||||
self.assertEqual(dict_get(d, ('b', 'c', ), 42), 42)
|
self.assertEqual(dict_get(d, ('b', 'c', ), 42), 42)
|
||||||
|
for key, false_value in FALSE_VALUES.items():
|
||||||
|
self.assertEqual(dict_get(d, ('b', 'c', key, )), None)
|
||||||
|
self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value)
|
||||||
|
|
||||||
def test_encode_compat_str(self):
|
def test_encode_compat_str(self):
|
||||||
self.assertEqual(encode_compat_str(b'\xd1\x82\xd0\xb5\xd1\x81\xd1\x82', 'utf-8'), 'тест')
|
self.assertEqual(encode_compat_str(b'\xd1\x82\xd0\xb5\xd1\x81\xd1\x82', 'utf-8'), 'тест')
|
||||||
|
@ -1717,10 +1717,11 @@ def encode_dict(d, encoding='utf-8'):
|
|||||||
return dict((encode(k), encode(v)) for k, v in d.items())
|
return dict((encode(k), encode(v)) for k, v in d.items())
|
||||||
|
|
||||||
|
|
||||||
def dict_get(d, key_or_keys, default=None):
|
def dict_get(d, key_or_keys, default=None, skip_false_values=True):
|
||||||
if isinstance(key_or_keys, (list, tuple)):
|
if isinstance(key_or_keys, (list, tuple)):
|
||||||
for key in key_or_keys:
|
for key in key_or_keys:
|
||||||
if d.get(key):
|
if key not in d or d[key] is None or skip_false_values and not d[key]:
|
||||||
|
continue
|
||||||
return d[key]
|
return d[key]
|
||||||
return default
|
return default
|
||||||
return d.get(key_or_keys, default)
|
return d.get(key_or_keys, default)
|
||||||
|
Loading…
Reference in New Issue
Block a user