1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480
|
"""Common code for converting proto to other formats, such as JSON."""
import base64
import collections
import datetime
import json
import logging
from protorpc import messages
from protorpc import protojson
from googlecloudapis.apitools.base.py import exceptions
__all__ = [
'CopyProtoMessage',
'JsonToMessage',
'MessageToJson',
'DictToMessage',
'MessageToDict',
'PyValueToMessage',
'MessageToPyValue',
'MessageToRepr',
]
_Codec = collections.namedtuple('_Codec', ['encoder', 'decoder'])
CodecResult = collections.namedtuple('CodecResult', ['value', 'complete'])
# TODO(user): Make these non-global.
_UNRECOGNIZED_FIELD_MAPPINGS = {}
_CUSTOM_MESSAGE_CODECS = {}
_CUSTOM_FIELD_CODECS = {}
_FIELD_TYPE_CODECS = {}
def MapUnrecognizedFields(field_name):
"""Register field_name as a container for unrecognized fields in message."""
def Register(cls):
_UNRECOGNIZED_FIELD_MAPPINGS[cls] = field_name
return cls
return Register
def RegisterCustomMessageCodec(encoder, decoder):
"""Register a custom encoder/decoder for this message class."""
def Register(cls):
_CUSTOM_MESSAGE_CODECS[cls] = _Codec(encoder=encoder, decoder=decoder)
return cls
return Register
def RegisterCustomFieldCodec(encoder, decoder):
"""Register a custom encoder/decoder for this field."""
def Register(field):
_CUSTOM_FIELD_CODECS[field] = _Codec(encoder=encoder, decoder=decoder)
return field
return Register
def RegisterFieldTypeCodec(encoder, decoder):
"""Register a custom encoder/decoder for all fields of this type."""
def Register(field_type):
_FIELD_TYPE_CODECS[field_type] = _Codec(encoder=encoder, decoder=decoder)
return field_type
return Register
# TODO(user): Delete this function with the switch to proto2.
def CopyProtoMessage(message):
codec = protojson.ProtoJson()
return codec.decode_message(type(message), codec.encode_message(message))
def MessageToJson(message, include_fields=None):
"""Convert the given message to JSON."""
result = _ProtoJsonApiTools.Get().encode_message(message)
return _IncludeFields(result, message, include_fields)
def JsonToMessage(message_type, message):
"""Convert the given JSON to a message of type message_type."""
return _ProtoJsonApiTools.Get().decode_message(message_type, message)
# TODO(user): Do this directly, instead of via JSON.
def DictToMessage(d, message_type):
"""Convert the given dictionary to a message of type message_type."""
return JsonToMessage(message_type, json.dumps(d))
def MessageToDict(message):
"""Convert the given message to a dictionary."""
return json.loads(MessageToJson(message))
def PyValueToMessage(message_type, value):
"""Convert the given python value to a message of type message_type."""
return JsonToMessage(message_type, json.dumps(value))
def MessageToPyValue(message):
"""Convert the given message to a python value."""
return json.loads(MessageToJson(message))
def MessageToRepr(msg, multiline=False, **kwargs):
"""Return a repr-style string for a protorpc message.
protorpc.Message.__repr__ does not return anything that could be considered
python code. Adding this function lets us print a protorpc message in such
a way that it could be pasted into code later, and used to compare against
other things.
Args:
msg: protorpc.Message, the message to be repr'd.
multiline: bool, True if the returned string should have each field
assignment on its own line.
**kwargs: {str:str}, Additional flags for how to format the string.
Known **kwargs:
shortstrings: bool, True if all string values should be truncated at
100 characters, since when mocking the contents typically don't matter
except for IDs, and IDs are usually less than 100 characters.
no_modules: bool, True if the long module name should not be printed with
each type.
Returns:
str, A string of valid python (assuming the right imports have been made)
that recreates the message passed into this function.
"""
# TODO(user): craigcitro suggests a pretty-printer from apitools/gen.
indent = kwargs.get('indent', 0)
def IndentKwargs(kwargs):
kwargs = dict(kwargs)
kwargs['indent'] = kwargs.get('indent', 0) + 4
return kwargs
if isinstance(msg, list):
s = '['
for item in msg:
if multiline:
s += '\n' + ' '*(indent + 4)
s += MessageToRepr(
item, multiline=multiline, **IndentKwargs(kwargs)) + ','
if multiline:
s += '\n' + ' '*indent
s += ']'
return s
if isinstance(msg, messages.Message):
s = type(msg).__name__ + '('
if not kwargs.get('no_modules'):
s = msg.__module__ + '.' + s
names = sorted([field.name for field in msg.all_fields()])
for name in names:
field = msg.field_by_name(name)
if multiline:
s += '\n' + ' '*(indent + 4)
value = getattr(msg, field.name)
s += field.name + '=' + MessageToRepr(
value, multiline=multiline, **IndentKwargs(kwargs)) + ','
if multiline:
s += '\n'+' '*indent
s += ')'
return s
if isinstance(msg, basestring):
if kwargs.get('shortstrings') and len(msg) > 100:
msg = msg[:100]
if isinstance(msg, datetime.datetime):
class SpecialTZInfo(datetime.tzinfo):
def __init__(self, offset):
super(SpecialTZInfo, self).__init__()
self.offset = offset
def __repr__(self):
s = 'TimeZoneOffset(' + repr(self.offset) + ')'
if not kwargs.get('no_modules'):
s = 'protorpc.util.' + s
return s
msg = datetime.datetime(
msg.year, msg.month, msg.day, msg.hour, msg.minute, msg.second,
msg.microsecond, SpecialTZInfo(msg.tzinfo.utcoffset(0)))
return repr(msg)
def _GetField(message, field_path):
for field in field_path:
if field not in dir(message):
raise KeyError('no field "%s"' % field)
message = getattr(message, field)
return message
def _SetFieldToNone(dictblob, field_path):
for field in field_path[:-1]:
dictblob[field] = {}
dictblob = dictblob[field]
dictblob[field_path[-1]] = None
def _IncludeFields(encoded_message, message, include_fields):
"""Add the requested fields to the encoded message."""
if include_fields is None:
return encoded_message
result = json.loads(encoded_message)
for field_name in include_fields:
try:
_GetField(message, field_name.split('.'))
except KeyError:
raise exceptions.InvalidDataError(
'No field named %s in message of type %s' % (
field_name, type(message)))
_SetFieldToNone(result, field_name.split('.'))
return json.dumps(result)
def _GetFieldCodecs(field, attr):
result = [
getattr(_CUSTOM_FIELD_CODECS.get(field), attr, None),
getattr(_FIELD_TYPE_CODECS.get(type(field)), attr, None),
]
return [x for x in result if x is not None]
class _ProtoJsonApiTools(protojson.ProtoJson):
"""JSON encoder used by apitools clients."""
_INSTANCE = None
@classmethod
def Get(cls):
if cls._INSTANCE is None:
cls._INSTANCE = cls()
return cls._INSTANCE
def decode_message(self, message_type, encoded_message): # pylint: disable=invalid-name
if message_type in _CUSTOM_MESSAGE_CODECS:
return _CUSTOM_MESSAGE_CODECS[message_type].decoder(encoded_message)
# We turn off the default logging in protorpc. We may want to
# remove this later.
old_level = logging.getLogger().level
logging.getLogger().setLevel(logging.ERROR)
result = super(_ProtoJsonApiTools, self).decode_message(
message_type, encoded_message)
logging.getLogger().setLevel(old_level)
result = _ProcessUnknownEnums(result, encoded_message)
result = _ProcessUnknownMessages(result, encoded_message)
return _DecodeUnknownFields(result, encoded_message)
def decode_field(self, field, value): # pylint: disable=g-bad-name
"""Decode the given JSON value.
Args:
field: a messages.Field for the field we're decoding.
value: a python value we'd like to decode.
Returns:
A value suitable for assignment to field.
"""
for decoder in _GetFieldCodecs(field, 'decoder'):
result = decoder(field, value)
value = result.value
if result.complete:
return value
if isinstance(field, messages.MessageField):
field_value = self.decode_message(field.message_type, json.dumps(value))
elif isinstance(field, messages.EnumField):
try:
field_value = super(_ProtoJsonApiTools, self).decode_field(field, value)
except messages.DecodeError:
if not isinstance(value, basestring):
raise
field_value = None
else:
field_value = super(_ProtoJsonApiTools, self).decode_field(field, value)
return field_value
def encode_message(self, message): # pylint: disable=invalid-name
if isinstance(message, messages.FieldList):
return '[%s]' % (', '.join(self.encode_message(x) for x in message))
if type(message) in _CUSTOM_MESSAGE_CODECS:
return _CUSTOM_MESSAGE_CODECS[type(message)].encoder(message)
message = _EncodeUnknownFields(message)
return super(_ProtoJsonApiTools, self).encode_message(message)
def encode_field(self, field, value): # pylint: disable=g-bad-name
"""Encode the given value as JSON.
Args:
field: a messages.Field for the field we're encoding.
value: a value for field.
Returns:
A python value suitable for json.dumps.
"""
for encoder in _GetFieldCodecs(field, 'encoder'):
result = encoder(field, value)
value = result.value
if result.complete:
return value
if isinstance(field, messages.MessageField):
value = json.loads(self.encode_message(value))
return super(_ProtoJsonApiTools, self).encode_field(field, value)
# TODO(user): Fold this and _IncludeFields in as codecs.
def _DecodeUnknownFields(message, encoded_message):
"""Rewrite unknown fields in message into message.destination."""
destination = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message))
if destination is None:
return message
pair_field = message.field_by_name(destination)
if not isinstance(pair_field, messages.MessageField):
raise exceptions.InvalidDataFromServerError(
'Unrecognized fields must be mapped to a compound '
'message type.')
pair_type = pair_field.message_type
# TODO(user): Add more error checking around the pair
# type being exactly what we suspect (field names, etc).
if isinstance(pair_type.value, messages.MessageField):
new_values = _DecodeUnknownMessages(
message, json.loads(encoded_message), pair_type)
else:
new_values = _DecodeUnrecognizedFields(message, pair_type)
setattr(message, destination, new_values)
# We could probably get away with not setting this, but
# why not clear it?
setattr(message, '_Message__unrecognized_fields', {})
return message
def _DecodeUnknownMessages(message, encoded_message, pair_type):
"""Process unknown fields in encoded_message of a message type."""
field_type = pair_type.value.type
new_values = []
all_field_names = [x.name for x in message.all_fields()]
for name, value_dict in encoded_message.iteritems():
if name in all_field_names:
continue
value = PyValueToMessage(field_type, value_dict)
new_pair = pair_type(key=name, value=value)
new_values.append(new_pair)
return new_values
def _DecodeUnrecognizedFields(message, pair_type):
"""Process unrecognized fields in message."""
new_values = []
for unknown_field in message.all_unrecognized_fields():
# TODO(user): Consider validating the variant if
# the assignment below doesn't take care of it. It may
# also be necessary to check it in the case that the
# type has multiple encodings.
value, _ = message.get_unrecognized_field_info(unknown_field)
value_type = pair_type.field_by_name('value')
if isinstance(value_type, messages.MessageField):
decoded_value = DictToMessage(value, pair_type.value.message_type)
else:
decoded_value = value
new_pair = pair_type(key=str(unknown_field), value=decoded_value)
new_values.append(new_pair)
return new_values
def _EncodeUnknownFields(message):
"""Remap unknown fields in message out of message.source."""
source = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message))
if source is None:
return message
result = CopyProtoMessage(message)
pairs_field = message.field_by_name(source)
if not isinstance(pairs_field, messages.MessageField):
raise exceptions.InvalidUserInputError(
'Invalid pairs field %s' % pairs_field)
pairs_type = pairs_field.message_type
value_variant = pairs_type.field_by_name('value').variant
pairs = getattr(message, source)
for pair in pairs:
if value_variant == messages.Variant.MESSAGE:
encoded_value = MessageToDict(pair.value)
else:
encoded_value = pair.value
result.set_unrecognized_field(pair.key, encoded_value, value_variant)
setattr(result, source, [])
return result
def _SafeEncodeBytes(field, value):
"""Encode the bytes in value as urlsafe base64."""
try:
if field.repeated:
result = [base64.urlsafe_b64encode(byte) for byte in value]
else:
result = base64.urlsafe_b64encode(value)
complete = True
except TypeError:
result = value
complete = False
return CodecResult(value=result, complete=complete)
def _SafeDecodeBytes(unused_field, value):
"""Decode the urlsafe base64 value into bytes."""
try:
result = base64.urlsafe_b64decode(str(value))
complete = True
except TypeError:
result = value
complete = False
return CodecResult(value=result, complete=complete)
def _ProcessUnknownEnums(message, encoded_message):
"""Add unknown enum values from encoded_message as unknown fields.
ProtoRPC diverges from the usual protocol buffer behavior here and
doesn't allow unknown fields. Throwing on unknown fields makes it
impossible to let servers add new enum values and stay compatible
with older clients, which isn't reasonable for us. We simply store
unrecognized enum values as unknown fields, and all is well.
Args:
message: Proto message we've decoded thus far.
encoded_message: JSON string we're decoding.
Returns:
message, with any unknown enums stored as unrecognized fields.
"""
if not encoded_message:
return message
decoded_message = json.loads(encoded_message)
for field in message.all_fields():
if (isinstance(field, messages.EnumField) and
field.name in decoded_message and
message.get_assigned_value(field.name) is None):
message.set_unrecognized_field(field.name, decoded_message[field.name],
messages.Variant.ENUM)
return message
def _ProcessUnknownMessages(message, encoded_message):
"""Store any remaining unknown fields as strings.
ProtoRPC currently ignores unknown values for which no type can be
determined (and logs a "No variant found" message). For the purposes
of reserializing, this is quite harmful (since it throws away
information). Here we simply add those as unknown fields of type
string (so that they can easily be reserialized).
Args:
message: Proto message we've decoded thus far.
encoded_message: JSON string we're decoding.
Returns:
message, with any remaining unrecognized fields saved.
"""
if not encoded_message:
return message
decoded_message = json.loads(encoded_message)
message_fields = [x.name for x in message.all_fields()] + list(
message.all_unrecognized_fields())
missing_fields = [x for x in decoded_message.iterkeys()
if x not in message_fields]
for field_name in missing_fields:
message.set_unrecognized_field(field_name, decoded_message[field_name],
messages.Variant.STRING)
return message
RegisterFieldTypeCodec(_SafeEncodeBytes, _SafeDecodeBytes)(messages.BytesField)
|