Massive roll-up of changes. See CHANGES.txt.

This commit is contained in:
kenton@google.com 2009-12-18 02:11:36 +00:00
parent d5cf7b55a6
commit fccb146e3f
153 changed files with 11580 additions and 7435 deletions

View File

@ -1,3 +1,49 @@
2009-12-17 version 2.3.0:
General
* Parsers for repeated numeric fields now always accept both packed and
unpacked input. The [packed=true] option only affects serializers.
Therefore, it is possible to switch a field to packed format without
breaking backwards-compatibility -- as long as all parties are using
protobuf 2.3.0 or above, at least.
* The generic RPC service code generated by the C++, Java, and Python
generators can be disabled via file options:
option cc_generic_services = false;
option java_generic_services = false;
option py_generic_services = false;
This allows plugins to generate alternative code, possibly specific to some
particular RPC implementation.
protoc
* Now supports a plugin system for code generators. Plugins can generate
code for new languages or inject additional code into the output of other
code generators. Plugins are just binaries which accept a protocol buffer
on stdin and write a protocol buffer to stdout, so they may be written in
any language. See src/google/protobuf/compiler/plugin.proto.
* inf, -inf, and nan can now be used as default values for float and double
fields.
C++
* Various speed and code size optimizations.
* DynamicMessageFactory is now fully thread-safe.
* Message::Utf8DebugString() method is like DebugString() but avoids escaping
UTF-8 bytes.
* Compiled-in message types can now contain dynamic extensions, through use
of CodedInputStream::SetExtensionRegistry().
Java
* parseDelimitedFrom() and mergeDelimitedFrom() now detect EOF and return
false/null instead of throwing an exception.
* Fixed some initialization ordering bugs.
* Fixes for OpenJDK 7.
Python
* 10-25 times faster than 2.2.0, still pure-Python.
* Calling a mutating method on a sub-message always instantiates the message
in its parent even if the mutating method doesn't actually mutate anything
(e.g. parsing from an empty string).
* Expanded descriptors a bit.
2009-08-11 version 2.2.0: 2009-08-11 version 2.2.0:
C++ C++

View File

@ -114,18 +114,12 @@ EXTRA_DIST = \
python/google/protobuf/internal/generator_test.py \ python/google/protobuf/internal/generator_test.py \
python/google/protobuf/internal/containers.py \ python/google/protobuf/internal/containers.py \
python/google/protobuf/internal/decoder.py \ python/google/protobuf/internal/decoder.py \
python/google/protobuf/internal/decoder_test.py \
python/google/protobuf/internal/descriptor_test.py \ python/google/protobuf/internal/descriptor_test.py \
python/google/protobuf/internal/encoder.py \ python/google/protobuf/internal/encoder.py \
python/google/protobuf/internal/encoder_test.py \
python/google/protobuf/internal/input_stream.py \
python/google/protobuf/internal/input_stream_test.py \
python/google/protobuf/internal/message_listener.py \ python/google/protobuf/internal/message_listener.py \
python/google/protobuf/internal/message_test.py \ python/google/protobuf/internal/message_test.py \
python/google/protobuf/internal/more_extensions.proto \ python/google/protobuf/internal/more_extensions.proto \
python/google/protobuf/internal/more_messages.proto \ python/google/protobuf/internal/more_messages.proto \
python/google/protobuf/internal/output_stream.py \
python/google/protobuf/internal/output_stream_test.py \
python/google/protobuf/internal/reflection_test.py \ python/google/protobuf/internal/reflection_test.py \
python/google/protobuf/internal/service_reflection_test.py \ python/google/protobuf/internal/service_reflection_test.py \
python/google/protobuf/internal/test_util.py \ python/google/protobuf/internal/test_util.py \

View File

@ -4,6 +4,8 @@
# be included in the distribution. These files are not checked in because they # be included in the distribution. These files are not checked in because they
# are automatically generated. # are automatically generated.
set -e
# Check that we're being run from the right directory. # Check that we're being run from the right directory.
if test ! -f src/google/protobuf/stubs/common.h; then if test ! -f src/google/protobuf/stubs/common.h; then
cat >&2 << __EOF__ cat >&2 << __EOF__
@ -13,6 +15,14 @@ __EOF__
exit 1 exit 1
fi fi
# Check that gtest is present. Usually it is already there since the
# directory is set up as an SVN external.
if test ! -e gtest; then
echo "Google Test not present. Fetching gtest-1.3.0 from the web..."
curl http://googletest.googlecode.com/files/gtest-1.3.0.tar.bz2 | tar jx
mv gtest-1.3.0 gtest
fi
set -ex set -ex
# Temporary hack: Must change C runtime library to "multi-threaded DLL", # Temporary hack: Must change C runtime library to "multi-threaded DLL",

View File

@ -27,5 +27,7 @@ __EOF__
fi fi
cd src cd src
make $@ protoc && ./protoc --cpp_out=dllexport_decl=LIBPROTOBUF_EXPORT:. google/protobuf/descriptor.proto make $@ protoc &&
./protoc --cpp_out=dllexport_decl=LIBPROTOBUF_EXPORT:. google/protobuf/descriptor.proto && \
./protoc --cpp_out=dllexport_decl=LIBPROTOC_EXPORT:. google/protobuf/compiler/plugin.proto
cd .. cd ..

View File

@ -113,6 +113,7 @@
<arg value="../src/google/protobuf/unittest_import_lite.proto" /> <arg value="../src/google/protobuf/unittest_import_lite.proto" />
<arg value="../src/google/protobuf/unittest_lite_imports_nonlite.proto" /> <arg value="../src/google/protobuf/unittest_lite_imports_nonlite.proto" />
<arg value="../src/google/protobuf/unittest_enormous_descriptor.proto" /> <arg value="../src/google/protobuf/unittest_enormous_descriptor.proto" />
<arg value="../src/google/protobuf/unittest_no_generic_services.proto" />
</exec> </exec>
</tasks> </tasks>
<testSourceRoot>target/generated-test-sources</testSourceRoot> <testSourceRoot>target/generated-test-sources</testSourceRoot>

View File

@ -311,6 +311,12 @@ public abstract class AbstractMessage extends AbstractMessageLite
} else { } else {
field = extension.descriptor; field = extension.descriptor;
defaultInstance = extension.defaultInstance; defaultInstance = extension.defaultInstance;
if (defaultInstance == null &&
field.getJavaType() == FieldDescriptor.JavaType.MESSAGE) {
throw new IllegalStateException(
"Message-typed extension lacked default instance: " +
field.getFullName());
}
} }
} else { } else {
field = null; field = null;
@ -319,15 +325,28 @@ public abstract class AbstractMessage extends AbstractMessageLite
field = type.findFieldByNumber(fieldNumber); field = type.findFieldByNumber(fieldNumber);
} }
if (field == null || wireType != boolean unknown = false;
FieldSet.getWireFormatForFieldType( boolean packed = false;
field.getLiteType(), if (field == null) {
field.getOptions().getPacked())) { unknown = true; // Unknown field.
// Unknown field or wrong wire type. Skip. } else if (wireType == FieldSet.getWireFormatForFieldType(
field.getLiteType(),
false /* isPacked */)) {
packed = false;
} else if (field.isPackable() &&
wireType == FieldSet.getWireFormatForFieldType(
field.getLiteType(),
true /* isPacked */)) {
packed = true;
} else {
unknown = true; // Unknown wire type.
}
if (unknown) { // Unknown field or wrong wire type. Skip.
return unknownFields.mergeFieldFrom(tag, input); return unknownFields.mergeFieldFrom(tag, input);
} }
if (field.getOptions().getPacked()) { if (packed) {
final int length = input.readRawVarint32(); final int length = input.readRawVarint32();
final int limit = input.pushLimit(length); final int limit = input.pushLimit(length);
if (field.getLiteType() == WireFormat.FieldType.ENUM) { if (field.getLiteType() == WireFormat.FieldType.ENUM) {
@ -673,13 +692,13 @@ public abstract class AbstractMessage extends AbstractMessageLite
} }
@Override @Override
public BuilderType mergeDelimitedFrom(final InputStream input) public boolean mergeDelimitedFrom(final InputStream input)
throws IOException { throws IOException {
return super.mergeDelimitedFrom(input); return super.mergeDelimitedFrom(input);
} }
@Override @Override
public BuilderType mergeDelimitedFrom( public boolean mergeDelimitedFrom(
final InputStream input, final InputStream input,
final ExtensionRegistryLite extensionRegistry) final ExtensionRegistryLite extensionRegistry)
throws IOException { throws IOException {

View File

@ -86,7 +86,7 @@ public abstract class AbstractMessageLite implements MessageLite {
CodedOutputStream.computeRawVarint32Size(serialized) + serialized); CodedOutputStream.computeRawVarint32Size(serialized) + serialized);
final CodedOutputStream codedOutput = final CodedOutputStream codedOutput =
CodedOutputStream.newInstance(output, bufferSize); CodedOutputStream.newInstance(output, bufferSize);
codedOutput.writeRawVarint32(getSerializedSize()); codedOutput.writeRawVarint32(serialized);
writeTo(codedOutput); writeTo(codedOutput);
codedOutput.flush(); codedOutput.flush();
} }
@ -105,13 +105,7 @@ public abstract class AbstractMessageLite implements MessageLite {
public BuilderType mergeFrom(final CodedInputStream input) public BuilderType mergeFrom(final CodedInputStream input)
throws IOException { throws IOException {
// TODO(kenton): Don't use null here. Currently we have to because return mergeFrom(input, ExtensionRegistryLite.getEmptyRegistry());
// using ExtensionRegistry.getEmptyRegistry() would imply a dependency
// on ExtensionRegistry. However, AbstractMessage overrides this with
// a correct implementation, and lite messages don't yet support
// extensions, so it ends up not mattering for now. It will matter
// once lite messages support extensions.
return mergeFrom(input, null);
} }
// Re-defined here for return type covariance. // Re-defined here for return type covariance.
@ -275,20 +269,24 @@ public abstract class AbstractMessageLite implements MessageLite {
} }
} }
public BuilderType mergeDelimitedFrom( public boolean mergeDelimitedFrom(
final InputStream input, final InputStream input,
final ExtensionRegistryLite extensionRegistry) final ExtensionRegistryLite extensionRegistry)
throws IOException { throws IOException {
final int size = CodedInputStream.readRawVarint32(input); final int firstByte = input.read();
if (firstByte == -1) {
return false;
}
final int size = CodedInputStream.readRawVarint32(firstByte, input);
final InputStream limitedInput = new LimitedInputStream(input, size); final InputStream limitedInput = new LimitedInputStream(input, size);
return mergeFrom(limitedInput, extensionRegistry); mergeFrom(limitedInput, extensionRegistry);
return true;
} }
public BuilderType mergeDelimitedFrom(final InputStream input) public boolean mergeDelimitedFrom(final InputStream input)
throws IOException { throws IOException {
final int size = CodedInputStream.readRawVarint32(input); return mergeDelimitedFrom(input,
final InputStream limitedInput = new LimitedInputStream(input, size); ExtensionRegistryLite.getEmptyRegistry());
return mergeFrom(limitedInput);
} }
/** /**

View File

@ -98,6 +98,24 @@ public final class ByteString {
return copyFrom(bytes, 0, bytes.length); return copyFrom(bytes, 0, bytes.length);
} }
/**
* Copies {@code size} bytes from a {@code java.nio.ByteBuffer} into
* a {@code ByteString}.
*/
public static ByteString copyFrom(final ByteBuffer bytes, final int size) {
final byte[] copy = new byte[size];
bytes.get(copy);
return new ByteString(copy);
}
/**
* Copies the remaining bytes from a {@code java.nio.ByteBuffer} into
* a {@code ByteString}.
*/
public static ByteString copyFrom(final ByteBuffer bytes) {
return copyFrom(bytes, bytes.remaining());
}
/** /**
* Encodes {@code text} into a sequence of bytes using the named charset * Encodes {@code text} into a sequence of bytes using the named charset
* and returns the result as a {@code ByteString}. * and returns the result as a {@code ByteString}.

View File

@ -84,8 +84,9 @@ public final class CodedInputStream {
} }
lastTag = readRawVarint32(); lastTag = readRawVarint32();
if (lastTag == 0) { if (WireFormat.getTagFieldNumber(lastTag) == 0) {
// If we actually read zero, that's not a valid tag. // If we actually read zero (or any tag number corresponding to field
// number zero), that's not a valid tag.
throw InvalidProtocolBufferException.invalidTag(); throw InvalidProtocolBufferException.invalidTag();
} }
return lastTag; return lastTag;
@ -355,8 +356,26 @@ public final class CodedInputStream {
* CodedInputStream buffers its input. * CodedInputStream buffers its input.
*/ */
static int readRawVarint32(final InputStream input) throws IOException { static int readRawVarint32(final InputStream input) throws IOException {
int result = 0; final int firstByte = input.read();
int offset = 0; if (firstByte == -1) {
throw InvalidProtocolBufferException.truncatedMessage();
}
return readRawVarint32(firstByte, input);
}
/**
* Like {@link #readRawVarint32(InputStream)}, but expects that the caller
* has already read one byte. This allows the caller to determine if EOF
* has been reached before attempting to read.
*/
static int readRawVarint32(final int firstByte,
final InputStream input) throws IOException {
if ((firstByte & 0x80) == 0) {
return firstByte;
}
int result = firstByte & 0x7f;
int offset = 7;
for (; offset < 32; offset += 7) { for (; offset < 32; offset += 7) {
final int b = input.read(); final int b = input.read();
if (b == -1) { if (b == -1) {

View File

@ -48,7 +48,7 @@ import java.io.UnsupportedEncodingException;
* (given a message object of the type) {@code message.getDescriptorForType()}. * (given a message object of the type) {@code message.getDescriptorForType()}.
* *
* Descriptors are built from DescriptorProtos, as defined in * Descriptors are built from DescriptorProtos, as defined in
* {@code net/proto2/proto/descriptor.proto}. * {@code google/protobuf/descriptor.proto}.
* *
* @author kenton@google.com Kenton Varda * @author kenton@google.com Kenton Varda
*/ */
@ -699,6 +699,11 @@ public final class Descriptors {
return getOptions().getPacked(); return getOptions().getPacked();
} }
/** Can this field be packed? i.e. is it a repeated primitive field? */
public boolean isPackable() {
return isRepeated() && getLiteType().isPackable();
}
/** Returns true if the field had an explicitly-defined default value. */ /** Returns true if the field had an explicitly-defined default value. */
public boolean hasDefaultValue() { return proto.hasDefaultValue(); } public boolean hasDefaultValue() { return proto.hasDefaultValue(); }
@ -810,39 +815,34 @@ public final class Descriptors {
private Object defaultValue; private Object defaultValue;
public enum Type { public enum Type {
DOUBLE (FieldDescriptorProto.Type.TYPE_DOUBLE , JavaType.DOUBLE ), DOUBLE (JavaType.DOUBLE ),
FLOAT (FieldDescriptorProto.Type.TYPE_FLOAT , JavaType.FLOAT ), FLOAT (JavaType.FLOAT ),
INT64 (FieldDescriptorProto.Type.TYPE_INT64 , JavaType.LONG ), INT64 (JavaType.LONG ),
UINT64 (FieldDescriptorProto.Type.TYPE_UINT64 , JavaType.LONG ), UINT64 (JavaType.LONG ),
INT32 (FieldDescriptorProto.Type.TYPE_INT32 , JavaType.INT ), INT32 (JavaType.INT ),
FIXED64 (FieldDescriptorProto.Type.TYPE_FIXED64 , JavaType.LONG ), FIXED64 (JavaType.LONG ),
FIXED32 (FieldDescriptorProto.Type.TYPE_FIXED32 , JavaType.INT ), FIXED32 (JavaType.INT ),
BOOL (FieldDescriptorProto.Type.TYPE_BOOL , JavaType.BOOLEAN ), BOOL (JavaType.BOOLEAN ),
STRING (FieldDescriptorProto.Type.TYPE_STRING , JavaType.STRING ), STRING (JavaType.STRING ),
GROUP (FieldDescriptorProto.Type.TYPE_GROUP , JavaType.MESSAGE ), GROUP (JavaType.MESSAGE ),
MESSAGE (FieldDescriptorProto.Type.TYPE_MESSAGE , JavaType.MESSAGE ), MESSAGE (JavaType.MESSAGE ),
BYTES (FieldDescriptorProto.Type.TYPE_BYTES , JavaType.BYTE_STRING), BYTES (JavaType.BYTE_STRING),
UINT32 (FieldDescriptorProto.Type.TYPE_UINT32 , JavaType.INT ), UINT32 (JavaType.INT ),
ENUM (FieldDescriptorProto.Type.TYPE_ENUM , JavaType.ENUM ), ENUM (JavaType.ENUM ),
SFIXED32(FieldDescriptorProto.Type.TYPE_SFIXED32, JavaType.INT ), SFIXED32(JavaType.INT ),
SFIXED64(FieldDescriptorProto.Type.TYPE_SFIXED64, JavaType.LONG ), SFIXED64(JavaType.LONG ),
SINT32 (FieldDescriptorProto.Type.TYPE_SINT32 , JavaType.INT ), SINT32 (JavaType.INT ),
SINT64 (FieldDescriptorProto.Type.TYPE_SINT64 , JavaType.LONG ); SINT64 (JavaType.LONG );
Type(final FieldDescriptorProto.Type proto, final JavaType javaType) { Type(final JavaType javaType) {
this.proto = proto;
this.javaType = javaType; this.javaType = javaType;
if (ordinal() != proto.getNumber() - 1) {
throw new RuntimeException(
"descriptor.proto changed but Desrciptors.java wasn't updated.");
}
} }
private FieldDescriptorProto.Type proto;
private JavaType javaType; private JavaType javaType;
public FieldDescriptorProto.Type toProto() { return proto; } public FieldDescriptorProto.Type toProto() {
return FieldDescriptorProto.Type.valueOf(ordinal() + 1);
}
public JavaType getJavaType() { return javaType; } public JavaType getJavaType() { return javaType; }
public static Type valueOf(final FieldDescriptorProto.Type type) { public static Type valueOf(final FieldDescriptorProto.Type type) {
@ -902,16 +902,10 @@ public final class Descriptors {
} }
// Only repeated primitive fields may be packed. // Only repeated primitive fields may be packed.
if (proto.getOptions().getPacked()) { if (proto.getOptions().getPacked() && !isPackable()) {
if (proto.getLabel() != FieldDescriptorProto.Label.LABEL_REPEATED || throw new DescriptorValidationException(this,
proto.getType() == FieldDescriptorProto.Type.TYPE_STRING || "[packed = true] can only be specified for repeated primitive " +
proto.getType() == FieldDescriptorProto.Type.TYPE_GROUP || "fields.");
proto.getType() == FieldDescriptorProto.Type.TYPE_MESSAGE ||
proto.getType() == FieldDescriptorProto.Type.TYPE_BYTES) {
throw new DescriptorValidationException(this,
"[packed = true] can only be specified for repeated primitive " +
"fields.");
}
} }
if (isExtension) { if (isExtension) {
@ -1030,10 +1024,26 @@ public final class Descriptors {
defaultValue = TextFormat.parseUInt64(proto.getDefaultValue()); defaultValue = TextFormat.parseUInt64(proto.getDefaultValue());
break; break;
case FLOAT: case FLOAT:
defaultValue = Float.valueOf(proto.getDefaultValue()); if (proto.getDefaultValue().equals("inf")) {
defaultValue = Float.POSITIVE_INFINITY;
} else if (proto.getDefaultValue().equals("-inf")) {
defaultValue = Float.NEGATIVE_INFINITY;
} else if (proto.getDefaultValue().equals("nan")) {
defaultValue = Float.NaN;
} else {
defaultValue = Float.valueOf(proto.getDefaultValue());
}
break; break;
case DOUBLE: case DOUBLE:
defaultValue = Double.valueOf(proto.getDefaultValue()); if (proto.getDefaultValue().equals("inf")) {
defaultValue = Double.POSITIVE_INFINITY;
} else if (proto.getDefaultValue().equals("-inf")) {
defaultValue = Double.NEGATIVE_INFINITY;
} else if (proto.getDefaultValue().equals("nan")) {
defaultValue = Double.NaN;
} else {
defaultValue = Double.valueOf(proto.getDefaultValue());
}
break; break;
case BOOL: case BOOL:
defaultValue = Boolean.valueOf(proto.getDefaultValue()); defaultValue = Boolean.valueOf(proto.getDefaultValue());
@ -1064,12 +1074,9 @@ public final class Descriptors {
"Message type had default value."); "Message type had default value.");
} }
} catch (NumberFormatException e) { } catch (NumberFormatException e) {
final DescriptorValidationException validationException = throw new DescriptorValidationException(this,
new DescriptorValidationException(this, "Could not parse default value: \"" +
"Could not parse default value: \"" + proto.getDefaultValue() + '\"', e);
proto.getDefaultValue() + '\"');
validationException.initCause(e);
throw validationException;
} }
} else { } else {
// Determine the default default for this field. // Determine the default default for this field.
@ -1536,14 +1543,7 @@ public final class Descriptors {
private DescriptorValidationException( private DescriptorValidationException(
final GenericDescriptor problemDescriptor, final GenericDescriptor problemDescriptor,
final String description) { final String description) {
this(problemDescriptor, description, null); super(problemDescriptor.getFullName() + ": " + description);
}
private DescriptorValidationException(
final GenericDescriptor problemDescriptor,
final String description,
final Throwable cause) {
super(problemDescriptor.getFullName() + ": " + description, cause);
// Note that problemDescriptor may be partially uninitialized, so we // Note that problemDescriptor may be partially uninitialized, so we
// don't want to expose it directly to the user. So, we only provide // don't want to expose it directly to the user. So, we only provide
@ -1553,6 +1553,14 @@ public final class Descriptors {
this.description = description; this.description = description;
} }
private DescriptorValidationException(
final GenericDescriptor problemDescriptor,
final String description,
final Throwable cause) {
this(problemDescriptor, description);
initCause(cause);
}
private DescriptorValidationException( private DescriptorValidationException(
final FileDescriptor problemDescriptor, final FileDescriptor problemDescriptor,
final String description) { final String description) {

View File

@ -157,6 +157,11 @@ public final class ExtensionRegistry extends ExtensionRegistryLite {
public void add(final GeneratedMessage.GeneratedExtension<?, ?> extension) { public void add(final GeneratedMessage.GeneratedExtension<?, ?> extension) {
if (extension.getDescriptor().getJavaType() == if (extension.getDescriptor().getJavaType() ==
FieldDescriptor.JavaType.MESSAGE) { FieldDescriptor.JavaType.MESSAGE) {
if (extension.getMessageDefaultInstance() == null) {
throw new IllegalStateException(
"Registered message-type extension had null default instance: " +
extension.getDescriptor().getFullName());
}
add(new ExtensionInfo(extension.getDescriptor(), add(new ExtensionInfo(extension.getDescriptor(),
extension.getMessageDefaultInstance())); extension.getMessageDefaultInstance()));
} else { } else {

View File

@ -789,6 +789,10 @@ public abstract class GeneratedMessage extends AbstractMessage {
messageDefaultInstance = messageDefaultInstance =
(Message) invokeOrDie(getMethodOrDie(type, "getDefaultInstance"), (Message) invokeOrDie(getMethodOrDie(type, "getDefaultInstance"),
null); null);
if (messageDefaultInstance == null) {
throw new IllegalStateException(
type.getName() + ".getDefaultInstance() returned null.");
}
break; break;
case ENUM: case ENUM:
enumValueOf = getMethodOrDie(type, "valueOf", enumValueOf = getMethodOrDie(type, "valueOf",

View File

@ -303,7 +303,7 @@ public abstract class GeneratedMessageLite extends AbstractMessageLite {
final ExtensionRegistryLite extensionRegistry, final ExtensionRegistryLite extensionRegistry,
final int tag) throws IOException { final int tag) throws IOException {
final FieldSet<ExtensionDescriptor> extensions = final FieldSet<ExtensionDescriptor> extensions =
internalGetResult().extensions; ((ExtendableMessage) internalGetResult()).extensions;
final int wireType = WireFormat.getTagWireType(tag); final int wireType = WireFormat.getTagWireType(tag);
final int fieldNumber = WireFormat.getTagFieldNumber(tag); final int fieldNumber = WireFormat.getTagFieldNumber(tag);
@ -312,15 +312,29 @@ public abstract class GeneratedMessageLite extends AbstractMessageLite {
extensionRegistry.findLiteExtensionByNumber( extensionRegistry.findLiteExtensionByNumber(
getDefaultInstanceForType(), fieldNumber); getDefaultInstanceForType(), fieldNumber);
if (extension == null || wireType != boolean unknown = false;
FieldSet.getWireFormatForFieldType( boolean packed = false;
extension.descriptor.getLiteType(), if (extension == null) {
extension.descriptor.isPacked())) { unknown = true; // Unknown field.
// Unknown field or wrong wire type. Skip. } else if (wireType == FieldSet.getWireFormatForFieldType(
extension.descriptor.getLiteType(),
false /* isPacked */)) {
packed = false; // Normal, unpacked value.
} else if (extension.descriptor.isRepeated &&
extension.descriptor.type.isPackable() &&
wireType == FieldSet.getWireFormatForFieldType(
extension.descriptor.getLiteType(),
true /* isPacked */)) {
packed = true; // Packed value.
} else {
unknown = true; // Wrong wire type.
}
if (unknown) { // Unknown field or wrong wire type. Skip.
return input.skipField(tag); return input.skipField(tag);
} }
if (extension.descriptor.isPacked()) { if (packed) {
final int length = input.readRawVarint32(); final int length = input.readRawVarint32();
final int limit = input.pushLimit(length); final int limit = input.pushLimit(length);
if (extension.descriptor.getLiteType() == WireFormat.FieldType.ENUM) { if (extension.descriptor.getLiteType() == WireFormat.FieldType.ENUM) {
@ -396,7 +410,8 @@ public abstract class GeneratedMessageLite extends AbstractMessageLite {
} }
protected final void mergeExtensionFields(final MessageType other) { protected final void mergeExtensionFields(final MessageType other) {
internalGetResult().extensions.mergeFrom(other.extensions); ((ExtendableMessage) internalGetResult()).extensions.mergeFrom(
((ExtendableMessage) other).extensions);
} }
} }

View File

@ -296,9 +296,9 @@ public interface Message extends MessageLite {
Builder mergeFrom(InputStream input, Builder mergeFrom(InputStream input,
ExtensionRegistryLite extensionRegistry) ExtensionRegistryLite extensionRegistry)
throws IOException; throws IOException;
Builder mergeDelimitedFrom(InputStream input) boolean mergeDelimitedFrom(InputStream input)
throws IOException; throws IOException;
Builder mergeDelimitedFrom(InputStream input, boolean mergeDelimitedFrom(InputStream input,
ExtensionRegistryLite extensionRegistry) ExtensionRegistryLite extensionRegistry)
throws IOException; throws IOException;
} }

View File

@ -317,14 +317,18 @@ public interface MessageLite {
* then the message data. Use * then the message data. Use
* {@link MessageLite#writeDelimitedTo(OutputStream)} to write messages in * {@link MessageLite#writeDelimitedTo(OutputStream)} to write messages in
* this format. * this format.
*
* @returns True if successful, or false if the stream is at EOF when the
* method starts. Any other error (including reaching EOF during
* parsing) will cause an exception to be thrown.
*/ */
Builder mergeDelimitedFrom(InputStream input) boolean mergeDelimitedFrom(InputStream input)
throws IOException; throws IOException;
/** /**
* Like {@link #mergeDelimitedFrom(InputStream)} but supporting extensions. * Like {@link #mergeDelimitedFrom(InputStream)} but supporting extensions.
*/ */
Builder mergeDelimitedFrom(InputStream input, boolean mergeDelimitedFrom(InputStream input,
ExtensionRegistryLite extensionRegistry) ExtensionRegistryLite extensionRegistry)
throws IOException; throws IOException;
} }

View File

@ -426,7 +426,7 @@ public final class TextFormat {
Pattern.compile("(\\s|(#.*$))++", Pattern.MULTILINE); Pattern.compile("(\\s|(#.*$))++", Pattern.MULTILINE);
private static final Pattern TOKEN = Pattern.compile( private static final Pattern TOKEN = Pattern.compile(
"[a-zA-Z_][0-9a-zA-Z_+-]*+|" + // an identifier "[a-zA-Z_][0-9a-zA-Z_+-]*+|" + // an identifier
"[0-9+-][0-9a-zA-Z_.+-]*+|" + // a number "[.]?[0-9+-][0-9a-zA-Z_.+-]*+|" + // a number
"\"([^\"\n\\\\]|\\\\.)*+(\"|\\\\?$)|" + // a double-quoted string "\"([^\"\n\\\\]|\\\\.)*+(\"|\\\\?$)|" + // a double-quoted string
"\'([^\"\n\\\\]|\\\\.)*+(\'|\\\\?$)", // a single-quoted string "\'([^\"\n\\\\]|\\\\.)*+(\'|\\\\?$)", // a single-quoted string
Pattern.MULTILINE); Pattern.MULTILINE);

View File

@ -30,6 +30,8 @@
package com.google.protobuf; package com.google.protobuf;
import com.google.protobuf.AbstractMessageLite.Builder.LimitedInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
@ -551,19 +553,23 @@ public final class UnknownFieldSet implements MessageLite {
return this; return this;
} }
public Builder mergeDelimitedFrom(InputStream input) public boolean mergeDelimitedFrom(InputStream input)
throws IOException { throws IOException {
final int size = CodedInputStream.readRawVarint32(input); final int firstByte = input.read();
final InputStream limitedInput = if (firstByte == -1) {
new AbstractMessage.Builder.LimitedInputStream(input, size); return false;
return mergeFrom(limitedInput, null); }
final int size = CodedInputStream.readRawVarint32(firstByte, input);
final InputStream limitedInput = new LimitedInputStream(input, size);
mergeFrom(limitedInput);
return true;
} }
public Builder mergeDelimitedFrom( public boolean mergeDelimitedFrom(
InputStream input, InputStream input,
ExtensionRegistryLite extensionRegistry) throws IOException { ExtensionRegistryLite extensionRegistry) throws IOException {
// UnknownFieldSet has no extensions. // UnknownFieldSet has no extensions.
return mergeFrom(input); return mergeDelimitedFrom(input);
} }
public Builder mergeFrom( public Builder mergeFrom(

View File

@ -113,10 +113,18 @@ public final class WireFormat {
FIXED64 (JavaType.LONG , WIRETYPE_FIXED64 ), FIXED64 (JavaType.LONG , WIRETYPE_FIXED64 ),
FIXED32 (JavaType.INT , WIRETYPE_FIXED32 ), FIXED32 (JavaType.INT , WIRETYPE_FIXED32 ),
BOOL (JavaType.BOOLEAN , WIRETYPE_VARINT ), BOOL (JavaType.BOOLEAN , WIRETYPE_VARINT ),
STRING (JavaType.STRING , WIRETYPE_LENGTH_DELIMITED), STRING (JavaType.STRING , WIRETYPE_LENGTH_DELIMITED) {
GROUP (JavaType.MESSAGE , WIRETYPE_START_GROUP ), public boolean isPackable() { return false; }
MESSAGE (JavaType.MESSAGE , WIRETYPE_LENGTH_DELIMITED), },
BYTES (JavaType.BYTE_STRING, WIRETYPE_LENGTH_DELIMITED), GROUP (JavaType.MESSAGE , WIRETYPE_START_GROUP ) {
public boolean isPackable() { return false; }
},
MESSAGE (JavaType.MESSAGE , WIRETYPE_LENGTH_DELIMITED) {
public boolean isPackable() { return false; }
},
BYTES (JavaType.BYTE_STRING, WIRETYPE_LENGTH_DELIMITED) {
public boolean isPackable() { return false; }
},
UINT32 (JavaType.INT , WIRETYPE_VARINT ), UINT32 (JavaType.INT , WIRETYPE_VARINT ),
ENUM (JavaType.ENUM , WIRETYPE_VARINT ), ENUM (JavaType.ENUM , WIRETYPE_VARINT ),
SFIXED32(JavaType.INT , WIRETYPE_FIXED32 ), SFIXED32(JavaType.INT , WIRETYPE_FIXED32 ),
@ -134,6 +142,8 @@ public final class WireFormat {
public JavaType getJavaType() { return javaType; } public JavaType getJavaType() { return javaType; }
public int getWireType() { return wireType; } public int getWireType() { return wireType; }
public boolean isPackable() { return true; }
} }
// Field numbers for feilds in MessageSet wire format. // Field numbers for feilds in MessageSet wire format.

View File

@ -38,6 +38,7 @@ import protobuf_unittest.UnittestProto.TestAllTypes;
import protobuf_unittest.UnittestProto.TestPackedTypes; import protobuf_unittest.UnittestProto.TestPackedTypes;
import protobuf_unittest.UnittestProto.TestRequired; import protobuf_unittest.UnittestProto.TestRequired;
import protobuf_unittest.UnittestProto.TestRequiredForeign; import protobuf_unittest.UnittestProto.TestRequiredForeign;
import protobuf_unittest.UnittestProto.TestUnpackedTypes;
import junit.framework.TestCase; import junit.framework.TestCase;
@ -238,6 +239,43 @@ public class AbstractMessageTest extends TestCase {
TestUtil.assertPackedFieldsSet((TestPackedTypes) message.wrappedMessage); TestUtil.assertPackedFieldsSet((TestPackedTypes) message.wrappedMessage);
} }
public void testUnpackedSerialization() throws Exception {
Message abstractMessage =
new AbstractMessageWrapper(TestUtil.getUnpackedSet());
TestUtil.assertUnpackedFieldsSet(
TestUnpackedTypes.parseFrom(abstractMessage.toByteString()));
assertEquals(TestUtil.getUnpackedSet().toByteString(),
abstractMessage.toByteString());
}
public void testParsePackedToUnpacked() throws Exception {
AbstractMessageWrapper.Builder builder =
new AbstractMessageWrapper.Builder(TestUnpackedTypes.newBuilder());
AbstractMessageWrapper message =
builder.mergeFrom(TestUtil.getPackedSet().toByteString()).build();
TestUtil.assertUnpackedFieldsSet(
(TestUnpackedTypes) message.wrappedMessage);
}
public void testParseUnpackedToPacked() throws Exception {
AbstractMessageWrapper.Builder builder =
new AbstractMessageWrapper.Builder(TestPackedTypes.newBuilder());
AbstractMessageWrapper message =
builder.mergeFrom(TestUtil.getUnpackedSet().toByteString()).build();
TestUtil.assertPackedFieldsSet((TestPackedTypes) message.wrappedMessage);
}
public void testUnpackedParsing() throws Exception {
AbstractMessageWrapper.Builder builder =
new AbstractMessageWrapper.Builder(TestUnpackedTypes.newBuilder());
AbstractMessageWrapper message =
builder.mergeFrom(TestUtil.getUnpackedSet().toByteString()).build();
TestUtil.assertUnpackedFieldsSet(
(TestUnpackedTypes) message.wrappedMessage);
}
public void testOptimizedForSize() throws Exception { public void testOptimizedForSize() throws Exception {
// We're mostly only checking that this class was compiled successfully. // We're mostly only checking that this class was compiled successfully.
TestOptimizedForSize message = TestOptimizedForSize message =

View File

@ -490,4 +490,18 @@ public class CodedInputStreamTest extends TestCase {
assertEquals(0, in.readTag()); assertEquals(0, in.readTag());
assertEquals(5, in.getTotalBytesRead()); assertEquals(5, in.getTotalBytesRead());
} }
public void testInvalidTag() throws Exception {
// Any tag number which corresponds to field number zero is invalid and
// should throw InvalidProtocolBufferException.
for (int i = 0; i < 8; i++) {
try {
CodedInputStream.newInstance(bytes(i)).readTag();
fail("Should have thrown an exception.");
} catch (InvalidProtocolBufferException e) {
assertEquals(InvalidProtocolBufferException.invalidTag().getMessage(),
e.getMessage());
}
}
}
} }

View File

@ -30,6 +30,10 @@
package com.google.protobuf; package com.google.protobuf;
import com.google.protobuf.DescriptorProtos.DescriptorProto;
import com.google.protobuf.DescriptorProtos.FieldDescriptorProto;
import com.google.protobuf.DescriptorProtos.FileDescriptorProto;
import com.google.protobuf.Descriptors.DescriptorValidationException;
import com.google.protobuf.Descriptors.FileDescriptor; import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.Descriptors.FieldDescriptor;
@ -63,6 +67,22 @@ import java.util.Collections;
* @author kenton@google.com Kenton Varda * @author kenton@google.com Kenton Varda
*/ */
public class DescriptorsTest extends TestCase { public class DescriptorsTest extends TestCase {
// Regression test for bug where referencing a FieldDescriptor.Type value
// before a FieldDescriptorProto.Type value would yield a
// ExceptionInInitializerError.
private static final Object STATIC_INIT_TEST = FieldDescriptor.Type.BOOL;
public void testFieldTypeEnumMapping() throws Exception {
assertEquals(FieldDescriptor.Type.values().length,
FieldDescriptorProto.Type.values().length);
for (FieldDescriptor.Type type : FieldDescriptor.Type.values()) {
FieldDescriptorProto.Type protoType = type.toProto();
assertEquals("TYPE_" + type.name(), protoType.name());
assertEquals(type, FieldDescriptor.Type.valueOf(protoType));
}
}
public void testFileDescriptor() throws Exception { public void testFileDescriptor() throws Exception {
FileDescriptor file = UnittestProto.getDescriptor(); FileDescriptor file = UnittestProto.getDescriptor();
@ -405,4 +425,35 @@ public class DescriptorsTest extends TestCase {
UnittestEnormousDescriptor.getDescriptor() UnittestEnormousDescriptor.getDescriptor()
.toProto().getSerializedSize() > 65536); .toProto().getSerializedSize() > 65536);
} }
/**
* Tests that the DescriptorValidationException works as intended.
*/
public void testDescriptorValidatorException() throws Exception {
FileDescriptorProto fileDescriptorProto = FileDescriptorProto.newBuilder()
.setName("foo.proto")
.addMessageType(DescriptorProto.newBuilder()
.setName("Foo")
.addField(FieldDescriptorProto.newBuilder()
.setLabel(FieldDescriptorProto.Label.LABEL_OPTIONAL)
.setType(FieldDescriptorProto.Type.TYPE_INT32)
.setName("foo")
.setNumber(1)
.setDefaultValue("invalid")
.build())
.build())
.build();
try {
Descriptors.FileDescriptor.buildFrom(fileDescriptorProto,
new FileDescriptor[0]);
fail("DescriptorValidationException expected");
} catch (DescriptorValidationException e) {
// Expected; check that the error message contains some useful hints
assertTrue(e.getMessage().indexOf("foo") != -1);
assertTrue(e.getMessage().indexOf("Foo") != -1);
assertTrue(e.getMessage().indexOf("invalid") != -1);
assertTrue(e.getCause() instanceof NumberFormatException);
assertTrue(e.getCause().getMessage().indexOf("invalid") != -1);
}
}
} }

View File

@ -39,6 +39,8 @@ import protobuf_unittest.UnittestProto.ForeignEnum;
import protobuf_unittest.UnittestProto.TestAllTypes; import protobuf_unittest.UnittestProto.TestAllTypes;
import protobuf_unittest.UnittestProto.TestAllExtensions; import protobuf_unittest.UnittestProto.TestAllExtensions;
import protobuf_unittest.UnittestProto.TestExtremeDefaultValues; import protobuf_unittest.UnittestProto.TestExtremeDefaultValues;
import protobuf_unittest.UnittestProto.TestPackedTypes;
import protobuf_unittest.UnittestProto.TestUnpackedTypes;
import protobuf_unittest.MultipleFilesTestProto; import protobuf_unittest.MultipleFilesTestProto;
import protobuf_unittest.MessageWithNoOuter; import protobuf_unittest.MessageWithNoOuter;
import protobuf_unittest.EnumWithNoOuter; import protobuf_unittest.EnumWithNoOuter;
@ -303,8 +305,15 @@ public class GeneratedMessageTest extends TestCase {
TestUtil.assertClear(TestAllTypes.getDefaultInstance()); TestUtil.assertClear(TestAllTypes.getDefaultInstance());
TestUtil.assertClear(TestAllTypes.newBuilder().build()); TestUtil.assertClear(TestAllTypes.newBuilder().build());
assertEquals("\u1234", TestExtremeDefaultValues message =
TestExtremeDefaultValues.getDefaultInstance().getUtf8String()); TestExtremeDefaultValues.getDefaultInstance();
assertEquals("\u1234", message.getUtf8String());
assertEquals(Double.POSITIVE_INFINITY, message.getInfDouble());
assertEquals(Double.NEGATIVE_INFINITY, message.getNegInfDouble());
assertTrue(Double.isNaN(message.getNanDouble()));
assertEquals(Float.POSITIVE_INFINITY, message.getInfFloat());
assertEquals(Float.NEGATIVE_INFINITY, message.getNegInfFloat());
assertTrue(Float.isNaN(message.getNanFloat()));
} }
public void testReflectionGetters() throws Exception { public void testReflectionGetters() throws Exception {
@ -361,6 +370,20 @@ public class GeneratedMessageTest extends TestCase {
assertTrue(map.findValueByNumber(12345) == null); assertTrue(map.findValueByNumber(12345) == null);
} }
public void testParsePackedToUnpacked() throws Exception {
TestUnpackedTypes.Builder builder = TestUnpackedTypes.newBuilder();
TestUnpackedTypes message =
builder.mergeFrom(TestUtil.getPackedSet().toByteString()).build();
TestUtil.assertUnpackedFieldsSet(message);
}
public void testParseUnpackedToPacked() throws Exception {
TestPackedTypes.Builder builder = TestPackedTypes.newBuilder();
TestPackedTypes message =
builder.mergeFrom(TestUtil.getUnpackedSet().toByteString()).build();
TestUtil.assertPackedFieldsSet(message);
}
// ================================================================= // =================================================================
// Extensions. // Extensions.
@ -615,4 +638,12 @@ public class GeneratedMessageTest extends TestCase {
UnittestProto.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48); UnittestProto.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48);
assertEquals(UnittestProto.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 51); assertEquals(UnittestProto.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 51);
} }
public void testRecursiveMessageDefaultInstance() throws Exception {
UnittestProto.TestRecursiveMessage message =
UnittestProto.TestRecursiveMessage.getDefaultInstance();
assertTrue(message != null);
assertTrue(message.getA() != null);
assertTrue(message.getA() == message);
}
} }

View File

@ -30,7 +30,9 @@
package com.google.protobuf; package com.google.protobuf;
import com.google.protobuf.Descriptors.FileDescriptor;
import com.google.protobuf.Descriptors.MethodDescriptor; import com.google.protobuf.Descriptors.MethodDescriptor;
import google.protobuf.no_generic_services_test.UnittestNoGenericServices;
import protobuf_unittest.MessageWithNoOuter; import protobuf_unittest.MessageWithNoOuter;
import protobuf_unittest.ServiceWithNoOuter; import protobuf_unittest.ServiceWithNoOuter;
import protobuf_unittest.UnittestProto.TestAllTypes; import protobuf_unittest.UnittestProto.TestAllTypes;
@ -44,6 +46,9 @@ import org.easymock.classextension.EasyMock;
import org.easymock.classextension.IMocksControl; import org.easymock.classextension.IMocksControl;
import org.easymock.IArgumentMatcher; import org.easymock.IArgumentMatcher;
import java.util.HashSet;
import java.util.Set;
import junit.framework.TestCase; import junit.framework.TestCase;
/** /**
@ -220,6 +225,48 @@ public class ServiceTest extends TestCase {
control.verify(); control.verify();
} }
public void testNoGenericServices() throws Exception {
// Non-services should be usable.
UnittestNoGenericServices.TestMessage message =
UnittestNoGenericServices.TestMessage.newBuilder()
.setA(123)
.setExtension(UnittestNoGenericServices.testExtension, 456)
.build();
assertEquals(123, message.getA());
assertEquals(1, UnittestNoGenericServices.TestEnum.FOO.getNumber());
// Build a list of the class names nested in UnittestNoGenericServices.
String outerName = "google.protobuf.no_generic_services_test." +
"UnittestNoGenericServices";
Class<?> outerClass = Class.forName(outerName);
Set<String> innerClassNames = new HashSet<String>();
for (Class<?> innerClass : outerClass.getClasses()) {
String fullName = innerClass.getName();
// Figure out the unqualified name of the inner class.
// Note: Surprisingly, the full name of an inner class will be separated
// from the outer class name by a '$' rather than a '.'. This is not
// mentioned in the documentation for java.lang.Class. I don't want to
// make assumptions, so I'm just going to accept any character as the
// separator.
assertTrue(fullName.startsWith(outerName));
innerClassNames.add(fullName.substring(outerName.length() + 1));
}
// No service class should have been generated.
assertTrue(innerClassNames.contains("TestMessage"));
assertTrue(innerClassNames.contains("TestEnum"));
assertFalse(innerClassNames.contains("TestService"));
// But descriptors are there.
FileDescriptor file = UnittestNoGenericServices.getDescriptor();
assertEquals(1, file.getServices().size());
assertEquals("TestService", file.getServices().get(0).getName());
assertEquals(1, file.getServices().get(0).getMethods().size());
assertEquals("Foo",
file.getServices().get(0).getMethods().get(0).getName());
}
// ================================================================= // =================================================================
/** /**

View File

@ -217,6 +217,7 @@ import protobuf_unittest.UnittestProto.TestAllExtensions;
import protobuf_unittest.UnittestProto.TestAllTypes; import protobuf_unittest.UnittestProto.TestAllTypes;
import protobuf_unittest.UnittestProto.TestPackedExtensions; import protobuf_unittest.UnittestProto.TestPackedExtensions;
import protobuf_unittest.UnittestProto.TestPackedTypes; import protobuf_unittest.UnittestProto.TestPackedTypes;
import protobuf_unittest.UnittestProto.TestUnpackedTypes;
import protobuf_unittest.UnittestProto.ForeignMessage; import protobuf_unittest.UnittestProto.ForeignMessage;
import protobuf_unittest.UnittestProto.ForeignEnum; import protobuf_unittest.UnittestProto.ForeignEnum;
import com.google.protobuf.test.UnittestImport.ImportMessage; import com.google.protobuf.test.UnittestImport.ImportMessage;
@ -289,6 +290,12 @@ class TestUtil {
return builder.build(); return builder.build();
} }
public static TestUnpackedTypes getUnpackedSet() {
TestUnpackedTypes.Builder builder = TestUnpackedTypes.newBuilder();
setUnpackedFields(builder);
return builder.build();
}
public static TestPackedExtensions getPackedExtensionsSet() { public static TestPackedExtensions getPackedExtensionsSet() {
TestPackedExtensions.Builder builder = TestPackedExtensions.newBuilder(); TestPackedExtensions.Builder builder = TestPackedExtensions.newBuilder();
setPackedExtensions(builder); setPackedExtensions(builder);
@ -955,6 +962,42 @@ class TestUtil {
message.addPackedEnum (ForeignEnum.FOREIGN_BAZ); message.addPackedEnum (ForeignEnum.FOREIGN_BAZ);
} }
/**
* Set every field of {@code message} to a unique value. Must correspond with
* the values applied by {@code setPackedFields}.
*/
public static void setUnpackedFields(TestUnpackedTypes.Builder message) {
message.addUnpackedInt32 (601);
message.addUnpackedInt64 (602);
message.addUnpackedUint32 (603);
message.addUnpackedUint64 (604);
message.addUnpackedSint32 (605);
message.addUnpackedSint64 (606);
message.addUnpackedFixed32 (607);
message.addUnpackedFixed64 (608);
message.addUnpackedSfixed32(609);
message.addUnpackedSfixed64(610);
message.addUnpackedFloat (611);
message.addUnpackedDouble (612);
message.addUnpackedBool (true);
message.addUnpackedEnum (ForeignEnum.FOREIGN_BAR);
// Add a second one of each field.
message.addUnpackedInt32 (701);
message.addUnpackedInt64 (702);
message.addUnpackedUint32 (703);
message.addUnpackedUint64 (704);
message.addUnpackedSint32 (705);
message.addUnpackedSint64 (706);
message.addUnpackedFixed32 (707);
message.addUnpackedFixed64 (708);
message.addUnpackedSfixed32(709);
message.addUnpackedSfixed64(710);
message.addUnpackedFloat (711);
message.addUnpackedDouble (712);
message.addUnpackedBool (false);
message.addUnpackedEnum (ForeignEnum.FOREIGN_BAZ);
}
/** /**
* Assert (using {@code junit.framework.Assert}} that all fields of * Assert (using {@code junit.framework.Assert}} that all fields of
* {@code message} are set to the values assigned by {@code setPackedFields}. * {@code message} are set to the values assigned by {@code setPackedFields}.
@ -1004,6 +1047,55 @@ class TestUtil {
Assert.assertEquals(ForeignEnum.FOREIGN_BAZ, message.getPackedEnum(1)); Assert.assertEquals(ForeignEnum.FOREIGN_BAZ, message.getPackedEnum(1));
} }
/**
* Assert (using {@code junit.framework.Assert}} that all fields of
* {@code message} are set to the values assigned by {@code setUnpackedFields}.
*/
public static void assertUnpackedFieldsSet(TestUnpackedTypes message) {
Assert.assertEquals(2, message.getUnpackedInt32Count ());
Assert.assertEquals(2, message.getUnpackedInt64Count ());
Assert.assertEquals(2, message.getUnpackedUint32Count ());
Assert.assertEquals(2, message.getUnpackedUint64Count ());
Assert.assertEquals(2, message.getUnpackedSint32Count ());
Assert.assertEquals(2, message.getUnpackedSint64Count ());
Assert.assertEquals(2, message.getUnpackedFixed32Count ());
Assert.assertEquals(2, message.getUnpackedFixed64Count ());
Assert.assertEquals(2, message.getUnpackedSfixed32Count());
Assert.assertEquals(2, message.getUnpackedSfixed64Count());
Assert.assertEquals(2, message.getUnpackedFloatCount ());
Assert.assertEquals(2, message.getUnpackedDoubleCount ());
Assert.assertEquals(2, message.getUnpackedBoolCount ());
Assert.assertEquals(2, message.getUnpackedEnumCount ());
Assert.assertEquals(601 , message.getUnpackedInt32 (0));
Assert.assertEquals(602 , message.getUnpackedInt64 (0));
Assert.assertEquals(603 , message.getUnpackedUint32 (0));
Assert.assertEquals(604 , message.getUnpackedUint64 (0));
Assert.assertEquals(605 , message.getUnpackedSint32 (0));
Assert.assertEquals(606 , message.getUnpackedSint64 (0));
Assert.assertEquals(607 , message.getUnpackedFixed32 (0));
Assert.assertEquals(608 , message.getUnpackedFixed64 (0));
Assert.assertEquals(609 , message.getUnpackedSfixed32(0));
Assert.assertEquals(610 , message.getUnpackedSfixed64(0));
Assert.assertEquals(611 , message.getUnpackedFloat (0), 0.0);
Assert.assertEquals(612 , message.getUnpackedDouble (0), 0.0);
Assert.assertEquals(true , message.getUnpackedBool (0));
Assert.assertEquals(ForeignEnum.FOREIGN_BAR, message.getUnpackedEnum(0));
Assert.assertEquals(701 , message.getUnpackedInt32 (1));
Assert.assertEquals(702 , message.getUnpackedInt64 (1));
Assert.assertEquals(703 , message.getUnpackedUint32 (1));
Assert.assertEquals(704 , message.getUnpackedUint64 (1));
Assert.assertEquals(705 , message.getUnpackedSint32 (1));
Assert.assertEquals(706 , message.getUnpackedSint64 (1));
Assert.assertEquals(707 , message.getUnpackedFixed32 (1));
Assert.assertEquals(708 , message.getUnpackedFixed64 (1));
Assert.assertEquals(709 , message.getUnpackedSfixed32(1));
Assert.assertEquals(710 , message.getUnpackedSfixed64(1));
Assert.assertEquals(711 , message.getUnpackedFloat (1), 0.0);
Assert.assertEquals(712 , message.getUnpackedDouble (1), 0.0);
Assert.assertEquals(false, message.getUnpackedBool (1));
Assert.assertEquals(ForeignEnum.FOREIGN_BAZ, message.getUnpackedEnum(1));
}
// =================================================================== // ===================================================================
// Like above, but for extensions // Like above, but for extensions

View File

@ -68,7 +68,7 @@ public class TextFormatTest extends TestCase {
private static String allExtensionsSetText = TestUtil.readTextFromFile( private static String allExtensionsSetText = TestUtil.readTextFromFile(
"text_format_unittest_extensions_data.txt"); "text_format_unittest_extensions_data.txt");
private String exoticText = private static String exoticText =
"repeated_int32: -1\n" + "repeated_int32: -1\n" +
"repeated_int32: -2147483648\n" + "repeated_int32: -2147483648\n" +
"repeated_int64: -1\n" + "repeated_int64: -1\n" +
@ -80,7 +80,13 @@ public class TextFormatTest extends TestCase {
"repeated_double: 123.0\n" + "repeated_double: 123.0\n" +
"repeated_double: 123.5\n" + "repeated_double: 123.5\n" +
"repeated_double: 0.125\n" + "repeated_double: 0.125\n" +
"repeated_double: .125\n" +
"repeated_double: -.125\n" +
"repeated_double: 1.23E17\n" + "repeated_double: 1.23E17\n" +
"repeated_double: 1.23E+17\n" +
"repeated_double: -1.23e-17\n" +
"repeated_double: .23e+17\n" +
"repeated_double: -.23E17\n" +
"repeated_double: 1.235E22\n" + "repeated_double: 1.235E22\n" +
"repeated_double: 1.235E-18\n" + "repeated_double: 1.235E-18\n" +
"repeated_double: 123.456789\n" + "repeated_double: 123.456789\n" +
@ -91,6 +97,10 @@ public class TextFormatTest extends TestCase {
"\\341\\210\\264\"\n" + "\\341\\210\\264\"\n" +
"repeated_bytes: \"\\000\\001\\a\\b\\f\\n\\r\\t\\v\\\\\\'\\\"\\376\"\n"; "repeated_bytes: \"\\000\\001\\a\\b\\f\\n\\r\\t\\v\\\\\\'\\\"\\376\"\n";
private static String canonicalExoticText =
exoticText.replace(": .", ": 0.").replace(": -.", ": -0.") // short-form double
.replace("23e", "23E").replace("E+", "E").replace("0.23E17", "2.3E16");
private String messageSetText = private String messageSetText =
"[protobuf_unittest.TestMessageSetExtension1] {\n" + "[protobuf_unittest.TestMessageSetExtension1] {\n" +
" i: 123\n" + " i: 123\n" +
@ -231,7 +241,13 @@ public class TextFormatTest extends TestCase {
.addRepeatedDouble(123) .addRepeatedDouble(123)
.addRepeatedDouble(123.5) .addRepeatedDouble(123.5)
.addRepeatedDouble(0.125) .addRepeatedDouble(0.125)
.addRepeatedDouble(.125)
.addRepeatedDouble(-.125)
.addRepeatedDouble(123e15) .addRepeatedDouble(123e15)
.addRepeatedDouble(123e15)
.addRepeatedDouble(-1.23e-17)
.addRepeatedDouble(.23e17)
.addRepeatedDouble(-23e15)
.addRepeatedDouble(123.5e20) .addRepeatedDouble(123.5e20)
.addRepeatedDouble(123.5e-20) .addRepeatedDouble(123.5e-20)
.addRepeatedDouble(123.456789) .addRepeatedDouble(123.456789)
@ -244,7 +260,7 @@ public class TextFormatTest extends TestCase {
.addRepeatedBytes(bytes("\0\001\007\b\f\n\r\t\013\\\'\"\u00fe")) .addRepeatedBytes(bytes("\0\001\007\b\f\n\r\t\013\\\'\"\u00fe"))
.build(); .build();
assertEquals(exoticText, message.toString()); assertEquals(canonicalExoticText, message.toString());
} }
public void testPrintMessageSet() throws Exception { public void testPrintMessageSet() throws Exception {
@ -319,7 +335,7 @@ public class TextFormatTest extends TestCase {
// Too lazy to check things individually. Don't try to debug this // Too lazy to check things individually. Don't try to debug this
// if testPrintExotic() is failing. // if testPrintExotic() is failing.
assertEquals(exoticText, builder.build().toString()); assertEquals(canonicalExoticText, builder.build().toString());
} }
public void testParseMessageSet() throws Exception { public void testParseMessageSet() throws Exception {

View File

@ -235,6 +235,9 @@ public class WireFormatTest extends TestCase {
TestUtil.assertPackedFieldsSet(TestPackedTypes.parseDelimitedFrom(input)); TestUtil.assertPackedFieldsSet(TestPackedTypes.parseDelimitedFrom(input));
assertEquals(34, input.read()); assertEquals(34, input.read());
assertEquals(-1, input.read()); assertEquals(-1, input.read());
// We're at EOF, so parsing again should return null.
assertTrue(TestAllTypes.parseDelimitedFrom(input) == null);
} }
private void assertFieldsInOrder(ByteString data) throws Exception { private void assertFieldsInOrder(ByteString data) throws Exception {

View File

@ -44,12 +44,24 @@ file, in types that make this information accessible in Python.
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'robinson@google.com (Will Robinson)'
class Error(Exception):
"""Base error for this module."""
class DescriptorBase(object): class DescriptorBase(object):
"""Descriptors base class. """Descriptors base class.
This class is the base of all descriptor classes. It provides common options This class is the base of all descriptor classes. It provides common options
related functionaility. related functionaility.
Attributes:
has_options: True if the descriptor has non-default options. Usually it
is not necessary to read this -- just call GetOptions() which will
happily return the default instance. However, it's sometimes useful
for efficiency, and also useful inside the protobuf implementation to
avoid some bootstrapping issues.
""" """
def __init__(self, options, options_class_name): def __init__(self, options, options_class_name):
@ -60,6 +72,9 @@ class DescriptorBase(object):
self._options = options self._options = options
self._options_class_name = options_class_name self._options_class_name = options_class_name
# Does this descriptor have non-default options?
self.has_options = options is not None
def GetOptions(self): def GetOptions(self):
"""Retrieves descriptor options. """Retrieves descriptor options.
@ -78,7 +93,70 @@ class DescriptorBase(object):
return self._options return self._options
class Descriptor(DescriptorBase): class _NestedDescriptorBase(DescriptorBase):
"""Common class for descriptors that can be nested."""
def __init__(self, options, options_class_name, name, full_name,
file, containing_type, serialized_start=None,
serialized_end=None):
"""Constructor.
Args:
options: Protocol message options or None
to use default message options.
options_class_name: (str) The class name of the above options.
name: (str) Name of this protocol message type.
full_name: (str) Fully-qualified name of this protocol message type,
which will include protocol "package" name and the name of any
enclosing types.
file: (FileDescriptor) Reference to file info.
containing_type: if provided, this is a nested descriptor, with this
descriptor as parent, otherwise None.
serialized_start: The start index (inclusive) in block in the
file.serialized_pb that describes this descriptor.
serialized_end: The end index (exclusive) in block in the
file.serialized_pb that describes this descriptor.
"""
super(_NestedDescriptorBase, self).__init__(
options, options_class_name)
self.name = name
# TODO(falk): Add function to calculate full_name instead of having it in
# memory?
self.full_name = full_name
self.file = file
self.containing_type = containing_type
self._serialized_start = serialized_start
self._serialized_end = serialized_end
def GetTopLevelContainingType(self):
"""Returns the root if this is a nested type, or itself if its the root."""
desc = self
while desc.containing_type is not None:
desc = desc.containing_type
return desc
def CopyToProto(self, proto):
"""Copies this to the matching proto in descriptor_pb2.
Args:
proto: An empty proto instance from descriptor_pb2.
Raises:
Error: If self couldnt be serialized, due to to few constructor arguments.
"""
if (self.file is not None and
self._serialized_start is not None and
self._serialized_end is not None):
proto.ParseFromString(self.file.serialized_pb[
self._serialized_start:self._serialized_end])
else:
raise Error('Descriptor does not contain serialization.')
class Descriptor(_NestedDescriptorBase):
"""Descriptor for a protocol message type. """Descriptor for a protocol message type.
@ -89,10 +167,8 @@ class Descriptor(DescriptorBase):
which will include protocol "package" name and the name of any which will include protocol "package" name and the name of any
enclosing types. enclosing types.
filename: (str) Name of the .proto file containing this message.
containing_type: (Descriptor) Reference to the descriptor of the containing_type: (Descriptor) Reference to the descriptor of the
type containing us, or None if we have no containing type. type containing us, or None if this is top-level.
fields: (list of FieldDescriptors) Field descriptors for all fields: (list of FieldDescriptors) Field descriptors for all
fields in this type. fields in this type.
@ -123,20 +199,28 @@ class Descriptor(DescriptorBase):
objects as |extensions|, but indexed by "name" attribute of each objects as |extensions|, but indexed by "name" attribute of each
FieldDescriptor. FieldDescriptor.
is_extendable: Does this type define any extension ranges?
options: (descriptor_pb2.MessageOptions) Protocol message options or None options: (descriptor_pb2.MessageOptions) Protocol message options or None
to use default message options. to use default message options.
file: (FileDescriptor) Reference to file descriptor.
""" """
def __init__(self, name, full_name, filename, containing_type, def __init__(self, name, full_name, filename, containing_type, fields,
fields, nested_types, enum_types, extensions, options=None): nested_types, enum_types, extensions, options=None,
is_extendable=True, extension_ranges=None, file=None,
serialized_start=None, serialized_end=None):
"""Arguments to __init__() are as described in the description """Arguments to __init__() are as described in the description
of Descriptor fields above. of Descriptor fields above.
Note that filename is an obsolete argument, that is not used anymore.
Please use file.name to access this as an attribute.
""" """
super(Descriptor, self).__init__(options, 'MessageOptions') super(Descriptor, self).__init__(
self.name = name options, 'MessageOptions', name, full_name, file,
self.full_name = full_name containing_type, serialized_start=serialized_start,
self.filename = filename serialized_end=serialized_start)
self.containing_type = containing_type
# We have fields in addition to fields_by_name and fields_by_number, # We have fields in addition to fields_by_name and fields_by_number,
# so that: # so that:
@ -163,6 +247,20 @@ class Descriptor(DescriptorBase):
for extension in self.extensions: for extension in self.extensions:
extension.extension_scope = self extension.extension_scope = self
self.extensions_by_name = dict((f.name, f) for f in extensions) self.extensions_by_name = dict((f.name, f) for f in extensions)
self.is_extendable = is_extendable
self.extension_ranges = extension_ranges
self._serialized_start = serialized_start
self._serialized_end = serialized_end
def CopyToProto(self, proto):
"""Copies this to a descriptor_pb2.DescriptorProto.
Args:
proto: An empty descriptor_pb2.DescriptorProto.
"""
# This function is overriden to give a better doc comment.
super(Descriptor, self).CopyToProto(proto)
# TODO(robinson): We should have aggressive checking here, # TODO(robinson): We should have aggressive checking here,
@ -195,6 +293,8 @@ class FieldDescriptor(DescriptorBase):
label: (One of the LABEL_* constants below) Tells whether this label: (One of the LABEL_* constants below) Tells whether this
field is optional, required, or repeated. field is optional, required, or repeated.
has_default_value: (bool) True if this field has a default value defined,
otherwise false.
default_value: (Varies) Default value of this field. Only default_value: (Varies) Default value of this field. Only
meaningful for non-repeated scalar fields. Repeated fields meaningful for non-repeated scalar fields. Repeated fields
should always set this to [], and non-repeated composite should always set this to [], and non-repeated composite
@ -272,7 +372,8 @@ class FieldDescriptor(DescriptorBase):
def __init__(self, name, full_name, index, number, type, cpp_type, label, def __init__(self, name, full_name, index, number, type, cpp_type, label,
default_value, message_type, enum_type, containing_type, default_value, message_type, enum_type, containing_type,
is_extension, extension_scope, options=None): is_extension, extension_scope, options=None,
has_default_value=True):
"""The arguments are as described in the description of FieldDescriptor """The arguments are as described in the description of FieldDescriptor
attributes above. attributes above.
@ -288,6 +389,7 @@ class FieldDescriptor(DescriptorBase):
self.type = type self.type = type
self.cpp_type = cpp_type self.cpp_type = cpp_type
self.label = label self.label = label
self.has_default_value = has_default_value
self.default_value = default_value self.default_value = default_value
self.containing_type = containing_type self.containing_type = containing_type
self.message_type = message_type self.message_type = message_type
@ -296,7 +398,7 @@ class FieldDescriptor(DescriptorBase):
self.extension_scope = extension_scope self.extension_scope = extension_scope
class EnumDescriptor(DescriptorBase): class EnumDescriptor(_NestedDescriptorBase):
"""Descriptor for an enum defined in a .proto file. """Descriptor for an enum defined in a .proto file.
@ -305,7 +407,6 @@ class EnumDescriptor(DescriptorBase):
name: (str) Name of the enum type. name: (str) Name of the enum type.
full_name: (str) Full name of the type, including package name full_name: (str) Full name of the type, including package name
and any enclosing type(s). and any enclosing type(s).
filename: (str) Name of the .proto file in which this appears.
values: (list of EnumValueDescriptors) List of the values values: (list of EnumValueDescriptors) List of the values
in this enum. in this enum.
@ -317,23 +418,41 @@ class EnumDescriptor(DescriptorBase):
type of this enum, or None if this is an enum defined at the type of this enum, or None if this is an enum defined at the
top level in a .proto file. Set by Descriptor's constructor top level in a .proto file. Set by Descriptor's constructor
if we're passed into one. if we're passed into one.
file: (FileDescriptor) Reference to file descriptor.
options: (descriptor_pb2.EnumOptions) Enum options message or options: (descriptor_pb2.EnumOptions) Enum options message or
None to use default enum options. None to use default enum options.
""" """
def __init__(self, name, full_name, filename, values, def __init__(self, name, full_name, filename, values,
containing_type=None, options=None): containing_type=None, options=None, file=None,
"""Arguments are as described in the attribute description above.""" serialized_start=None, serialized_end=None):
super(EnumDescriptor, self).__init__(options, 'EnumOptions') """Arguments are as described in the attribute description above.
self.name = name
self.full_name = full_name Note that filename is an obsolete argument, that is not used anymore.
self.filename = filename Please use file.name to access this as an attribute.
"""
super(EnumDescriptor, self).__init__(
options, 'EnumOptions', name, full_name, file,
containing_type, serialized_start=serialized_start,
serialized_end=serialized_start)
self.values = values self.values = values
for value in self.values: for value in self.values:
value.type = self value.type = self
self.values_by_name = dict((v.name, v) for v in values) self.values_by_name = dict((v.name, v) for v in values)
self.values_by_number = dict((v.number, v) for v in values) self.values_by_number = dict((v.number, v) for v in values)
self.containing_type = containing_type
self._serialized_start = serialized_start
self._serialized_end = serialized_end
def CopyToProto(self, proto):
"""Copies this to a descriptor_pb2.EnumDescriptorProto.
Args:
proto: An empty descriptor_pb2.EnumDescriptorProto.
"""
# This function is overriden to give a better doc comment.
super(EnumDescriptor, self).CopyToProto(proto)
class EnumValueDescriptor(DescriptorBase): class EnumValueDescriptor(DescriptorBase):
@ -360,7 +479,7 @@ class EnumValueDescriptor(DescriptorBase):
self.type = type self.type = type
class ServiceDescriptor(DescriptorBase): class ServiceDescriptor(_NestedDescriptorBase):
"""Descriptor for a service. """Descriptor for a service.
@ -372,12 +491,15 @@ class ServiceDescriptor(DescriptorBase):
service. service.
options: (descriptor_pb2.ServiceOptions) Service options message or options: (descriptor_pb2.ServiceOptions) Service options message or
None to use default service options. None to use default service options.
file: (FileDescriptor) Reference to file info.
""" """
def __init__(self, name, full_name, index, methods, options=None): def __init__(self, name, full_name, index, methods, options=None, file=None,
super(ServiceDescriptor, self).__init__(options, 'ServiceOptions') serialized_start=None, serialized_end=None):
self.name = name super(ServiceDescriptor, self).__init__(
self.full_name = full_name options, 'ServiceOptions', name, full_name, file,
None, serialized_start=serialized_start,
serialized_end=serialized_end)
self.index = index self.index = index
self.methods = methods self.methods = methods
# Set the containing service for each method in this service. # Set the containing service for each method in this service.
@ -391,6 +513,15 @@ class ServiceDescriptor(DescriptorBase):
return method return method
return None return None
def CopyToProto(self, proto):
"""Copies this to a descriptor_pb2.ServiceDescriptorProto.
Args:
proto: An empty descriptor_pb2.ServiceDescriptorProto.
"""
# This function is overriden to give a better doc comment.
super(ServiceDescriptor, self).CopyToProto(proto)
class MethodDescriptor(DescriptorBase): class MethodDescriptor(DescriptorBase):
@ -423,6 +554,32 @@ class MethodDescriptor(DescriptorBase):
self.output_type = output_type self.output_type = output_type
class FileDescriptor(DescriptorBase):
"""Descriptor for a file. Mimics the descriptor_pb2.FileDescriptorProto.
name: name of file, relative to root of source tree.
package: name of the package
serialized_pb: (str) Byte string of serialized
descriptor_pb2.FileDescriptorProto.
"""
def __init__(self, name, package, options=None, serialized_pb=None):
"""Constructor."""
super(FileDescriptor, self).__init__(options, 'FileOptions')
self.name = name
self.package = package
self.serialized_pb = serialized_pb
def CopyToProto(self, proto):
"""Copies this to a descriptor_pb2.FileDescriptorProto.
Args:
proto: An empty descriptor_pb2.FileDescriptorProto.
"""
proto.ParseFromString(self.serialized_pb)
def _ParseOptions(message, string): def _ParseOptions(message, string):
"""Parses serialized options. """Parses serialized options.
@ -430,4 +587,4 @@ def _ParseOptions(message, string):
proto2 files. It must not be used outside proto2. proto2 files. It must not be used outside proto2.
""" """
message.ParseFromString(string) message.ParseFromString(string)
return message; return message

View File

@ -54,8 +54,7 @@ class BaseContainer(object):
Args: Args:
message_listener: A MessageListener implementation. message_listener: A MessageListener implementation.
The RepeatedScalarFieldContainer will call this object's The RepeatedScalarFieldContainer will call this object's
TransitionToNonempty() method when it transitions from being empty to Modified() method when it is modified.
being nonempty.
""" """
self._message_listener = message_listener self._message_listener = message_listener
self._values = [] self._values = []
@ -73,6 +72,9 @@ class BaseContainer(object):
# The concrete classes should define __eq__. # The concrete classes should define __eq__.
return not self == other return not self == other
def __repr__(self):
return repr(self._values)
class RepeatedScalarFieldContainer(BaseContainer): class RepeatedScalarFieldContainer(BaseContainer):
@ -86,8 +88,7 @@ class RepeatedScalarFieldContainer(BaseContainer):
Args: Args:
message_listener: A MessageListener implementation. message_listener: A MessageListener implementation.
The RepeatedScalarFieldContainer will call this object's The RepeatedScalarFieldContainer will call this object's
TransitionToNonempty() method when it transitions from being empty to Modified() method when it is modified.
being nonempty.
type_checker: A type_checkers.ValueChecker instance to run on elements type_checker: A type_checkers.ValueChecker instance to run on elements
inserted into this container. inserted into this container.
""" """
@ -96,44 +97,47 @@ class RepeatedScalarFieldContainer(BaseContainer):
def append(self, value): def append(self, value):
"""Appends an item to the list. Similar to list.append().""" """Appends an item to the list. Similar to list.append()."""
self.insert(len(self._values), value) self._type_checker.CheckValue(value)
self._values.append(value)
if not self._message_listener.dirty:
self._message_listener.Modified()
def insert(self, key, value): def insert(self, key, value):
"""Inserts the item at the specified position. Similar to list.insert().""" """Inserts the item at the specified position. Similar to list.insert()."""
self._type_checker.CheckValue(value) self._type_checker.CheckValue(value)
self._values.insert(key, value) self._values.insert(key, value)
self._message_listener.ByteSizeDirty() if not self._message_listener.dirty:
if len(self._values) == 1: self._message_listener.Modified()
self._message_listener.TransitionToNonempty()
def extend(self, elem_seq): def extend(self, elem_seq):
"""Extends by appending the given sequence. Similar to list.extend().""" """Extends by appending the given sequence. Similar to list.extend()."""
if not elem_seq: if not elem_seq:
return return
orig_empty = len(self._values) == 0
new_values = [] new_values = []
for elem in elem_seq: for elem in elem_seq:
self._type_checker.CheckValue(elem) self._type_checker.CheckValue(elem)
new_values.append(elem) new_values.append(elem)
self._values.extend(new_values) self._values.extend(new_values)
self._message_listener.ByteSizeDirty() self._message_listener.Modified()
if orig_empty:
self._message_listener.TransitionToNonempty() def MergeFrom(self, other):
"""Appends the contents of another repeated field of the same type to this
one. We do not check the types of the individual fields.
"""
self._values.extend(other._values)
self._message_listener.Modified()
def remove(self, elem): def remove(self, elem):
"""Removes an item from the list. Similar to list.remove().""" """Removes an item from the list. Similar to list.remove()."""
self._values.remove(elem) self._values.remove(elem)
self._message_listener.ByteSizeDirty() self._message_listener.Modified()
def __setitem__(self, key, value): def __setitem__(self, key, value):
"""Sets the item on the specified position.""" """Sets the item on the specified position."""
# No need to call TransitionToNonempty(), since if we're able to
# set the element at this index, we were already nonempty before
# this method was called.
self._message_listener.ByteSizeDirty()
self._type_checker.CheckValue(value) self._type_checker.CheckValue(value)
self._values[key] = value self._values[key] = value
self._message_listener.Modified()
def __getslice__(self, start, stop): def __getslice__(self, start, stop):
"""Retrieves the subset of items from between the specified indices.""" """Retrieves the subset of items from between the specified indices."""
@ -146,17 +150,17 @@ class RepeatedScalarFieldContainer(BaseContainer):
self._type_checker.CheckValue(value) self._type_checker.CheckValue(value)
new_values.append(value) new_values.append(value)
self._values[start:stop] = new_values self._values[start:stop] = new_values
self._message_listener.ByteSizeDirty() self._message_listener.Modified()
def __delitem__(self, key): def __delitem__(self, key):
"""Deletes the item at the specified position.""" """Deletes the item at the specified position."""
del self._values[key] del self._values[key]
self._message_listener.ByteSizeDirty() self._message_listener.Modified()
def __delslice__(self, start, stop): def __delslice__(self, start, stop):
"""Deletes the subset of items from between the specified indices.""" """Deletes the subset of items from between the specified indices."""
del self._values[start:stop] del self._values[start:stop]
self._message_listener.ByteSizeDirty() self._message_listener.Modified()
def __eq__(self, other): def __eq__(self, other):
"""Compares the current instance with another one.""" """Compares the current instance with another one."""
@ -186,8 +190,7 @@ class RepeatedCompositeFieldContainer(BaseContainer):
Args: Args:
message_listener: A MessageListener implementation. message_listener: A MessageListener implementation.
The RepeatedCompositeFieldContainer will call this object's The RepeatedCompositeFieldContainer will call this object's
TransitionToNonempty() method when it transitions from being empty to Modified() method when it is modified.
being nonempty.
message_descriptor: A Descriptor instance describing the protocol type message_descriptor: A Descriptor instance describing the protocol type
that should be present in this container. We'll use the that should be present in this container. We'll use the
_concrete_class field of this descriptor when the client calls add(). _concrete_class field of this descriptor when the client calls add().
@ -199,10 +202,24 @@ class RepeatedCompositeFieldContainer(BaseContainer):
new_element = self._message_descriptor._concrete_class() new_element = self._message_descriptor._concrete_class()
new_element._SetListener(self._message_listener) new_element._SetListener(self._message_listener)
self._values.append(new_element) self._values.append(new_element)
self._message_listener.ByteSizeDirty() if not self._message_listener.dirty:
self._message_listener.TransitionToNonempty() self._message_listener.Modified()
return new_element return new_element
def MergeFrom(self, other):
"""Appends the contents of another repeated field of the same type to this
one, copying each individual message.
"""
message_class = self._message_descriptor._concrete_class
listener = self._message_listener
values = self._values
for message in other._values:
new_element = message_class()
new_element._SetListener(listener)
new_element.MergeFrom(message)
values.append(new_element)
listener.Modified()
def __getslice__(self, start, stop): def __getslice__(self, start, stop):
"""Retrieves the subset of items from between the specified indices.""" """Retrieves the subset of items from between the specified indices."""
return self._values[start:stop] return self._values[start:stop]
@ -210,12 +227,12 @@ class RepeatedCompositeFieldContainer(BaseContainer):
def __delitem__(self, key): def __delitem__(self, key):
"""Deletes the item at the specified position.""" """Deletes the item at the specified position."""
del self._values[key] del self._values[key]
self._message_listener.ByteSizeDirty() self._message_listener.Modified()
def __delslice__(self, start, stop): def __delslice__(self, start, stop):
"""Deletes the subset of items from between the specified indices.""" """Deletes the subset of items from between the specified indices."""
del self._values[start:stop] del self._values[start:stop]
self._message_listener.ByteSizeDirty() self._message_listener.Modified()
def __eq__(self, other): def __eq__(self, other):
"""Compares the current instance with another one.""" """Compares the current instance with another one."""

View File

@ -28,182 +28,614 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Class for decoding protocol buffer primitives. """Code for decoding protocol buffer primitives.
Contains the logic for decoding every logical protocol field type This code is very similar to encoder.py -- read the docs for that module first.
from one of the 5 physical wire types.
A "decoder" is a function with the signature:
Decode(buffer, pos, end, message, field_dict)
The arguments are:
buffer: The string containing the encoded message.
pos: The current position in the string.
end: The position in the string where the current message ends. May be
less than len(buffer) if we're reading a sub-message.
message: The message object into which we're parsing.
field_dict: message._fields (avoids a hashtable lookup).
The decoder reads the field and stores it into field_dict, returning the new
buffer position. A decoder for a repeated field may proactively decode all of
the elements of that field, if they appear consecutively.
Note that decoders may throw any of the following:
IndexError: Indicates a truncated message.
struct.error: Unpacking of a fixed-width field failed.
message.DecodeError: Other errors.
Decoders are expected to raise an exception if they are called with pos > end.
This allows callers to be lax about bounds checking: it's fineto read past
"end" as long as you are sure that someone else will notice and throw an
exception later on.
Something up the call stack is expected to catch IndexError and struct.error
and convert them to message.DecodeError.
Decoders are constructed using decoder constructors with the signature:
MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
The arguments are:
field_number: The field number of the field we want to decode.
is_repeated: Is the field a repeated field? (bool)
is_packed: Is the field a packed field? (bool)
key: The key to use when looking up the field within field_dict.
(This is actually the FieldDescriptor but nothing in this
file should depend on that.)
new_default: A function which takes a message object as a parameter and
returns a new instance of the default value for this field.
(This is called for repeated fields and sub-messages, when an
instance does not already exist.)
As with encoders, we define a decoder constructor for every type of field.
Then, for every field of every message class we construct an actual decoder.
That decoder goes into a dict indexed by tag, so when we decode a message
we repeatedly read a tag, look up the corresponding decoder, and invoke it.
""" """
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'kenton@google.com (Kenton Varda)'
import struct import struct
from google.protobuf import message from google.protobuf.internal import encoder
from google.protobuf.internal import input_stream
from google.protobuf.internal import wire_format from google.protobuf.internal import wire_format
from google.protobuf import message
# This is not for optimization, but rather to avoid conflicts with local
# Note that much of this code is ported from //net/proto/ProtocolBuffer, and # variables named "message".
# that the interface is strongly inspired by WireFormat from the C++ proto2 _DecodeError = message.DecodeError
# implementation.
class Decoder(object): def _VarintDecoder(mask):
"""Return an encoder for a basic varint value (does not include tag).
"""Decodes logical protocol buffer fields from the wire.""" Decoded values will be bitwise-anded with the given mask before being
returned, e.g. to limit them to 32 bits. The returned decoder does not
take the usual "end" parameter -- the caller is expected to do bounds checking
after the fact (often the caller can defer such checking until later). The
decoder returns a (value, new_pos) pair.
"""
def __init__(self, s): local_ord = ord
"""Initializes the decoder to read from s. def DecodeVarint(buffer, pos):
result = 0
shift = 0
while 1:
b = local_ord(buffer[pos])
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
result &= mask
return (result, pos)
shift += 7
if shift >= 64:
raise _DecodeError('Too many bytes when decoding varint.')
return DecodeVarint
Args:
s: An immutable sequence of bytes, which must be accessible def _SignedVarintDecoder(mask):
via the Python buffer() primitive (i.e., buffer(s)). """Like _VarintDecoder() but decodes signed values."""
local_ord = ord
def DecodeVarint(buffer, pos):
result = 0
shift = 0
while 1:
b = local_ord(buffer[pos])
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
if result > 0x7fffffffffffffff:
result -= (1 << 64)
result |= ~mask
else:
result &= mask
return (result, pos)
shift += 7
if shift >= 64:
raise _DecodeError('Too many bytes when decoding varint.')
return DecodeVarint
_DecodeVarint = _VarintDecoder((1 << 64) - 1)
_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1)
# Use these versions for values which must be limited to 32 bits.
_DecodeVarint32 = _VarintDecoder((1 << 32) - 1)
_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1)
def ReadTag(buffer, pos):
"""Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple.
We return the raw bytes of the tag rather than decoding them. The raw
bytes can then be used to look up the proper decoder. This effectively allows
us to trade some work that would be done in pure-python (decoding a varint)
for work that is done in C (searching for a byte string in a hash table).
In a low-level language it would be much cheaper to decode the varint and
use that, but not in Python.
"""
start = pos
while ord(buffer[pos]) & 0x80:
pos += 1
pos += 1
return (buffer[start:pos], pos)
# --------------------------------------------------------------------
def _SimpleDecoder(wire_type, decode_value):
"""Return a constructor for a decoder for fields of a particular type.
Args:
wire_type: The field's wire type.
decode_value: A function which decodes an individual value, e.g.
_DecodeVarint()
"""
def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default):
if is_packed:
local_DecodeVarint = _DecodeVarint
def DecodePackedField(buffer, pos, end, message, field_dict):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
(endpoint, pos) = local_DecodeVarint(buffer, pos)
endpoint += pos
if endpoint > end:
raise _DecodeError('Truncated message.')
while pos < endpoint:
(element, pos) = decode_value(buffer, pos)
value.append(element)
if pos > endpoint:
del value[-1] # Discard corrupt value.
raise _DecodeError('Packed element was truncated.')
return pos
return DecodePackedField
elif is_repeated:
tag_bytes = encoder.TagBytes(field_number, wire_type)
tag_len = len(tag_bytes)
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
(element, new_pos) = decode_value(buffer, pos)
value.append(element)
# Predict that the next tag is another copy of the same repeated
# field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
# Prediction failed. Return.
if new_pos > end:
raise _DecodeError('Truncated message.')
return new_pos
return DecodeRepeatedField
else:
def DecodeField(buffer, pos, end, message, field_dict):
(field_dict[key], pos) = decode_value(buffer, pos)
if pos > end:
del field_dict[key] # Discard corrupt value.
raise _DecodeError('Truncated message.')
return pos
return DecodeField
return SpecificDecoder
def _ModifiedDecoder(wire_type, decode_value, modify_value):
"""Like SimpleDecoder but additionally invokes modify_value on every value
before storing it. Usually modify_value is ZigZagDecode.
"""
# Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
# not enough to make a significant difference.
def InnerDecode(buffer, pos):
(result, new_pos) = decode_value(buffer, pos)
return (modify_value(result), new_pos)
return _SimpleDecoder(wire_type, InnerDecode)
def _StructPackDecoder(wire_type, format):
"""Return a constructor for a decoder for a fixed-width field.
Args:
wire_type: The field's wire type.
format: The format string to pass to struct.unpack().
"""
value_size = struct.calcsize(format)
local_unpack = struct.unpack
# Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
# not enough to make a significant difference.
# Note that we expect someone up-stack to catch struct.error and convert
# it to _DecodeError -- this way we don't have to set up exception-
# handling blocks every time we parse one value.
def InnerDecode(buffer, pos):
new_pos = pos + value_size
result = local_unpack(format, buffer[pos:new_pos])[0]
return (result, new_pos)
return _SimpleDecoder(wire_type, InnerDecode)
# --------------------------------------------------------------------
Int32Decoder = EnumDecoder = _SimpleDecoder(
wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
Int64Decoder = _SimpleDecoder(
wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
SInt32Decoder = _ModifiedDecoder(
wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
SInt64Decoder = _ModifiedDecoder(
wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
# Note that Python conveniently guarantees that when using the '<' prefix on
# formats, they will also have the same size across all platforms (as opposed
# to without the prefix, where their sizes depend on the C compiler's basic
# type sizes).
Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<f')
DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<d')
BoolDecoder = _ModifiedDecoder(
wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
"""Returns a decoder for a string field."""
local_DecodeVarint = _DecodeVarint
local_unicode = unicode
assert not is_packed
if is_repeated:
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_LENGTH_DELIMITED)
tag_len = len(tag_bytes)
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
value.append(local_unicode(buffer[pos:new_pos], 'utf-8'))
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
# Prediction failed. Return.
return new_pos
return DecodeRepeatedField
else:
def DecodeField(buffer, pos, end, message, field_dict):
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8')
return new_pos
return DecodeField
def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
"""Returns a decoder for a bytes field."""
local_DecodeVarint = _DecodeVarint
assert not is_packed
if is_repeated:
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_LENGTH_DELIMITED)
tag_len = len(tag_bytes)
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
value.append(buffer[pos:new_pos])
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
# Prediction failed. Return.
return new_pos
return DecodeRepeatedField
else:
def DecodeField(buffer, pos, end, message, field_dict):
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
field_dict[key] = buffer[pos:new_pos]
return new_pos
return DecodeField
def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
"""Returns a decoder for a group field."""
end_tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_END_GROUP)
end_tag_len = len(end_tag_bytes)
assert not is_packed
if is_repeated:
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_START_GROUP)
tag_len = len(tag_bytes)
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read sub-message.
pos = value.add()._InternalParse(buffer, pos, end)
# Read end tag.
new_pos = pos+end_tag_len
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
raise _DecodeError('Missing group end tag.')
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
# Prediction failed. Return.
return new_pos
return DecodeRepeatedField
else:
def DecodeField(buffer, pos, end, message, field_dict):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read sub-message.
pos = value._InternalParse(buffer, pos, end)
# Read end tag.
new_pos = pos+end_tag_len
if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
raise _DecodeError('Missing group end tag.')
return new_pos
return DecodeField
def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
"""Returns a decoder for a message field."""
local_DecodeVarint = _DecodeVarint
assert not is_packed
if is_repeated:
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_LENGTH_DELIMITED)
tag_len = len(tag_bytes)
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read length.
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated message.')
# Read sub-message.
if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
# The only reason _InternalParse would return early is if it
# encountered an end-group tag.
raise _DecodeError('Unexpected end-group tag.')
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
# Prediction failed. Return.
return new_pos
return DecodeRepeatedField
else:
def DecodeField(buffer, pos, end, message, field_dict):
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
# Read length.
(size, pos) = local_DecodeVarint(buffer, pos)
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated message.')
# Read sub-message.
if value._InternalParse(buffer, pos, new_pos) != new_pos:
# The only reason _InternalParse would return early is if it encountered
# an end-group tag.
raise _DecodeError('Unexpected end-group tag.')
return new_pos
return DecodeField
# --------------------------------------------------------------------
MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
def MessageSetItemDecoder(extensions_by_number):
"""Returns a decoder for a MessageSet item.
The parameter is the _extensions_by_number map for the message class.
The message set message looks like this:
message MessageSet {
repeated group Item = 1 {
required int32 type_id = 2;
required string message = 3;
}
}
"""
type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
local_ReadTag = ReadTag
local_DecodeVarint = _DecodeVarint
local_SkipField = SkipField
def DecodeItem(buffer, pos, end, message, field_dict):
type_id = -1
message_start = -1
message_end = -1
# Technically, type_id and message can appear in any order, so we need
# a little loop here.
while 1:
(tag_bytes, pos) = local_ReadTag(buffer, pos)
if tag_bytes == type_id_tag_bytes:
(type_id, pos) = local_DecodeVarint(buffer, pos)
elif tag_bytes == message_tag_bytes:
(size, message_start) = local_DecodeVarint(buffer, pos)
pos = message_end = message_start + size
elif tag_bytes == item_end_tag_bytes:
break
else:
pos = SkipField(buffer, pos, end, tag_bytes)
if pos == -1:
raise _DecodeError('Missing group end tag.')
if pos > end:
raise _DecodeError('Truncated message.')
if type_id == -1:
raise _DecodeError('MessageSet item missing type_id.')
if message_start == -1:
raise _DecodeError('MessageSet item missing message.')
extension = extensions_by_number.get(type_id)
if extension is not None:
value = field_dict.get(extension)
if value is None:
value = field_dict.setdefault(
extension, extension.message_type._concrete_class())
if value._InternalParse(buffer, message_start,message_end) != message_end:
# The only reason _InternalParse would return early is if it encountered
# an end-group tag.
raise _DecodeError('Unexpected end-group tag.')
return pos
return DecodeItem
# --------------------------------------------------------------------
# Optimization is not as heavy here because calls to SkipField() are rare,
# except for handling end-group tags.
def _SkipVarint(buffer, pos, end):
"""Skip a varint value. Returns the new position."""
while ord(buffer[pos]) & 0x80:
pos += 1
pos += 1
if pos > end:
raise _DecodeError('Truncated message.')
return pos
def _SkipFixed64(buffer, pos, end):
"""Skip a fixed64 value. Returns the new position."""
pos += 8
if pos > end:
raise _DecodeError('Truncated message.')
return pos
def _SkipLengthDelimited(buffer, pos, end):
"""Skip a length-delimited value. Returns the new position."""
(size, pos) = _DecodeVarint(buffer, pos)
pos += size
if pos > end:
raise _DecodeError('Truncated message.')
return pos
def _SkipGroup(buffer, pos, end):
"""Skip sub-group. Returns the new position."""
while 1:
(tag_bytes, pos) = ReadTag(buffer, pos)
new_pos = SkipField(buffer, pos, end, tag_bytes)
if new_pos == -1:
return pos
pos = new_pos
def _EndGroup(buffer, pos, end):
"""Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
return -1
def _SkipFixed32(buffer, pos, end):
"""Skip a fixed32 value. Returns the new position."""
pos += 4
if pos > end:
raise _DecodeError('Truncated message.')
return pos
def _RaiseInvalidWireType(buffer, pos, end):
"""Skip function for unknown wire types. Raises an exception."""
raise _DecodeError('Tag had invalid wire type.')
def _FieldSkipper():
"""Constructs the SkipField function."""
WIRETYPE_TO_SKIPPER = [
_SkipVarint,
_SkipFixed64,
_SkipLengthDelimited,
_SkipGroup,
_EndGroup,
_SkipFixed32,
_RaiseInvalidWireType,
_RaiseInvalidWireType,
]
wiretype_mask = wire_format.TAG_TYPE_MASK
local_ord = ord
def SkipField(buffer, pos, end, tag_bytes):
"""Skips a field with the specified tag.
|pos| should point to the byte immediately after the tag.
Returns:
The new position (after the tag value), or -1 if the tag is an end-group
tag (in which case the calling loop should break).
""" """
self._stream = input_stream.InputStream(s)
def EndOfStream(self): # The wire type is always in the first byte since varints are little-endian.
"""Returns true iff we've reached the end of the bytes we're reading.""" wire_type = local_ord(tag_bytes[0]) & wiretype_mask
return self._stream.EndOfStream() return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
def Position(self): return SkipField
"""Returns the 0-indexed position in |s|."""
return self._stream.Position()
def ReadFieldNumberAndWireType(self): SkipField = _FieldSkipper()
"""Reads a tag from the wire. Returns a (field_number, wire_type) pair."""
tag_and_type = self.ReadUInt32()
return wire_format.UnpackTag(tag_and_type)
def SkipBytes(self, bytes):
"""Skips the specified number of bytes on the wire."""
self._stream.SkipBytes(bytes)
# Note that the Read*() methods below are not exactly symmetrical with the
# corresponding Encoder.Append*() methods. Those Encoder methods first
# encode a tag, but the Read*() methods below assume that the tag has already
# been read, and that the client wishes to read a field of the specified type
# starting at the current position.
def ReadInt32(self):
"""Reads and returns a signed, varint-encoded, 32-bit integer."""
return self._stream.ReadVarint32()
def ReadInt64(self):
"""Reads and returns a signed, varint-encoded, 64-bit integer."""
return self._stream.ReadVarint64()
def ReadUInt32(self):
"""Reads and returns an signed, varint-encoded, 32-bit integer."""
return self._stream.ReadVarUInt32()
def ReadUInt64(self):
"""Reads and returns an signed, varint-encoded,64-bit integer."""
return self._stream.ReadVarUInt64()
def ReadSInt32(self):
"""Reads and returns a signed, zigzag-encoded, varint-encoded,
32-bit integer."""
return wire_format.ZigZagDecode(self._stream.ReadVarUInt32())
def ReadSInt64(self):
"""Reads and returns a signed, zigzag-encoded, varint-encoded,
64-bit integer."""
return wire_format.ZigZagDecode(self._stream.ReadVarUInt64())
def ReadFixed32(self):
"""Reads and returns an unsigned, fixed-width, 32-bit integer."""
return self._stream.ReadLittleEndian32()
def ReadFixed64(self):
"""Reads and returns an unsigned, fixed-width, 64-bit integer."""
return self._stream.ReadLittleEndian64()
def ReadSFixed32(self):
"""Reads and returns a signed, fixed-width, 32-bit integer."""
value = self._stream.ReadLittleEndian32()
if value >= (1 << 31):
value -= (1 << 32)
return value
def ReadSFixed64(self):
"""Reads and returns a signed, fixed-width, 64-bit integer."""
value = self._stream.ReadLittleEndian64()
if value >= (1 << 63):
value -= (1 << 64)
return value
def ReadFloat(self):
"""Reads and returns a 4-byte floating-point number."""
serialized = self._stream.ReadBytes(4)
return struct.unpack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, serialized)[0]
def ReadDouble(self):
"""Reads and returns an 8-byte floating-point number."""
serialized = self._stream.ReadBytes(8)
return struct.unpack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, serialized)[0]
def ReadBool(self):
"""Reads and returns a bool."""
i = self._stream.ReadVarUInt32()
return bool(i)
def ReadEnum(self):
"""Reads and returns an enum value."""
return self._stream.ReadVarUInt32()
def ReadString(self):
"""Reads and returns a length-delimited string."""
bytes = self.ReadBytes()
return unicode(bytes, 'utf-8')
def ReadBytes(self):
"""Reads and returns a length-delimited byte sequence."""
length = self._stream.ReadVarUInt32()
return self._stream.ReadBytes(length)
def ReadMessageInto(self, msg):
"""Calls msg.MergeFromString() to merge
length-delimited serialized message data into |msg|.
REQUIRES: The decoder must be positioned at the serialized "length"
prefix to a length-delmiited serialized message.
POSTCONDITION: The decoder is positioned just after the
serialized message, and we have merged those serialized
contents into |msg|.
"""
length = self._stream.ReadVarUInt32()
sub_buffer = self._stream.GetSubBuffer(length)
num_bytes_used = msg.MergeFromString(sub_buffer)
if num_bytes_used != length:
raise message.DecodeError(
'Submessage told to deserialize from %d-byte encoding, '
'but used only %d bytes' % (length, num_bytes_used))
self._stream.SkipBytes(num_bytes_used)
def ReadGroupInto(self, expected_field_number, group):
"""Calls group.MergeFromString() to merge
END_GROUP-delimited serialized message data into |group|.
We'll raise an exception if we don't find an END_GROUP
tag immediately after the serialized message contents.
REQUIRES: The decoder is positioned just after the START_GROUP
tag for this group.
POSTCONDITION: The decoder is positioned just after the
END_GROUP tag for this group, and we have merged
the contents of the group into |group|.
"""
sub_buffer = self._stream.GetSubBuffer() # No a priori length limit.
num_bytes_used = group.MergeFromString(sub_buffer)
if num_bytes_used < 0:
raise message.DecodeError('Group message reported negative bytes read.')
self._stream.SkipBytes(num_bytes_used)
field_number, field_type = self.ReadFieldNumberAndWireType()
if field_type != wire_format.WIRETYPE_END_GROUP:
raise message.DecodeError('Group message did not end with an END_GROUP.')
if field_number != expected_field_number:
raise message.DecodeError('END_GROUP tag had field '
'number %d, was expecting field number %d' % (
field_number, expected_field_number))
# We're now positioned just after the END_GROUP tag. Perfect.

View File

@ -1,256 +0,0 @@
#! /usr/bin/python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# http://code.google.com/p/protobuf/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test for google.protobuf.internal.decoder."""
__author__ = 'robinson@google.com (Will Robinson)'
import struct
import unittest
from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
from google.protobuf.internal import input_stream
from google.protobuf.internal import wire_format
from google.protobuf import message
import logging
import mox
class DecoderTest(unittest.TestCase):
def setUp(self):
self.mox = mox.Mox()
self.mock_stream = self.mox.CreateMock(input_stream.InputStream)
self.mock_message = self.mox.CreateMock(message.Message)
def testReadFieldNumberAndWireType(self):
# Test field numbers that will require various varint sizes.
for expected_field_number in (1, 15, 16, 2047, 2048):
for expected_wire_type in range(6): # Highest-numbered wiretype is 5.
e = encoder.Encoder()
e.AppendTag(expected_field_number, expected_wire_type)
s = e.ToString()
d = decoder.Decoder(s)
field_number, wire_type = d.ReadFieldNumberAndWireType()
self.assertEqual(expected_field_number, field_number)
self.assertEqual(expected_wire_type, wire_type)
def ReadScalarTestHelper(self, test_name, decoder_method, expected_result,
expected_stream_method_name,
stream_method_return, *args):
"""Helper for testReadScalars below.
Calls one of the Decoder.Read*() methods and ensures that the results are
as expected.
Args:
test_name: Name of this test, used for logging only.
decoder_method: Unbound decoder.Decoder method to call.
expected_result: Value we expect returned from decoder_method().
expected_stream_method_name: (string) Name of the InputStream
method we expect Decoder to call to actually read the value
on the wire.
stream_method_return: Value our mocked-out stream method should
return to the decoder.
args: Additional arguments that we expect to be passed to the
stream method.
"""
logging.info('Testing %s scalar input.\n'
'Calling %r(), and expecting that to call the '
'stream method %s(%r), which will return %r. Finally, '
'expecting the Decoder method to return %r'% (
test_name, decoder_method,
expected_stream_method_name, args, stream_method_return,
expected_result))
d = decoder.Decoder('')
d._stream = self.mock_stream
if decoder_method in (decoder.Decoder.ReadString,
decoder.Decoder.ReadBytes):
self.mock_stream.ReadVarUInt32().AndReturn(len(stream_method_return))
# We have to use names instead of methods to work around some
# mox weirdness. (ResetAll() is overzealous).
expected_stream_method = getattr(self.mock_stream,
expected_stream_method_name)
expected_stream_method(*args).AndReturn(stream_method_return)
self.mox.ReplayAll()
result = decoder_method(d)
self.assertEqual(expected_result, result)
self.assert_(isinstance(result, type(expected_result)))
self.mox.VerifyAll()
self.mox.ResetAll()
VAL = 1.125 # Perfectly representable as a float (no rounding error).
LITTLE_FLOAT_VAL = '\x00\x00\x90?'
LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?'
def testReadScalars(self):
test_string = 'I can feel myself getting sutpider.'
scalar_tests = [
['int32', decoder.Decoder.ReadInt32, 0, 'ReadVarint32', 0],
['int64', decoder.Decoder.ReadInt64, 0, 'ReadVarint64', 0],
['uint32', decoder.Decoder.ReadUInt32, 0, 'ReadVarUInt32', 0],
['uint64', decoder.Decoder.ReadUInt64, 0, 'ReadVarUInt64', 0],
['fixed32', decoder.Decoder.ReadFixed32, 0xffffffff,
'ReadLittleEndian32', 0xffffffff],
['fixed64', decoder.Decoder.ReadFixed64, 0xffffffffffffffff,
'ReadLittleEndian64', 0xffffffffffffffff],
['sfixed32', decoder.Decoder.ReadSFixed32, long(-1),
'ReadLittleEndian32', long(0xffffffff)],
['sfixed64', decoder.Decoder.ReadSFixed64, long(-1),
'ReadLittleEndian64', 0xffffffffffffffff],
['float', decoder.Decoder.ReadFloat, self.VAL,
'ReadBytes', self.LITTLE_FLOAT_VAL, 4],
['double', decoder.Decoder.ReadDouble, self.VAL,
'ReadBytes', self.LITTLE_DOUBLE_VAL, 8],
['bool', decoder.Decoder.ReadBool, True, 'ReadVarUInt32', 1],
['enum', decoder.Decoder.ReadEnum, 23, 'ReadVarUInt32', 23],
['string', decoder.Decoder.ReadString,
unicode(test_string, 'utf-8'), 'ReadBytes', test_string,
len(test_string)],
['utf8-string', decoder.Decoder.ReadString,
unicode(test_string, 'utf-8'), 'ReadBytes', test_string,
len(test_string)],
['bytes', decoder.Decoder.ReadBytes,
test_string, 'ReadBytes', test_string, len(test_string)],
# We test zigzag decoding routines more extensively below.
['sint32', decoder.Decoder.ReadSInt32, -1, 'ReadVarUInt32', 1],
['sint64', decoder.Decoder.ReadSInt64, -1, 'ReadVarUInt64', 1],
]
# Ensure that we're testing different Decoder methods and using
# different test names in all test cases above.
self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests)))
self.assert_(len(scalar_tests) >= len(set(t[1] for t in scalar_tests)))
for args in scalar_tests:
self.ReadScalarTestHelper(*args)
def testReadMessageInto(self):
length = 23
def Test(simulate_error):
d = decoder.Decoder('')
d._stream = self.mock_stream
self.mock_stream.ReadVarUInt32().AndReturn(length)
sub_buffer = object()
self.mock_stream.GetSubBuffer(length).AndReturn(sub_buffer)
if simulate_error:
self.mock_message.MergeFromString(sub_buffer).AndReturn(length - 1)
self.mox.ReplayAll()
self.assertRaises(
message.DecodeError, d.ReadMessageInto, self.mock_message)
else:
self.mock_message.MergeFromString(sub_buffer).AndReturn(length)
self.mock_stream.SkipBytes(length)
self.mox.ReplayAll()
d.ReadMessageInto(self.mock_message)
self.mox.VerifyAll()
self.mox.ResetAll()
Test(simulate_error=False)
Test(simulate_error=True)
def testReadGroupInto_Success(self):
# Test both the empty and nonempty cases.
for num_bytes in (5, 0):
field_number = expected_field_number = 10
d = decoder.Decoder('')
d._stream = self.mock_stream
sub_buffer = object()
self.mock_stream.GetSubBuffer().AndReturn(sub_buffer)
self.mock_message.MergeFromString(sub_buffer).AndReturn(num_bytes)
self.mock_stream.SkipBytes(num_bytes)
self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag(
field_number, wire_format.WIRETYPE_END_GROUP))
self.mox.ReplayAll()
d.ReadGroupInto(expected_field_number, self.mock_message)
self.mox.VerifyAll()
self.mox.ResetAll()
def ReadGroupInto_FailureTestHelper(self, bytes_read):
d = decoder.Decoder('')
d._stream = self.mock_stream
sub_buffer = object()
self.mock_stream.GetSubBuffer().AndReturn(sub_buffer)
self.mock_message.MergeFromString(sub_buffer).AndReturn(bytes_read)
return d
def testReadGroupInto_NegativeBytesReported(self):
expected_field_number = 10
d = self.ReadGroupInto_FailureTestHelper(bytes_read=-1)
self.mox.ReplayAll()
self.assertRaises(message.DecodeError,
d.ReadGroupInto, expected_field_number,
self.mock_message)
self.mox.VerifyAll()
def testReadGroupInto_NoEndGroupTag(self):
field_number = expected_field_number = 10
num_bytes = 5
d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes)
self.mock_stream.SkipBytes(num_bytes)
# Right field number, wrong wire type.
self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag(
field_number, wire_format.WIRETYPE_LENGTH_DELIMITED))
self.mox.ReplayAll()
self.assertRaises(message.DecodeError,
d.ReadGroupInto, expected_field_number,
self.mock_message)
self.mox.VerifyAll()
def testReadGroupInto_WrongFieldNumberInEndGroupTag(self):
expected_field_number = 10
field_number = expected_field_number + 1
num_bytes = 5
d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes)
self.mock_stream.SkipBytes(num_bytes)
# Wrong field number, right wire type.
self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag(
field_number, wire_format.WIRETYPE_END_GROUP))
self.mox.ReplayAll()
self.assertRaises(message.DecodeError,
d.ReadGroupInto, expected_field_number,
self.mock_message)
self.mox.VerifyAll()
def testSkipBytes(self):
d = decoder.Decoder('')
num_bytes = 1024
self.mock_stream.SkipBytes(num_bytes)
d._stream = self.mock_stream
self.mox.ReplayAll()
d.SkipBytes(num_bytes)
self.mox.VerifyAll()
if __name__ == '__main__':
unittest.main()

View File

@ -35,16 +35,30 @@
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'robinson@google.com (Will Robinson)'
import unittest import unittest
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pb2
from google.protobuf import descriptor from google.protobuf import descriptor
from google.protobuf import text_format
TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII = """
name: 'TestEmptyMessage'
"""
class DescriptorTest(unittest.TestCase): class DescriptorTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.my_file = descriptor.FileDescriptor(
name='some/filename/some.proto',
package='protobuf_unittest'
)
self.my_enum = descriptor.EnumDescriptor( self.my_enum = descriptor.EnumDescriptor(
name='ForeignEnum', name='ForeignEnum',
full_name='protobuf_unittest.ForeignEnum', full_name='protobuf_unittest.ForeignEnum',
filename='ForeignEnum', filename=None,
file=self.my_file,
values=[ values=[
descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4), descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4),
descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5), descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5),
@ -53,7 +67,8 @@ class DescriptorTest(unittest.TestCase):
self.my_message = descriptor.Descriptor( self.my_message = descriptor.Descriptor(
name='NestedMessage', name='NestedMessage',
full_name='protobuf_unittest.TestAllTypes.NestedMessage', full_name='protobuf_unittest.TestAllTypes.NestedMessage',
filename='some/filename/some.proto', filename=None,
file=self.my_file,
containing_type=None, containing_type=None,
fields=[ fields=[
descriptor.FieldDescriptor( descriptor.FieldDescriptor(
@ -61,7 +76,7 @@ class DescriptorTest(unittest.TestCase):
full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb', full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb',
index=0, number=1, index=0, number=1,
type=5, cpp_type=1, label=1, type=5, cpp_type=1, label=1,
default_value=0, has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None, message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None), is_extension=False, extension_scope=None),
], ],
@ -80,6 +95,7 @@ class DescriptorTest(unittest.TestCase):
self.my_service = descriptor.ServiceDescriptor( self.my_service = descriptor.ServiceDescriptor(
name='TestServiceWithOptions', name='TestServiceWithOptions',
full_name='protobuf_unittest.TestServiceWithOptions', full_name='protobuf_unittest.TestServiceWithOptions',
file=self.my_file,
index=0, index=0,
methods=[ methods=[
self.my_method self.my_method
@ -109,5 +125,210 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual(self.my_service.GetOptions(), self.assertEqual(self.my_service.GetOptions(),
descriptor_pb2.ServiceOptions()) descriptor_pb2.ServiceOptions())
def testFileDescriptorReferences(self):
self.assertEqual(self.my_enum.file, self.my_file)
self.assertEqual(self.my_message.file, self.my_file)
def testFileDescriptor(self):
self.assertEqual(self.my_file.name, 'some/filename/some.proto')
self.assertEqual(self.my_file.package, 'protobuf_unittest')
class DescriptorCopyToProtoTest(unittest.TestCase):
"""Tests for CopyTo functions of Descriptor."""
def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii):
expected_proto = expected_class()
text_format.Merge(expected_ascii, expected_proto)
self.assertEqual(
actual_proto, expected_proto,
'Not equal,\nActual:\n%s\nExpected:\n%s\n'
% (str(actual_proto), str(expected_proto)))
def _InternalTestCopyToProto(self, desc, expected_proto_class,
expected_proto_ascii):
actual = expected_proto_class()
desc.CopyToProto(actual)
self._AssertProtoEqual(
actual, expected_proto_class, expected_proto_ascii)
def testCopyToProto_EmptyMessage(self):
self._InternalTestCopyToProto(
unittest_pb2.TestEmptyMessage.DESCRIPTOR,
descriptor_pb2.DescriptorProto,
TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII)
def testCopyToProto_NestedMessage(self):
TEST_NESTED_MESSAGE_ASCII = """
name: 'NestedMessage'
field: <
name: 'bb'
number: 1
label: 1 # Optional
type: 5 # TYPE_INT32
>
"""
self._InternalTestCopyToProto(
unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
descriptor_pb2.DescriptorProto,
TEST_NESTED_MESSAGE_ASCII)
def testCopyToProto_ForeignNestedMessage(self):
TEST_FOREIGN_NESTED_ASCII = """
name: 'TestForeignNested'
field: <
name: 'foreign_nested'
number: 1
label: 1 # Optional
type: 11 # TYPE_MESSAGE
type_name: '.protobuf_unittest.TestAllTypes.NestedMessage'
>
"""
self._InternalTestCopyToProto(
unittest_pb2.TestForeignNested.DESCRIPTOR,
descriptor_pb2.DescriptorProto,
TEST_FOREIGN_NESTED_ASCII)
def testCopyToProto_ForeignEnum(self):
TEST_FOREIGN_ENUM_ASCII = """
name: 'ForeignEnum'
value: <
name: 'FOREIGN_FOO'
number: 4
>
value: <
name: 'FOREIGN_BAR'
number: 5
>
value: <
name: 'FOREIGN_BAZ'
number: 6
>
"""
self._InternalTestCopyToProto(
unittest_pb2._FOREIGNENUM,
descriptor_pb2.EnumDescriptorProto,
TEST_FOREIGN_ENUM_ASCII)
def testCopyToProto_Options(self):
TEST_DEPRECATED_FIELDS_ASCII = """
name: 'TestDeprecatedFields'
field: <
name: 'deprecated_int32'
number: 1
label: 1 # Optional
type: 5 # TYPE_INT32
options: <
deprecated: true
>
>
"""
self._InternalTestCopyToProto(
unittest_pb2.TestDeprecatedFields.DESCRIPTOR,
descriptor_pb2.DescriptorProto,
TEST_DEPRECATED_FIELDS_ASCII)
def testCopyToProto_AllExtensions(self):
TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII = """
name: 'TestEmptyMessageWithExtensions'
extension_range: <
start: 1
end: 536870912
>
"""
self._InternalTestCopyToProto(
unittest_pb2.TestEmptyMessageWithExtensions.DESCRIPTOR,
descriptor_pb2.DescriptorProto,
TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII)
def testCopyToProto_SeveralExtensions(self):
TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII = """
name: 'TestMultipleExtensionRanges'
extension_range: <
start: 42
end: 43
>
extension_range: <
start: 4143
end: 4244
>
extension_range: <
start: 65536
end: 536870912
>
"""
self._InternalTestCopyToProto(
unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR,
descriptor_pb2.DescriptorProto,
TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII)
def testCopyToProto_FileDescriptor(self):
UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = ("""
name: 'google/protobuf/unittest_import.proto'
package: 'protobuf_unittest_import'
message_type: <
name: 'ImportMessage'
field: <
name: 'd'
number: 1
label: 1 # Optional
type: 5 # TYPE_INT32
>
>
""" +
"""enum_type: <
name: 'ImportEnum'
value: <
name: 'IMPORT_FOO'
number: 7
>
value: <
name: 'IMPORT_BAR'
number: 8
>
value: <
name: 'IMPORT_BAZ'
number: 9
>
>
options: <
java_package: 'com.google.protobuf.test'
optimize_for: 1 # SPEED
>
""")
self._InternalTestCopyToProto(
unittest_import_pb2.DESCRIPTOR,
descriptor_pb2.FileDescriptorProto,
UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII)
def testCopyToProto_ServiceDescriptor(self):
TEST_SERVICE_ASCII = """
name: 'TestService'
method: <
name: 'Foo'
input_type: '.protobuf_unittest.FooRequest'
output_type: '.protobuf_unittest.FooResponse'
>
method: <
name: 'Bar'
input_type: '.protobuf_unittest.BarRequest'
output_type: '.protobuf_unittest.BarResponse'
>
"""
self._InternalTestCopyToProto(
unittest_pb2.TestService.DESCRIPTOR,
descriptor_pb2.ServiceDescriptorProto,
TEST_SERVICE_ASCII)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -28,253 +28,659 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Class for encoding protocol message primitives. """Code for encoding protocol message primitives.
Contains the logic for encoding every logical protocol field type Contains the logic for encoding every logical protocol field type
into one of the 5 physical wire types. into one of the 5 physical wire types.
This code is designed to push the Python interpreter's performance to the
limits.
The basic idea is that at startup time, for every field (i.e. every
FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The
sizer takes a value of this field's type and computes its byte size. The
encoder takes a writer function and a value. It encodes the value into byte
strings and invokes the writer function to write those strings. Typically the
writer function is the write() method of a cStringIO.
We try to do as much work as possible when constructing the writer and the
sizer rather than when calling them. In particular:
* We copy any needed global functions to local variables, so that we do not need
to do costly global table lookups at runtime.
* Similarly, we try to do any attribute lookups at startup time if possible.
* Every field's tag is encoded to bytes at startup, since it can't change at
runtime.
* Whatever component of the field size we can compute at startup, we do.
* We *avoid* sharing code if doing so would make the code slower and not sharing
does not burden us too much. For example, encoders for repeated fields do
not just call the encoders for singular fields in a loop because this would
add an extra function call overhead for every loop iteration; instead, we
manually inline the single-value encoder into the loop.
* If a Python function lacks a return statement, Python actually generates
instructions to pop the result of the last statement off the stack, push
None onto the stack, and then return that. If we really don't care what
value is returned, then we can save two instructions by returning the
result of the last statement. It looks funny but it helps.
* We assume that type and bounds checking has happened at a higher level.
""" """
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'kenton@google.com (Kenton Varda)'
import struct import struct
from google.protobuf import message
from google.protobuf.internal import wire_format from google.protobuf.internal import wire_format
from google.protobuf.internal import output_stream
# Note that much of this code is ported from //net/proto/ProtocolBuffer, and def _VarintSize(value):
# that the interface is strongly inspired by WireFormat from the C++ proto2 """Compute the size of a varint value."""
# implementation. if value <= 0x7f: return 1
if value <= 0x3fff: return 2
if value <= 0x1fffff: return 3
if value <= 0xfffffff: return 4
if value <= 0x7ffffffff: return 5
if value <= 0x3ffffffffff: return 6
if value <= 0x1ffffffffffff: return 7
if value <= 0xffffffffffffff: return 8
if value <= 0x7fffffffffffffff: return 9
return 10
class Encoder(object): def _SignedVarintSize(value):
"""Compute the size of a signed varint value."""
"""Encodes logical protocol buffer fields to the wire format.""" if value < 0: return 10
if value <= 0x7f: return 1
def __init__(self): if value <= 0x3fff: return 2
self._stream = output_stream.OutputStream() if value <= 0x1fffff: return 3
if value <= 0xfffffff: return 4
def ToString(self): if value <= 0x7ffffffff: return 5
"""Returns all values encoded in this object as a string.""" if value <= 0x3ffffffffff: return 6
return self._stream.ToString() if value <= 0x1ffffffffffff: return 7
if value <= 0xffffffffffffff: return 8
# Append*NoTag methods. These are necessary for serializing packed if value <= 0x7fffffffffffffff: return 9
# repeated fields. The Append*() methods call these methods to do return 10
# the actual serialization.
def AppendInt32NoTag(self, value):
"""Appends a 32-bit integer to our buffer, varint-encoded."""
self._stream.AppendVarint32(value)
def AppendInt64NoTag(self, value):
"""Appends a 64-bit integer to our buffer, varint-encoded."""
self._stream.AppendVarint64(value)
def AppendUInt32NoTag(self, unsigned_value):
"""Appends an unsigned 32-bit integer to our buffer, varint-encoded."""
self._stream.AppendVarUInt32(unsigned_value)
def AppendUInt64NoTag(self, unsigned_value):
"""Appends an unsigned 64-bit integer to our buffer, varint-encoded."""
self._stream.AppendVarUInt64(unsigned_value)
def AppendSInt32NoTag(self, value):
"""Appends a 32-bit integer to our buffer, zigzag-encoded and then
varint-encoded.
"""
zigzag_value = wire_format.ZigZagEncode(value)
self._stream.AppendVarUInt32(zigzag_value)
def AppendSInt64NoTag(self, value):
"""Appends a 64-bit integer to our buffer, zigzag-encoded and then
varint-encoded.
"""
zigzag_value = wire_format.ZigZagEncode(value)
self._stream.AppendVarUInt64(zigzag_value)
def AppendFixed32NoTag(self, unsigned_value):
"""Appends an unsigned 32-bit integer to our buffer, in little-endian
byte-order.
"""
self._stream.AppendLittleEndian32(unsigned_value)
def AppendFixed64NoTag(self, unsigned_value):
"""Appends an unsigned 64-bit integer to our buffer, in little-endian
byte-order.
"""
self._stream.AppendLittleEndian64(unsigned_value)
def AppendSFixed32NoTag(self, value):
"""Appends a signed 32-bit integer to our buffer, in little-endian
byte-order.
"""
sign = (value & 0x80000000) and -1 or 0
if value >> 32 != sign:
raise message.EncodeError('SFixed32 out of range: %d' % value)
self._stream.AppendLittleEndian32(value & 0xffffffff)
def AppendSFixed64NoTag(self, value):
"""Appends a signed 64-bit integer to our buffer, in little-endian
byte-order.
"""
sign = (value & 0x8000000000000000) and -1 or 0
if value >> 64 != sign:
raise message.EncodeError('SFixed64 out of range: %d' % value)
self._stream.AppendLittleEndian64(value & 0xffffffffffffffff)
def AppendFloatNoTag(self, value):
"""Appends a floating-point number to our buffer."""
self._stream.AppendRawBytes(
struct.pack(wire_format.FORMAT_FLOAT_LITTLE_ENDIAN, value))
def AppendDoubleNoTag(self, value):
"""Appends a double-precision floating-point number to our buffer."""
self._stream.AppendRawBytes(
struct.pack(wire_format.FORMAT_DOUBLE_LITTLE_ENDIAN, value))
def AppendBoolNoTag(self, value):
"""Appends a boolean to our buffer."""
self.AppendInt32NoTag(value)
def AppendEnumNoTag(self, value):
"""Appends an enum value to our buffer."""
self.AppendInt32NoTag(value)
# All the Append*() methods below first append a tag+type pair to the buffer def _TagSize(field_number):
# before appending the specified value. """Returns the number of bytes required to serialize a tag with this field
number."""
# Just pass in type 0, since the type won't affect the tag+type size.
return _VarintSize(wire_format.PackTag(field_number, 0))
def AppendInt32(self, field_number, value):
"""Appends a 32-bit integer to our buffer, varint-encoded."""
self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
self.AppendInt32NoTag(value)
def AppendInt64(self, field_number, value): # --------------------------------------------------------------------
"""Appends a 64-bit integer to our buffer, varint-encoded.""" # In this section we define some generic sizers. Each of these functions
self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) # takes parameters specific to a particular field type, e.g. int32 or fixed64.
self.AppendInt64NoTag(value) # It returns another function which in turn takes parameters specific to a
# particular field, e.g. the field number and whether it is repeated or packed.
# Look at the next section to see how these are used.
def AppendUInt32(self, field_number, unsigned_value):
"""Appends an unsigned 32-bit integer to our buffer, varint-encoded."""
self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
self.AppendUInt32NoTag(unsigned_value)
def AppendUInt64(self, field_number, unsigned_value): def _SimpleSizer(compute_value_size):
"""Appends an unsigned 64-bit integer to our buffer, varint-encoded.""" """A sizer which uses the function compute_value_size to compute the size of
self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) each value. Typically compute_value_size is _VarintSize."""
self.AppendUInt64NoTag(unsigned_value)
def AppendSInt32(self, field_number, value): def SpecificSizer(field_number, is_repeated, is_packed):
"""Appends a 32-bit integer to our buffer, zigzag-encoded and then tag_size = _TagSize(field_number)
varint-encoded. if is_packed:
""" local_VarintSize = _VarintSize
self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) def PackedFieldSize(value):
self.AppendSInt32NoTag(value) result = 0
for element in value:
result += compute_value_size(element)
return result + local_VarintSize(result) + tag_size
return PackedFieldSize
elif is_repeated:
def RepeatedFieldSize(value):
result = tag_size * len(value)
for element in value:
result += compute_value_size(element)
return result
return RepeatedFieldSize
else:
def FieldSize(value):
return tag_size + compute_value_size(value)
return FieldSize
def AppendSInt64(self, field_number, value): return SpecificSizer
"""Appends a 64-bit integer to our buffer, zigzag-encoded and then
varint-encoded.
"""
self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
self.AppendSInt64NoTag(value)
def AppendFixed32(self, field_number, unsigned_value):
"""Appends an unsigned 32-bit integer to our buffer, in little-endian
byte-order.
"""
self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
self.AppendFixed32NoTag(unsigned_value)
def AppendFixed64(self, field_number, unsigned_value): def _ModifiedSizer(compute_value_size, modify_value):
"""Appends an unsigned 64-bit integer to our buffer, in little-endian """Like SimpleSizer, but modify_value is invoked on each value before it is
byte-order. passed to compute_value_size. modify_value is typically ZigZagEncode."""
"""
self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
self.AppendFixed64NoTag(unsigned_value)
def AppendSFixed32(self, field_number, value): def SpecificSizer(field_number, is_repeated, is_packed):
"""Appends a signed 32-bit integer to our buffer, in little-endian tag_size = _TagSize(field_number)
byte-order. if is_packed:
""" local_VarintSize = _VarintSize
self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32) def PackedFieldSize(value):
self.AppendSFixed32NoTag(value) result = 0
for element in value:
result += compute_value_size(modify_value(element))
return result + local_VarintSize(result) + tag_size
return PackedFieldSize
elif is_repeated:
def RepeatedFieldSize(value):
result = tag_size * len(value)
for element in value:
result += compute_value_size(modify_value(element))
return result
return RepeatedFieldSize
else:
def FieldSize(value):
return tag_size + compute_value_size(modify_value(value))
return FieldSize
def AppendSFixed64(self, field_number, value): return SpecificSizer
"""Appends a signed 64-bit integer to our buffer, in little-endian
byte-order.
"""
self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
self.AppendSFixed64NoTag(value)
def AppendFloat(self, field_number, value):
"""Appends a floating-point number to our buffer."""
self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
self.AppendFloatNoTag(value)
def AppendDouble(self, field_number, value): def _FixedSizer(value_size):
"""Appends a double-precision floating-point number to our buffer.""" """Like _SimpleSizer except for a fixed-size field. The input is the size
self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64) of one value."""
self.AppendDoubleNoTag(value)
def AppendBool(self, field_number, value): def SpecificSizer(field_number, is_repeated, is_packed):
"""Appends a boolean to our buffer.""" tag_size = _TagSize(field_number)
self.AppendInt32(field_number, value) if is_packed:
local_VarintSize = _VarintSize
def PackedFieldSize(value):
result = len(value) * value_size
return result + local_VarintSize(result) + tag_size
return PackedFieldSize
elif is_repeated:
element_size = value_size + tag_size
def RepeatedFieldSize(value):
return len(value) * element_size
return RepeatedFieldSize
else:
field_size = value_size + tag_size
def FieldSize(value):
return field_size
return FieldSize
def AppendEnum(self, field_number, value): return SpecificSizer
"""Appends an enum value to our buffer."""
self.AppendInt32(field_number, value)
def AppendString(self, field_number, value):
"""Appends a length-prefixed unicode string, encoded as UTF-8 to our buffer,
with the length varint-encoded.
"""
self.AppendBytes(field_number, value.encode('utf-8'))
def AppendBytes(self, field_number, value): # ====================================================================
"""Appends a length-prefixed sequence of bytes to our buffer, with the # Here we declare a sizer constructor for each field type. Each "sizer
length varint-encoded. # constructor" is a function that takes (field_number, is_repeated, is_packed)
""" # as parameters and returns a sizer, which in turn takes a field value as
self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) # a parameter and returns its encoded size.
self._stream.AppendVarUInt32(len(value))
self._stream.AppendRawBytes(value)
# TODO(robinson): For AppendGroup() and AppendMessage(), we'd really like to
# avoid the extra string copy here. We can do so if we widen the Message
# interface to be able to serialize to a stream in addition to a string. The
# challenge when thinking ahead to the Python/C API implementation of Message
# is finding a stream-like Python thing to which we can write raw bytes
# from C. I'm not sure such a thing exists(?). (array.array is pretty much
# what we want, but it's not directly exposed in the Python/C API).
def AppendGroup(self, field_number, group): Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize)
"""Appends a group to our buffer.
"""
self.AppendTag(field_number, wire_format.WIRETYPE_START_GROUP)
self._stream.AppendRawBytes(group.SerializeToString())
self.AppendTag(field_number, wire_format.WIRETYPE_END_GROUP)
def AppendMessage(self, field_number, msg): UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize)
"""Appends a nested message to our buffer.
"""
self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
self._stream.AppendVarUInt32(msg.ByteSize())
self._stream.AppendRawBytes(msg.SerializeToString())
def AppendMessageSetItem(self, field_number, msg): SInt32Sizer = SInt64Sizer = _ModifiedSizer(
"""Appends an item using the message set wire format. _SignedVarintSize, wire_format.ZigZagEncode)
The message set message looks like this: Fixed32Sizer = SFixed32Sizer = FloatSizer = _FixedSizer(4)
message MessageSet { Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8)
repeated group Item = 1 {
required int32 type_id = 2; BoolSizer = _FixedSizer(1)
required string message = 3;
}
def StringSizer(field_number, is_repeated, is_packed):
"""Returns a sizer for a string field."""
tag_size = _TagSize(field_number)
local_VarintSize = _VarintSize
local_len = len
assert not is_packed
if is_repeated:
def RepeatedFieldSize(value):
result = tag_size * len(value)
for element in value:
l = local_len(element.encode('utf-8'))
result += local_VarintSize(l) + l
return result
return RepeatedFieldSize
else:
def FieldSize(value):
l = local_len(value.encode('utf-8'))
return tag_size + local_VarintSize(l) + l
return FieldSize
def BytesSizer(field_number, is_repeated, is_packed):
"""Returns a sizer for a bytes field."""
tag_size = _TagSize(field_number)
local_VarintSize = _VarintSize
local_len = len
assert not is_packed
if is_repeated:
def RepeatedFieldSize(value):
result = tag_size * len(value)
for element in value:
l = local_len(element)
result += local_VarintSize(l) + l
return result
return RepeatedFieldSize
else:
def FieldSize(value):
l = local_len(value)
return tag_size + local_VarintSize(l) + l
return FieldSize
def GroupSizer(field_number, is_repeated, is_packed):
"""Returns a sizer for a group field."""
tag_size = _TagSize(field_number) * 2
assert not is_packed
if is_repeated:
def RepeatedFieldSize(value):
result = tag_size * len(value)
for element in value:
result += element.ByteSize()
return result
return RepeatedFieldSize
else:
def FieldSize(value):
return tag_size + value.ByteSize()
return FieldSize
def MessageSizer(field_number, is_repeated, is_packed):
"""Returns a sizer for a message field."""
tag_size = _TagSize(field_number)
local_VarintSize = _VarintSize
assert not is_packed
if is_repeated:
def RepeatedFieldSize(value):
result = tag_size * len(value)
for element in value:
l = element.ByteSize()
result += local_VarintSize(l) + l
return result
return RepeatedFieldSize
else:
def FieldSize(value):
l = value.ByteSize()
return tag_size + local_VarintSize(l) + l
return FieldSize
# --------------------------------------------------------------------
# MessageSet is special.
def MessageSetItemSizer(field_number):
"""Returns a sizer for extensions of MessageSet.
The message set message looks like this:
message MessageSet {
repeated group Item = 1 {
required int32 type_id = 2;
required string message = 3;
} }
""" }
self.AppendTag(1, wire_format.WIRETYPE_START_GROUP) """
self.AppendInt32(2, field_number) static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) +
self.AppendMessage(3, msg) _TagSize(3))
self.AppendTag(1, wire_format.WIRETYPE_END_GROUP) local_VarintSize = _VarintSize
def AppendTag(self, field_number, wire_type): def FieldSize(value):
"""Appends a tag containing field number and wire type information.""" l = value.ByteSize()
self._stream.AppendVarUInt32(wire_format.PackTag(field_number, wire_type)) return static_size + local_VarintSize(l) + l
return FieldSize
# ====================================================================
# Encoders!
def _VarintEncoder():
"""Return an encoder for a basic varint value (does not include tag)."""
local_chr = chr
def EncodeVarint(write, value):
bits = value & 0x7f
value >>= 7
while value:
write(local_chr(0x80|bits))
bits = value & 0x7f
value >>= 7
return write(local_chr(bits))
return EncodeVarint
def _SignedVarintEncoder():
"""Return an encoder for a basic signed varint value (does not include
tag)."""
local_chr = chr
def EncodeSignedVarint(write, value):
if value < 0:
value += (1 << 64)
bits = value & 0x7f
value >>= 7
while value:
write(local_chr(0x80|bits))
bits = value & 0x7f
value >>= 7
return write(local_chr(bits))
return EncodeSignedVarint
_EncodeVarint = _VarintEncoder()
_EncodeSignedVarint = _SignedVarintEncoder()
def _VarintBytes(value):
"""Encode the given integer as a varint and return the bytes. This is only
called at startup time so it doesn't need to be fast."""
pieces = []
_EncodeVarint(pieces.append, value)
return "".join(pieces)
def TagBytes(field_number, wire_type):
"""Encode the given tag and return the bytes. Only called at startup."""
return _VarintBytes(wire_format.PackTag(field_number, wire_type))
# --------------------------------------------------------------------
# As with sizers (see above), we have a number of common encoder
# implementations.
def _SimpleEncoder(wire_type, encode_value, compute_value_size):
"""Return a constructor for an encoder for fields of a particular type.
Args:
wire_type: The field's wire type, for encoding tags.
encode_value: A function which encodes an individual value, e.g.
_EncodeVarint().
compute_value_size: A function which computes the size of an individual
value, e.g. _VarintSize().
"""
def SpecificEncoder(field_number, is_repeated, is_packed):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
def EncodePackedField(write, value):
write(tag_bytes)
size = 0
for element in value:
size += compute_value_size(element)
local_EncodeVarint(write, size)
for element in value:
encode_value(write, element)
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeRepeatedField(write, value):
for element in value:
write(tag_bytes)
encode_value(write, element)
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeField(write, value):
write(tag_bytes)
return encode_value(write, value)
return EncodeField
return SpecificEncoder
def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
"""Like SimpleEncoder but additionally invokes modify_value on every value
before passing it to encode_value. Usually modify_value is ZigZagEncode."""
def SpecificEncoder(field_number, is_repeated, is_packed):
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
def EncodePackedField(write, value):
write(tag_bytes)
size = 0
for element in value:
size += compute_value_size(modify_value(element))
local_EncodeVarint(write, size)
for element in value:
encode_value(write, modify_value(element))
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeRepeatedField(write, value):
for element in value:
write(tag_bytes)
encode_value(write, modify_value(element))
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeField(write, value):
write(tag_bytes)
return encode_value(write, modify_value(value))
return EncodeField
return SpecificEncoder
def _StructPackEncoder(wire_type, format):
"""Return a constructor for an encoder for a fixed-width field.
Args:
wire_type: The field's wire type, for encoding tags.
format: The format string to pass to struct.pack().
"""
value_size = struct.calcsize(format)
def SpecificEncoder(field_number, is_repeated, is_packed):
local_struct_pack = struct.pack
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
def EncodePackedField(write, value):
write(tag_bytes)
local_EncodeVarint(write, len(value) * value_size)
for element in value:
write(local_struct_pack(format, element))
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeRepeatedField(write, value):
for element in value:
write(tag_bytes)
write(local_struct_pack(format, element))
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_type)
def EncodeField(write, value):
write(tag_bytes)
return write(local_struct_pack(format, value))
return EncodeField
return SpecificEncoder
# ====================================================================
# Here we declare an encoder constructor for each field type. These work
# very similarly to sizer constructors, described earlier.
Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
UInt32Encoder = UInt64Encoder = _SimpleEncoder(
wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
SInt32Encoder = SInt64Encoder = _ModifiedEncoder(
wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize,
wire_format.ZigZagEncode)
# Note that Python conveniently guarantees that when using the '<' prefix on
# formats, they will also have the same size across all platforms (as opposed
# to without the prefix, where their sizes depend on the C compiler's basic
# type sizes).
Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
FloatEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<f')
DoubleEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<d')
def BoolEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a boolean field."""
false_byte = chr(0)
true_byte = chr(1)
if is_packed:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
def EncodePackedField(write, value):
write(tag_bytes)
local_EncodeVarint(write, len(value))
for element in value:
if element:
write(true_byte)
else:
write(false_byte)
return EncodePackedField
elif is_repeated:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
def EncodeRepeatedField(write, value):
for element in value:
write(tag_bytes)
if element:
write(true_byte)
else:
write(false_byte)
return EncodeRepeatedField
else:
tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
def EncodeField(write, value):
write(tag_bytes)
if value:
return write(true_byte)
return write(false_byte)
return EncodeField
def StringEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a string field."""
tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
local_len = len
assert not is_packed
if is_repeated:
def EncodeRepeatedField(write, value):
for element in value:
encoded = element.encode('utf-8')
write(tag)
local_EncodeVarint(write, local_len(encoded))
write(encoded)
return EncodeRepeatedField
else:
def EncodeField(write, value):
encoded = value.encode('utf-8')
write(tag)
local_EncodeVarint(write, local_len(encoded))
return write(encoded)
return EncodeField
def BytesEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a bytes field."""
tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
local_len = len
assert not is_packed
if is_repeated:
def EncodeRepeatedField(write, value):
for element in value:
write(tag)
local_EncodeVarint(write, local_len(element))
write(element)
return EncodeRepeatedField
else:
def EncodeField(write, value):
write(tag)
local_EncodeVarint(write, local_len(value))
return write(value)
return EncodeField
def GroupEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a group field."""
start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
assert not is_packed
if is_repeated:
def EncodeRepeatedField(write, value):
for element in value:
write(start_tag)
element._InternalSerialize(write)
write(end_tag)
return EncodeRepeatedField
else:
def EncodeField(write, value):
write(start_tag)
value._InternalSerialize(write)
return write(end_tag)
return EncodeField
def MessageEncoder(field_number, is_repeated, is_packed):
"""Returns an encoder for a message field."""
tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
local_EncodeVarint = _EncodeVarint
assert not is_packed
if is_repeated:
def EncodeRepeatedField(write, value):
for element in value:
write(tag)
local_EncodeVarint(write, element.ByteSize())
element._InternalSerialize(write)
return EncodeRepeatedField
else:
def EncodeField(write, value):
write(tag)
local_EncodeVarint(write, value.ByteSize())
return value._InternalSerialize(write)
return EncodeField
# --------------------------------------------------------------------
# As before, MessageSet is special.
def MessageSetItemEncoder(field_number):
"""Encoder for extensions of MessageSet.
The message set message looks like this:
message MessageSet {
repeated group Item = 1 {
required int32 type_id = 2;
required string message = 3;
}
}
"""
start_bytes = "".join([
TagBytes(1, wire_format.WIRETYPE_START_GROUP),
TagBytes(2, wire_format.WIRETYPE_VARINT),
_VarintBytes(field_number),
TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)])
end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
local_EncodeVarint = _EncodeVarint
def EncodeField(write, value):
write(start_bytes)
local_EncodeVarint(write, value.ByteSize())
value._InternalSerialize(write)
return write(end_bytes)
return EncodeField

View File

@ -1,286 +0,0 @@
#! /usr/bin/python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# http://code.google.com/p/protobuf/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test for google.protobuf.internal.encoder."""
__author__ = 'robinson@google.com (Will Robinson)'
import struct
import logging
import unittest
from google.protobuf.internal import wire_format
from google.protobuf.internal import encoder
from google.protobuf.internal import output_stream
from google.protobuf import message
import mox
class EncoderTest(unittest.TestCase):
def setUp(self):
self.mox = mox.Mox()
self.encoder = encoder.Encoder()
self.mock_stream = self.mox.CreateMock(output_stream.OutputStream)
self.mock_message = self.mox.CreateMock(message.Message)
self.encoder._stream = self.mock_stream
def PackTag(self, field_number, wire_type):
return wire_format.PackTag(field_number, wire_type)
def AppendScalarTestHelper(self, test_name, encoder_method,
expected_stream_method_name,
wire_type, field_value,
expected_value=None, expected_length=None,
is_tag_test=True):
"""Helper for testAppendScalars.
Calls one of the Encoder methods, and ensures that the Encoder
in turn makes the expected calls into its OutputStream.
Args:
test_name: Name of this test, used only for logging.
encoder_method: Callable on self.encoder. This is the Encoder
method we're testing. If is_tag_test=True, the encoder method
accepts a field_number and field_value. if is_tag_test=False,
the encoder method accepts a field_value.
expected_stream_method_name: (string) Name of the OutputStream
method we expect Encoder to call to actually put the value
on the wire.
wire_type: The WIRETYPE_* constant we expect encoder to
use in the specified encoder_method.
field_value: The value we're trying to encode. Passed
into encoder_method.
expected_value: The value we expect Encoder to pass into
the OutputStream method. If None, we expect field_value
to pass through unmodified.
expected_length: The length we expect Encoder to pass to the
AppendVarUInt32 method. If None we expect the length of the
field_value.
is_tag_test: A Boolean. If True (the default), we append the
the packed field number and wire_type to the stream before
the field value.
"""
if expected_value is None:
expected_value = field_value
logging.info('Testing %s scalar output.\n'
'Calling %r(%r), and expecting that to call the '
'stream method %s(%r).' % (
test_name, encoder_method, field_value,
expected_stream_method_name, expected_value))
if is_tag_test:
field_number = 10
# Should first append the field number and type information.
self.mock_stream.AppendVarUInt32(self.PackTag(field_number, wire_type))
# If we're length-delimited, we should then append the length.
if wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
if expected_length is None:
expected_length = len(field_value)
self.mock_stream.AppendVarUInt32(expected_length)
# Should then append the value itself.
# We have to use names instead of methods to work around some
# mox weirdness. (ResetAll() is overzealous).
expected_stream_method = getattr(self.mock_stream,
expected_stream_method_name)
expected_stream_method(expected_value)
self.mox.ReplayAll()
if is_tag_test:
encoder_method(field_number, field_value)
else:
encoder_method(field_value)
self.mox.VerifyAll()
self.mox.ResetAll()
VAL = 1.125 # Perfectly representable as a float (no rounding error).
LITTLE_FLOAT_VAL = '\x00\x00\x90?'
LITTLE_DOUBLE_VAL = '\x00\x00\x00\x00\x00\x00\xf2?'
def testAppendScalars(self):
utf8_bytes = '\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82'
utf8_string = unicode(utf8_bytes, 'utf-8')
scalar_tests = [
['int32', self.encoder.AppendInt32, 'AppendVarint32',
wire_format.WIRETYPE_VARINT, 0],
['int64', self.encoder.AppendInt64, 'AppendVarint64',
wire_format.WIRETYPE_VARINT, 0],
['uint32', self.encoder.AppendUInt32, 'AppendVarUInt32',
wire_format.WIRETYPE_VARINT, 0],
['uint64', self.encoder.AppendUInt64, 'AppendVarUInt64',
wire_format.WIRETYPE_VARINT, 0],
['fixed32', self.encoder.AppendFixed32, 'AppendLittleEndian32',
wire_format.WIRETYPE_FIXED32, 0],
['fixed64', self.encoder.AppendFixed64, 'AppendLittleEndian64',
wire_format.WIRETYPE_FIXED64, 0],
['sfixed32', self.encoder.AppendSFixed32, 'AppendLittleEndian32',
wire_format.WIRETYPE_FIXED32, -1, 0xffffffff],
['sfixed64', self.encoder.AppendSFixed64, 'AppendLittleEndian64',
wire_format.WIRETYPE_FIXED64, -1, 0xffffffffffffffff],
['float', self.encoder.AppendFloat, 'AppendRawBytes',
wire_format.WIRETYPE_FIXED32, self.VAL, self.LITTLE_FLOAT_VAL],
['double', self.encoder.AppendDouble, 'AppendRawBytes',
wire_format.WIRETYPE_FIXED64, self.VAL, self.LITTLE_DOUBLE_VAL],
['bool', self.encoder.AppendBool, 'AppendVarint32',
wire_format.WIRETYPE_VARINT, False],
['enum', self.encoder.AppendEnum, 'AppendVarint32',
wire_format.WIRETYPE_VARINT, 0],
['string', self.encoder.AppendString, 'AppendRawBytes',
wire_format.WIRETYPE_LENGTH_DELIMITED,
"You're in a maze of twisty little passages, all alike."],
['utf8-string', self.encoder.AppendString, 'AppendRawBytes',
wire_format.WIRETYPE_LENGTH_DELIMITED, utf8_string,
utf8_bytes, len(utf8_bytes)],
# We test zigzag encoding routines more extensively below.
['sint32', self.encoder.AppendSInt32, 'AppendVarUInt32',
wire_format.WIRETYPE_VARINT, -1, 1],
['sint64', self.encoder.AppendSInt64, 'AppendVarUInt64',
wire_format.WIRETYPE_VARINT, -1, 1],
]
# Ensure that we're testing different Encoder methods and using
# different test names in all test cases above.
self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests)))
self.assert_(len(scalar_tests) >= len(set(t[1] for t in scalar_tests)))
for args in scalar_tests:
self.AppendScalarTestHelper(*args)
def testAppendScalarsWithoutTags(self):
scalar_no_tag_tests = [
['int32', self.encoder.AppendInt32NoTag, 'AppendVarint32', None, 0],
['int64', self.encoder.AppendInt64NoTag, 'AppendVarint64', None, 0],
['uint32', self.encoder.AppendUInt32NoTag, 'AppendVarUInt32', None, 0],
['uint64', self.encoder.AppendUInt64NoTag, 'AppendVarUInt64', None, 0],
['fixed32', self.encoder.AppendFixed32NoTag,
'AppendLittleEndian32', None, 0],
['fixed64', self.encoder.AppendFixed64NoTag,
'AppendLittleEndian64', None, 0],
['sfixed32', self.encoder.AppendSFixed32NoTag,
'AppendLittleEndian32', None, 0],
['sfixed64', self.encoder.AppendSFixed64NoTag,
'AppendLittleEndian64', None, 0],
['float', self.encoder.AppendFloatNoTag,
'AppendRawBytes', None, self.VAL, self.LITTLE_FLOAT_VAL],
['double', self.encoder.AppendDoubleNoTag,
'AppendRawBytes', None, self.VAL, self.LITTLE_DOUBLE_VAL],
['bool', self.encoder.AppendBoolNoTag, 'AppendVarint32', None, 0],
['enum', self.encoder.AppendEnumNoTag, 'AppendVarint32', None, 0],
['sint32', self.encoder.AppendSInt32NoTag,
'AppendVarUInt32', None, -1, 1],
['sint64', self.encoder.AppendSInt64NoTag,
'AppendVarUInt64', None, -1, 1],
]
self.assertEqual(len(scalar_no_tag_tests),
len(set(t[0] for t in scalar_no_tag_tests)))
self.assert_(len(scalar_no_tag_tests) >=
len(set(t[1] for t in scalar_no_tag_tests)))
for args in scalar_no_tag_tests:
# For no tag tests, the wire_type is not used, so we put in None.
self.AppendScalarTestHelper(is_tag_test=False, *args)
def testAppendGroup(self):
field_number = 23
# Should first append the start-group marker.
self.mock_stream.AppendVarUInt32(
self.PackTag(field_number, wire_format.WIRETYPE_START_GROUP))
# Should then serialize itself.
self.mock_message.SerializeToString().AndReturn('foo')
self.mock_stream.AppendRawBytes('foo')
# Should finally append the end-group marker.
self.mock_stream.AppendVarUInt32(
self.PackTag(field_number, wire_format.WIRETYPE_END_GROUP))
self.mox.ReplayAll()
self.encoder.AppendGroup(field_number, self.mock_message)
self.mox.VerifyAll()
def testAppendMessage(self):
field_number = 23
byte_size = 42
# Should first append the field number and type information.
self.mock_stream.AppendVarUInt32(
self.PackTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED))
# Should then append its length.
self.mock_message.ByteSize().AndReturn(byte_size)
self.mock_stream.AppendVarUInt32(byte_size)
# Should then serialize itself to the encoder.
self.mock_message.SerializeToString().AndReturn('foo')
self.mock_stream.AppendRawBytes('foo')
self.mox.ReplayAll()
self.encoder.AppendMessage(field_number, self.mock_message)
self.mox.VerifyAll()
def testAppendMessageSetItem(self):
field_number = 23
byte_size = 42
# Should first append the field number and type information.
self.mock_stream.AppendVarUInt32(
self.PackTag(1, wire_format.WIRETYPE_START_GROUP))
self.mock_stream.AppendVarUInt32(
self.PackTag(2, wire_format.WIRETYPE_VARINT))
self.mock_stream.AppendVarint32(field_number)
self.mock_stream.AppendVarUInt32(
self.PackTag(3, wire_format.WIRETYPE_LENGTH_DELIMITED))
# Should then append its length.
self.mock_message.ByteSize().AndReturn(byte_size)
self.mock_stream.AppendVarUInt32(byte_size)
# Should then serialize itself to the encoder.
self.mock_message.SerializeToString().AndReturn('foo')
self.mock_stream.AppendRawBytes('foo')
self.mock_stream.AppendVarUInt32(
self.PackTag(1, wire_format.WIRETYPE_END_GROUP))
self.mox.ReplayAll()
self.encoder.AppendMessageSetItem(field_number, self.mock_message)
self.mox.VerifyAll()
def testAppendSFixed(self):
# Most of our bounds-checking is done in output_stream.py,
# but encoder.py is responsible for transforming signed
# fixed-width integers into unsigned ones, so we test here
# to ensure that we're not losing any entropy when we do
# that conversion.
field_number = 10
self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32,
10, wire_format.UINT32_MAX + 1)
self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32,
10, -(1 << 32))
self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64,
10, wire_format.UINT64_MAX + 1)
self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64,
10, -(1 << 64))
if __name__ == '__main__':
unittest.main()

View File

@ -35,15 +35,20 @@
# indirect testing of the protocol compiler output. # indirect testing of the protocol compiler output.
"""Unittest that directly tests the output of the pure-Python protocol """Unittest that directly tests the output of the pure-Python protocol
compiler. See //net/proto2/internal/reflection_test.py for a test which compiler. See //google/protobuf/reflection_test.py for a test which
further ensures that we can use Python protocol message objects as we expect. further ensures that we can use Python protocol message objects as we expect.
""" """
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'robinson@google.com (Will Robinson)'
import unittest import unittest
from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2 from google.protobuf import unittest_pb2
from google.protobuf import unittest_no_generic_services_pb2
MAX_EXTENSION = 536870912
class GeneratorTest(unittest.TestCase): class GeneratorTest(unittest.TestCase):
@ -71,6 +76,31 @@ class GeneratorTest(unittest.TestCase):
self.assertEqual(3, proto.BAZ) self.assertEqual(3, proto.BAZ)
self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
def testExtremeDefaultValues(self):
message = unittest_pb2.TestExtremeDefaultValues()
self.assertEquals(float('inf'), message.inf_double)
self.assertEquals(float('-inf'), message.neg_inf_double)
self.assert_(message.nan_double != message.nan_double)
self.assertEquals(float('inf'), message.inf_float)
self.assertEquals(float('-inf'), message.neg_inf_float)
self.assert_(message.nan_float != message.nan_float)
def testHasDefaultValues(self):
desc = unittest_pb2.TestAllTypes.DESCRIPTOR
expected_has_default_by_name = {
'optional_int32': False,
'repeated_int32': False,
'optional_nested_message': False,
'default_int32': True,
}
has_default_by_name = dict(
[(f.name, f.has_default_value)
for f in desc.fields
if f.name in expected_has_default_by_name])
self.assertEqual(expected_has_default_by_name, has_default_by_name)
def testContainingTypeBehaviorForExtensions(self): def testContainingTypeBehaviorForExtensions(self):
self.assertEqual(unittest_pb2.optional_int32_extension.containing_type, self.assertEqual(unittest_pb2.optional_int32_extension.containing_type,
unittest_pb2.TestAllExtensions.DESCRIPTOR) unittest_pb2.TestAllExtensions.DESCRIPTOR)
@ -95,6 +125,81 @@ class GeneratorTest(unittest.TestCase):
proto = unittest_mset_pb2.TestMessageSet() proto = unittest_mset_pb2.TestMessageSet()
self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format) self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format)
def testNestedTypes(self):
self.assertEquals(
set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types),
set([
unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
unittest_pb2.TestAllTypes.OptionalGroup.DESCRIPTOR,
unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR,
]))
self.assertEqual(unittest_pb2.TestEmptyMessage.DESCRIPTOR.nested_types, [])
self.assertEqual(
unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.nested_types, [])
def testContainingType(self):
self.assertTrue(
unittest_pb2.TestEmptyMessage.DESCRIPTOR.containing_type is None)
self.assertTrue(
unittest_pb2.TestAllTypes.DESCRIPTOR.containing_type is None)
self.assertEqual(
unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type,
unittest_pb2.TestAllTypes.DESCRIPTOR)
self.assertEqual(
unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type,
unittest_pb2.TestAllTypes.DESCRIPTOR)
self.assertEqual(
unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR.containing_type,
unittest_pb2.TestAllTypes.DESCRIPTOR)
def testContainingTypeInEnumDescriptor(self):
self.assertTrue(unittest_pb2._FOREIGNENUM.containing_type is None)
self.assertEqual(unittest_pb2._TESTALLTYPES_NESTEDENUM.containing_type,
unittest_pb2.TestAllTypes.DESCRIPTOR)
def testPackage(self):
self.assertEqual(
unittest_pb2.TestAllTypes.DESCRIPTOR.file.package,
'protobuf_unittest')
desc = unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR
self.assertEqual(desc.file.package, 'protobuf_unittest')
self.assertEqual(
unittest_import_pb2.ImportMessage.DESCRIPTOR.file.package,
'protobuf_unittest_import')
self.assertEqual(
unittest_pb2._FOREIGNENUM.file.package, 'protobuf_unittest')
self.assertEqual(
unittest_pb2._TESTALLTYPES_NESTEDENUM.file.package,
'protobuf_unittest')
self.assertEqual(
unittest_import_pb2._IMPORTENUM.file.package,
'protobuf_unittest_import')
def testExtensionRange(self):
self.assertEqual(
unittest_pb2.TestAllTypes.DESCRIPTOR.extension_ranges, [])
self.assertEqual(
unittest_pb2.TestAllExtensions.DESCRIPTOR.extension_ranges,
[(1, MAX_EXTENSION)])
self.assertEqual(
unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR.extension_ranges,
[(42, 43), (4143, 4244), (65536, MAX_EXTENSION)])
def testFileDescriptor(self):
self.assertEqual(unittest_pb2.DESCRIPTOR.name,
'google/protobuf/unittest.proto')
self.assertEqual(unittest_pb2.DESCRIPTOR.package, 'protobuf_unittest')
self.assertFalse(unittest_pb2.DESCRIPTOR.serialized_pb is None)
def testNoGenericServices(self):
# unittest_no_generic_services.proto should contain defs for everything
# except services.
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage"))
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO"))
self.assertTrue(hasattr(unittest_no_generic_services_pb2, "test_extension"))
self.assertFalse(hasattr(unittest_no_generic_services_pb2, "TestService"))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,338 +0,0 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# http://code.google.com/p/protobuf/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""InputStream is the primitive interface for reading bits from the wire.
All protocol buffer deserialization can be expressed in terms of
the InputStream primitives provided here.
"""
__author__ = 'robinson@google.com (Will Robinson)'
import array
import struct
from google.protobuf import message
from google.protobuf.internal import wire_format
# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
# that the interface is strongly inspired by CodedInputStream from the C++
# proto2 implementation.
class InputStreamBuffer(object):
"""Contains all logic for reading bits, and dealing with stream position.
If an InputStream method ever raises an exception, the stream is left
in an indeterminate state and is not safe for further use.
"""
def __init__(self, s):
# What we really want is something like array('B', s), where elements we
# read from the array are already given to us as one-byte integers. BUT
# using array() instead of buffer() would force full string copies to result
# from each GetSubBuffer() call.
#
# So, if the N serialized bytes of a single protocol buffer object are
# split evenly between 2 child messages, and so on recursively, using
# array('B', s) instead of buffer() would incur an additional N*logN bytes
# copied during deserialization.
#
# The higher constant overhead of having to ord() for every byte we read
# from the buffer in _ReadVarintHelper() could definitely lead to worse
# performance in many real-world scenarios, even if the asymptotic
# complexity is better. However, our real answer is that the mythical
# Python/C extension module output mode for the protocol compiler will
# be blazing-fast and will eliminate most use of this class anyway.
self._buffer = buffer(s)
self._pos = 0
def EndOfStream(self):
"""Returns true iff we're at the end of the stream.
If this returns true, then a call to any other InputStream method
will raise an exception.
"""
return self._pos >= len(self._buffer)
def Position(self):
"""Returns the current position in the stream, or equivalently, the
number of bytes read so far.
"""
return self._pos
def GetSubBuffer(self, size=None):
"""Returns a sequence-like object that represents a portion of our
underlying sequence.
Position 0 in the returned object corresponds to self.Position()
in this stream.
If size is specified, then the returned object ends after the
next "size" bytes in this stream. If size is not specified,
then the returned object ends at the end of this stream.
We guarantee that the returned object R supports the Python buffer
interface (and thus that the call buffer(R) will work).
Note that the returned buffer is read-only.
The intended use for this method is for nested-message and nested-group
deserialization, where we want to make a recursive MergeFromString()
call on the portion of the original sequence that contains the serialized
nested message. (And we'd like to do so without making unnecessary string
copies).
REQUIRES: size is nonnegative.
"""
# Note that buffer() doesn't perform any actual string copy.
if size is None:
return buffer(self._buffer, self._pos)
else:
if size < 0:
raise message.DecodeError('Negative size %d' % size)
return buffer(self._buffer, self._pos, size)
def SkipBytes(self, num_bytes):
"""Skip num_bytes bytes ahead, or go to the end of the stream, whichever
comes first.
REQUIRES: num_bytes is nonnegative.
"""
if num_bytes < 0:
raise message.DecodeError('Negative num_bytes %d' % num_bytes)
self._pos += num_bytes
self._pos = min(self._pos, len(self._buffer))
def ReadBytes(self, size):
"""Reads up to 'size' bytes from the stream, stopping early
only if we reach the end of the stream. Returns the bytes read
as a string.
"""
if size < 0:
raise message.DecodeError('Negative size %d' % size)
s = (self._buffer[self._pos : self._pos + size])
self._pos += len(s) # Only advance by the number of bytes actually read.
return s
def ReadLittleEndian32(self):
"""Interprets the next 4 bytes of the stream as a little-endian
encoded, unsiged 32-bit integer, and returns that integer.
"""
try:
i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN,
self._buffer[self._pos : self._pos + 4])
self._pos += 4
return i[0] # unpack() result is a 1-element tuple.
except struct.error, e:
raise message.DecodeError(e)
def ReadLittleEndian64(self):
"""Interprets the next 8 bytes of the stream as a little-endian
encoded, unsiged 64-bit integer, and returns that integer.
"""
try:
i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN,
self._buffer[self._pos : self._pos + 8])
self._pos += 8
return i[0] # unpack() result is a 1-element tuple.
except struct.error, e:
raise message.DecodeError(e)
def ReadVarint32(self):
"""Reads a varint from the stream, interprets this varint
as a signed, 32-bit integer, and returns the integer.
"""
i = self.ReadVarint64()
if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX:
raise message.DecodeError('Value out of range for int32: %d' % i)
return int(i)
def ReadVarUInt32(self):
"""Reads a varint from the stream, interprets this varint
as an unsigned, 32-bit integer, and returns the integer.
"""
i = self.ReadVarUInt64()
if i > wire_format.UINT32_MAX:
raise message.DecodeError('Value out of range for uint32: %d' % i)
return i
def ReadVarint64(self):
"""Reads a varint from the stream, interprets this varint
as a signed, 64-bit integer, and returns the integer.
"""
i = self.ReadVarUInt64()
if i > wire_format.INT64_MAX:
i -= (1 << 64)
return i
def ReadVarUInt64(self):
"""Reads a varint from the stream, interprets this varint
as an unsigned, 64-bit integer, and returns the integer.
"""
i = self._ReadVarintHelper()
if not 0 <= i <= wire_format.UINT64_MAX:
raise message.DecodeError('Value out of range for uint64: %d' % i)
return i
def _ReadVarintHelper(self):
"""Helper for the various varint-reading methods above.
Reads an unsigned, varint-encoded integer from the stream and
returns this integer.
Does no bounds checking except to ensure that we read at most as many bytes
as could possibly be present in a varint-encoded 64-bit number.
"""
result = 0
shift = 0
while 1:
if shift >= 64:
raise message.DecodeError('Too many bytes when decoding varint.')
try:
b = ord(self._buffer[self._pos])
except IndexError:
raise message.DecodeError('Truncated varint.')
self._pos += 1
result |= ((b & 0x7f) << shift)
shift += 7
if not (b & 0x80):
return result
class InputStreamArray(object):
"""Contains all logic for reading bits, and dealing with stream position.
If an InputStream method ever raises an exception, the stream is left
in an indeterminate state and is not safe for further use.
This alternative to InputStreamBuffer is used in environments where buffer()
is unavailble, such as Google App Engine.
"""
def __init__(self, s):
self._buffer = array.array('B', s)
self._pos = 0
def EndOfStream(self):
return self._pos >= len(self._buffer)
def Position(self):
return self._pos
def GetSubBuffer(self, size=None):
if size is None:
return self._buffer[self._pos : ].tostring()
else:
if size < 0:
raise message.DecodeError('Negative size %d' % size)
return self._buffer[self._pos : self._pos + size].tostring()
def SkipBytes(self, num_bytes):
if num_bytes < 0:
raise message.DecodeError('Negative num_bytes %d' % num_bytes)
self._pos += num_bytes
self._pos = min(self._pos, len(self._buffer))
def ReadBytes(self, size):
if size < 0:
raise message.DecodeError('Negative size %d' % size)
s = self._buffer[self._pos : self._pos + size].tostring()
self._pos += len(s) # Only advance by the number of bytes actually read.
return s
def ReadLittleEndian32(self):
try:
i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN,
self._buffer[self._pos : self._pos + 4])
self._pos += 4
return i[0] # unpack() result is a 1-element tuple.
except struct.error, e:
raise message.DecodeError(e)
def ReadLittleEndian64(self):
try:
i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN,
self._buffer[self._pos : self._pos + 8])
self._pos += 8
return i[0] # unpack() result is a 1-element tuple.
except struct.error, e:
raise message.DecodeError(e)
def ReadVarint32(self):
i = self.ReadVarint64()
if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX:
raise message.DecodeError('Value out of range for int32: %d' % i)
return int(i)
def ReadVarUInt32(self):
i = self.ReadVarUInt64()
if i > wire_format.UINT32_MAX:
raise message.DecodeError('Value out of range for uint32: %d' % i)
return i
def ReadVarint64(self):
i = self.ReadVarUInt64()
if i > wire_format.INT64_MAX:
i -= (1 << 64)
return i
def ReadVarUInt64(self):
i = self._ReadVarintHelper()
if not 0 <= i <= wire_format.UINT64_MAX:
raise message.DecodeError('Value out of range for uint64: %d' % i)
return i
def _ReadVarintHelper(self):
result = 0
shift = 0
while 1:
if shift >= 64:
raise message.DecodeError('Too many bytes when decoding varint.')
try:
b = self._buffer[self._pos]
except IndexError:
raise message.DecodeError('Truncated varint.')
self._pos += 1
result |= ((b & 0x7f) << shift)
shift += 7
if not (b & 0x80):
return result
try:
buffer('')
InputStream = InputStreamBuffer
except NotImplementedError:
# Google App Engine: dev_appserver.py
InputStream = InputStreamArray
except RuntimeError:
# Google App Engine: production
InputStream = InputStreamArray

View File

@ -1,314 +0,0 @@
#! /usr/bin/python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# http://code.google.com/p/protobuf/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test for google.protobuf.internal.input_stream."""
__author__ = 'robinson@google.com (Will Robinson)'
import unittest
from google.protobuf import message
from google.protobuf.internal import wire_format
from google.protobuf.internal import input_stream
class InputStreamBufferTest(unittest.TestCase):
def setUp(self):
self.__original_input_stream = input_stream.InputStream
input_stream.InputStream = input_stream.InputStreamBuffer
def tearDown(self):
input_stream.InputStream = self.__original_input_stream
def testEndOfStream(self):
stream = input_stream.InputStream('abcd')
self.assertFalse(stream.EndOfStream())
self.assertEqual('abcd', stream.ReadBytes(10))
self.assertTrue(stream.EndOfStream())
def testPosition(self):
stream = input_stream.InputStream('abcd')
self.assertEqual(0, stream.Position())
self.assertEqual(0, stream.Position()) # No side-effects.
stream.ReadBytes(1)
self.assertEqual(1, stream.Position())
stream.ReadBytes(1)
self.assertEqual(2, stream.Position())
stream.ReadBytes(10)
self.assertEqual(4, stream.Position()) # Can't go past end of stream.
def testGetSubBuffer(self):
stream = input_stream.InputStream('abcd')
# Try leaving out the size.
self.assertEqual('abcd', str(stream.GetSubBuffer()))
stream.SkipBytes(1)
# GetSubBuffer() always starts at current size.
self.assertEqual('bcd', str(stream.GetSubBuffer()))
# Try 0-size.
self.assertEqual('', str(stream.GetSubBuffer(0)))
# Negative sizes should raise an error.
self.assertRaises(message.DecodeError, stream.GetSubBuffer, -1)
# Positive sizes should work as expected.
self.assertEqual('b', str(stream.GetSubBuffer(1)))
self.assertEqual('bc', str(stream.GetSubBuffer(2)))
# Sizes longer than remaining bytes in the buffer should
# return the whole remaining buffer.
self.assertEqual('bcd', str(stream.GetSubBuffer(1000)))
def testSkipBytes(self):
stream = input_stream.InputStream('')
# Skipping bytes when at the end of stream
# should have no effect.
stream.SkipBytes(0)
stream.SkipBytes(1)
stream.SkipBytes(2)
self.assertTrue(stream.EndOfStream())
self.assertEqual(0, stream.Position())
# Try skipping within a stream.
stream = input_stream.InputStream('abcd')
self.assertEqual(0, stream.Position())
stream.SkipBytes(1)
self.assertEqual(1, stream.Position())
stream.SkipBytes(10) # Can't skip past the end.
self.assertEqual(4, stream.Position())
# Ensure that a negative skip raises an exception.
stream = input_stream.InputStream('abcd')
stream.SkipBytes(1)
self.assertRaises(message.DecodeError, stream.SkipBytes, -1)
def testReadBytes(self):
s = 'abcd'
# Also test going past the total stream length.
for i in range(len(s) + 10):
stream = input_stream.InputStream(s)
self.assertEqual(s[:i], stream.ReadBytes(i))
self.assertEqual(min(i, len(s)), stream.Position())
stream = input_stream.InputStream(s)
self.assertRaises(message.DecodeError, stream.ReadBytes, -1)
def EnsureFailureOnEmptyStream(self, input_stream_method):
"""Helper for integer-parsing tests below.
Ensures that the given InputStream method raises a DecodeError
if called on a stream with no bytes remaining.
"""
stream = input_stream.InputStream('')
self.assertRaises(message.DecodeError, input_stream_method, stream)
def testReadLittleEndian32(self):
self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian32)
s = ''
# Read 0.
s += '\x00\x00\x00\x00'
# Read 1.
s += '\x01\x00\x00\x00'
# Read a bunch of different bytes.
s += '\x01\x02\x03\x04'
# Read max unsigned 32-bit int.
s += '\xff\xff\xff\xff'
# Try a read with fewer than 4 bytes left in the stream.
s += '\x00\x00\x00'
stream = input_stream.InputStream(s)
self.assertEqual(0, stream.ReadLittleEndian32())
self.assertEqual(4, stream.Position())
self.assertEqual(1, stream.ReadLittleEndian32())
self.assertEqual(8, stream.Position())
self.assertEqual(0x04030201, stream.ReadLittleEndian32())
self.assertEqual(12, stream.Position())
self.assertEqual(wire_format.UINT32_MAX, stream.ReadLittleEndian32())
self.assertEqual(16, stream.Position())
self.assertRaises(message.DecodeError, stream.ReadLittleEndian32)
def testReadLittleEndian64(self):
self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian64)
s = ''
# Read 0.
s += '\x00\x00\x00\x00\x00\x00\x00\x00'
# Read 1.
s += '\x01\x00\x00\x00\x00\x00\x00\x00'
# Read a bunch of different bytes.
s += '\x01\x02\x03\x04\x05\x06\x07\x08'
# Read max unsigned 64-bit int.
s += '\xff\xff\xff\xff\xff\xff\xff\xff'
# Try a read with fewer than 8 bytes left in the stream.
s += '\x00\x00\x00'
stream = input_stream.InputStream(s)
self.assertEqual(0, stream.ReadLittleEndian64())
self.assertEqual(8, stream.Position())
self.assertEqual(1, stream.ReadLittleEndian64())
self.assertEqual(16, stream.Position())
self.assertEqual(0x0807060504030201, stream.ReadLittleEndian64())
self.assertEqual(24, stream.Position())
self.assertEqual(wire_format.UINT64_MAX, stream.ReadLittleEndian64())
self.assertEqual(32, stream.Position())
self.assertRaises(message.DecodeError, stream.ReadLittleEndian64)
def ReadVarintSuccessTestHelper(self, varints_and_ints, read_method):
"""Helper for tests below that test successful reads of various varints.
Args:
varints_and_ints: Iterable of (str, integer) pairs, where the string
gives the wire encoding and the integer gives the value we expect
to be returned by the read_method upon encountering this string.
read_method: Unbound InputStream method that is capable of reading
the encoded strings provided in the first elements of varints_and_ints.
"""
s = ''.join(s for s, i in varints_and_ints)
stream = input_stream.InputStream(s)
expected_pos = 0
self.assertEqual(expected_pos, stream.Position())
for s, expected_int in varints_and_ints:
self.assertEqual(expected_int, read_method(stream))
expected_pos += len(s)
self.assertEqual(expected_pos, stream.Position())
def testReadVarint32Success(self):
varints_and_ints = [
('\x00', 0),
('\x01', 1),
('\x7f', 127),
('\x80\x01', 128),
('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1),
('\xff\xff\xff\xff\x07', wire_format.INT32_MAX),
('\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01', wire_format.INT32_MIN),
]
self.ReadVarintSuccessTestHelper(varints_and_ints,
input_stream.InputStream.ReadVarint32)
def testReadVarint32Failure(self):
self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint32)
# Try and fail to read INT32_MAX + 1.
s = '\x80\x80\x80\x80\x08'
stream = input_stream.InputStream(s)
self.assertRaises(message.DecodeError, stream.ReadVarint32)
# Try and fail to read INT32_MIN - 1.
s = '\xfe\xff\xff\xff\xf7\xff\xff\xff\xff\x01'
stream = input_stream.InputStream(s)
self.assertRaises(message.DecodeError, stream.ReadVarint32)
# Try and fail to read something that looks like
# a varint with more than 10 bytes.
s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
stream = input_stream.InputStream(s)
self.assertRaises(message.DecodeError, stream.ReadVarint32)
def testReadVarUInt32Success(self):
varints_and_ints = [
('\x00', 0),
('\x01', 1),
('\x7f', 127),
('\x80\x01', 128),
('\xff\xff\xff\xff\x0f', wire_format.UINT32_MAX),
]
self.ReadVarintSuccessTestHelper(varints_and_ints,
input_stream.InputStream.ReadVarUInt32)
def testReadVarUInt32Failure(self):
self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt32)
# Try and fail to read UINT32_MAX + 1
s = '\x80\x80\x80\x80\x10'
stream = input_stream.InputStream(s)
self.assertRaises(message.DecodeError, stream.ReadVarUInt32)
# Try and fail to read something that looks like
# a varint with more than 10 bytes.
s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
stream = input_stream.InputStream(s)
self.assertRaises(message.DecodeError, stream.ReadVarUInt32)
def testReadVarint64Success(self):
varints_and_ints = [
('\x00', 0),
('\x01', 1),
('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1),
('\x7f', 127),
('\x80\x01', 128),
('\xff\xff\xff\xff\xff\xff\xff\xff\x7f', wire_format.INT64_MAX),
('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', wire_format.INT64_MIN),
]
self.ReadVarintSuccessTestHelper(varints_and_ints,
input_stream.InputStream.ReadVarint64)
def testReadVarint64Failure(self):
self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint64)
# Try and fail to read something with the mythical 64th bit set.
s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02'
stream = input_stream.InputStream(s)
self.assertRaises(message.DecodeError, stream.ReadVarint64)
# Try and fail to read something that looks like
# a varint with more than 10 bytes.
s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
stream = input_stream.InputStream(s)
self.assertRaises(message.DecodeError, stream.ReadVarint64)
def testReadVarUInt64Success(self):
varints_and_ints = [
('\x00', 0),
('\x01', 1),
('\x7f', 127),
('\x80\x01', 128),
('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', 1 << 63),
]
self.ReadVarintSuccessTestHelper(varints_and_ints,
input_stream.InputStream.ReadVarUInt64)
def testReadVarUInt64Failure(self):
self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt64)
# Try and fail to read something with the mythical 64th bit set.
s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02'
stream = input_stream.InputStream(s)
self.assertRaises(message.DecodeError, stream.ReadVarUInt64)
# Try and fail to read something that looks like
# a varint with more than 10 bytes.
s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00'
stream = input_stream.InputStream(s)
self.assertRaises(message.DecodeError, stream.ReadVarUInt64)
class InputStreamArrayTest(InputStreamBufferTest):
def setUp(self):
# Test InputStreamArray against the same tests in InputStreamBuffer
self.__original_input_stream = input_stream.InputStream
input_stream.InputStream = input_stream.InputStreamArray
def tearDown(self):
input_stream.InputStream = self.__original_input_stream
if __name__ == '__main__':
unittest.main()

View File

@ -39,22 +39,34 @@ __author__ = 'robinson@google.com (Will Robinson)'
class MessageListener(object): class MessageListener(object):
"""Listens for transitions to nonempty and for invalidations of cached """Listens for modifications made to a message. Meant to be registered via
byte sizes. Meant to be registered via Message._SetListener(). Message._SetListener().
Attributes:
dirty: If True, then calling Modified() would be a no-op. This can be
used to avoid these calls entirely in the common case.
""" """
def TransitionToNonempty(self): def Modified(self):
"""Called the *first* time that this message becomes nonempty. """Called every time the message is modified in such a way that the parent
Implementations are free (but not required) to call this method multiple message may need to be updated. This currently means either:
times after the message has become nonempty. (a) The message was modified for the first time, so the parent message
""" should henceforth mark the message as present.
raise NotImplementedError (b) The message's cached byte size became dirty -- i.e. the message was
modified for the first time after a previous call to ByteSize().
Therefore the parent should also mark its byte size as dirty.
Note that (a) implies (b), since new objects start out with a client cached
size (zero). However, we document (a) explicitly because it is important.
def ByteSizeDirty(self): Modified() will *only* be called in response to one of these two events --
"""Called *every* time the cached byte size value not every time the sub-message is modified.
for this object is invalidated (transitions from being
"clean" to "dirty"). Note that if the listener's |dirty| attribute is true, then calling
Modified at the moment would be a no-op, so it can be skipped. Performance-
sensitive callers should check this attribute directly before calling since
it will be true most of the time.
""" """
raise NotImplementedError raise NotImplementedError
@ -62,8 +74,5 @@ class NullMessageListener(object):
"""No-op MessageListener implementation.""" """No-op MessageListener implementation."""
def TransitionToNonempty(self): def Modified(self):
pass
def ByteSizeDirty(self):
pass pass

View File

@ -30,7 +30,16 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Tests python protocol buffers against the golden message.""" """Tests python protocol buffers against the golden message.
Note that the golden messages exercise every known field type, thus this
test ends up exercising and verifying nearly all of the parsing and
serialization code in the whole library.
TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of
sense to call this a test of the "message" module, which only declares an
abstract interface.
"""
__author__ = 'gps@google.com (Gregory P. Smith)' __author__ = 'gps@google.com (Gregory P. Smith)'
@ -40,14 +49,41 @@ from google.protobuf import unittest_pb2
from google.protobuf.internal import test_util from google.protobuf.internal import test_util
class MessageTest(test_util.GoldenMessageTestCase): class MessageTest(unittest.TestCase):
def testGoldenMessage(self): def testGoldenMessage(self):
golden_data = test_util.GoldenFile('golden_message').read() golden_data = test_util.GoldenFile('golden_message').read()
golden_message = unittest_pb2.TestAllTypes() golden_message = unittest_pb2.TestAllTypes()
golden_message.ParseFromString(golden_data) golden_message.ParseFromString(golden_data)
self.ExpectAllFieldsSet(golden_message) test_util.ExpectAllFieldsSet(self, golden_message)
self.assertTrue(golden_message.SerializeToString() == golden_data)
def testGoldenExtensions(self):
golden_data = test_util.GoldenFile('golden_message').read()
golden_message = unittest_pb2.TestAllExtensions()
golden_message.ParseFromString(golden_data)
all_set = unittest_pb2.TestAllExtensions()
test_util.SetAllExtensions(all_set)
self.assertEquals(all_set, golden_message)
self.assertTrue(golden_message.SerializeToString() == golden_data)
def testGoldenPackedMessage(self):
golden_data = test_util.GoldenFile('golden_packed_fields_message').read()
golden_message = unittest_pb2.TestPackedTypes()
golden_message.ParseFromString(golden_data)
all_set = unittest_pb2.TestPackedTypes()
test_util.SetAllPackedFields(all_set)
self.assertEquals(all_set, golden_message)
self.assertTrue(all_set.SerializeToString() == golden_data)
def testGoldenPackedExtensions(self):
golden_data = test_util.GoldenFile('golden_packed_fields_message').read()
golden_message = unittest_pb2.TestPackedExtensions()
golden_message.ParseFromString(golden_data)
all_set = unittest_pb2.TestPackedExtensions()
test_util.SetAllPackedExtensions(all_set)
self.assertEquals(all_set, golden_message)
self.assertTrue(all_set.SerializeToString() == golden_data)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -1,125 +0,0 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# http://code.google.com/p/protobuf/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""OutputStream is the primitive interface for sticking bits on the wire.
All protocol buffer serialization can be expressed in terms of
the OutputStream primitives provided here.
"""
__author__ = 'robinson@google.com (Will Robinson)'
import array
import struct
from google.protobuf import message
from google.protobuf.internal import wire_format
# Note that much of this code is ported from //net/proto/ProtocolBuffer, and
# that the interface is strongly inspired by CodedOutputStream from the C++
# proto2 implementation.
class OutputStream(object):
"""Contains all logic for writing bits, and ToString() to get the result."""
def __init__(self):
self._buffer = array.array('B')
def AppendRawBytes(self, raw_bytes):
"""Appends raw_bytes to our internal buffer."""
self._buffer.fromstring(raw_bytes)
def AppendLittleEndian32(self, unsigned_value):
"""Appends an unsigned 32-bit integer to the internal buffer,
in little-endian byte order.
"""
if not 0 <= unsigned_value <= wire_format.UINT32_MAX:
raise message.EncodeError(
'Unsigned 32-bit out of range: %d' % unsigned_value)
self._buffer.fromstring(struct.pack(
wire_format.FORMAT_UINT32_LITTLE_ENDIAN, unsigned_value))
def AppendLittleEndian64(self, unsigned_value):
"""Appends an unsigned 64-bit integer to the internal buffer,
in little-endian byte order.
"""
if not 0 <= unsigned_value <= wire_format.UINT64_MAX:
raise message.EncodeError(
'Unsigned 64-bit out of range: %d' % unsigned_value)
self._buffer.fromstring(struct.pack(
wire_format.FORMAT_UINT64_LITTLE_ENDIAN, unsigned_value))
def AppendVarint32(self, value):
"""Appends a signed 32-bit integer to the internal buffer,
encoded as a varint. (Note that a negative varint32 will
always require 10 bytes of space.)
"""
if not wire_format.INT32_MIN <= value <= wire_format.INT32_MAX:
raise message.EncodeError('Value out of range: %d' % value)
self.AppendVarint64(value)
def AppendVarUInt32(self, value):
"""Appends an unsigned 32-bit integer to the internal buffer,
encoded as a varint.
"""
if not 0 <= value <= wire_format.UINT32_MAX:
raise message.EncodeError('Value out of range: %d' % value)
self.AppendVarUInt64(value)
def AppendVarint64(self, value):
"""Appends a signed 64-bit integer to the internal buffer,
encoded as a varint.
"""
if not wire_format.INT64_MIN <= value <= wire_format.INT64_MAX:
raise message.EncodeError('Value out of range: %d' % value)
if value < 0:
value += (1 << 64)
self.AppendVarUInt64(value)
def AppendVarUInt64(self, unsigned_value):
"""Appends an unsigned 64-bit integer to the internal buffer,
encoded as a varint.
"""
if not 0 <= unsigned_value <= wire_format.UINT64_MAX:
raise message.EncodeError('Value out of range: %d' % unsigned_value)
while True:
bits = unsigned_value & 0x7f
unsigned_value >>= 7
if not unsigned_value:
self._buffer.append(bits)
break
self._buffer.append(0x80|bits)
def ToString(self):
"""Returns a string containing the bytes in our internal buffer."""
return self._buffer.tostring()

View File

@ -1,178 +0,0 @@
#! /usr/bin/python
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
# http://code.google.com/p/protobuf/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following disclaimer
# in the documentation and/or other materials provided with the
# distribution.
# * Neither the name of Google Inc. nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Test for google.protobuf.internal.output_stream."""
__author__ = 'robinson@google.com (Will Robinson)'
import unittest
from google.protobuf import message
from google.protobuf.internal import output_stream
from google.protobuf.internal import wire_format
class OutputStreamTest(unittest.TestCase):
def setUp(self):
self.stream = output_stream.OutputStream()
def testAppendRawBytes(self):
# Empty string.
self.stream.AppendRawBytes('')
self.assertEqual('', self.stream.ToString())
# Nonempty string.
self.stream.AppendRawBytes('abc')
self.assertEqual('abc', self.stream.ToString())
# Ensure that we're actually appending.
self.stream.AppendRawBytes('def')
self.assertEqual('abcdef', self.stream.ToString())
def AppendNumericTestHelper(self, append_fn, values_and_strings):
"""For each (value, expected_string) pair in values_and_strings,
calls an OutputStream.Append*(value) method on an OutputStream and ensures
that the string written to that stream matches expected_string.
Args:
append_fn: Unbound OutputStream method that takes an integer or
long value as input.
values_and_strings: Iterable of (value, expected_string) pairs.
"""
for conversion in (int, long):
for value, string in values_and_strings:
stream = output_stream.OutputStream()
expected_string = ''
append_fn(stream, conversion(value))
expected_string += string
self.assertEqual(expected_string, stream.ToString())
def AppendOverflowTestHelper(self, append_fn, value):
"""Calls an OutputStream.Append*(value) method and asserts
that the method raises message.EncodeError.
Args:
append_fn: Unbound OutputStream method that takes an integer or
long value as input.
value: Value to pass to append_fn which should cause an
message.EncodeError.
"""
stream = output_stream.OutputStream()
self.assertRaises(message.EncodeError, append_fn, stream, value)
def testAppendLittleEndian32(self):
append_fn = output_stream.OutputStream.AppendLittleEndian32
values_and_expected_strings = [
(0, '\x00\x00\x00\x00'),
(1, '\x01\x00\x00\x00'),
((1 << 32) - 1, '\xff\xff\xff\xff'),
]
self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
self.AppendOverflowTestHelper(append_fn, 1 << 32)
self.AppendOverflowTestHelper(append_fn, -1)
def testAppendLittleEndian64(self):
append_fn = output_stream.OutputStream.AppendLittleEndian64
values_and_expected_strings = [
(0, '\x00\x00\x00\x00\x00\x00\x00\x00'),
(1, '\x01\x00\x00\x00\x00\x00\x00\x00'),
((1 << 64) - 1, '\xff\xff\xff\xff\xff\xff\xff\xff'),
]
self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
self.AppendOverflowTestHelper(append_fn, 1 << 64)
self.AppendOverflowTestHelper(append_fn, -1)
def testAppendVarint32(self):
append_fn = output_stream.OutputStream.AppendVarint32
values_and_expected_strings = [
(0, '\x00'),
(1, '\x01'),
(127, '\x7f'),
(128, '\x80\x01'),
(-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
(wire_format.INT32_MAX, '\xff\xff\xff\xff\x07'),
(wire_format.INT32_MIN, '\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01'),
]
self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MAX + 1)
self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MIN - 1)
def testAppendVarUInt32(self):
append_fn = output_stream.OutputStream.AppendVarUInt32
values_and_expected_strings = [
(0, '\x00'),
(1, '\x01'),
(127, '\x7f'),
(128, '\x80\x01'),
(wire_format.UINT32_MAX, '\xff\xff\xff\xff\x0f'),
]
self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
self.AppendOverflowTestHelper(append_fn, -1)
self.AppendOverflowTestHelper(append_fn, wire_format.UINT32_MAX + 1)
def testAppendVarint64(self):
append_fn = output_stream.OutputStream.AppendVarint64
values_and_expected_strings = [
(0, '\x00'),
(1, '\x01'),
(127, '\x7f'),
(128, '\x80\x01'),
(-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
(wire_format.INT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\x7f'),
(wire_format.INT64_MIN, '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01'),
]
self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MAX + 1)
self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MIN - 1)
def testAppendVarUInt64(self):
append_fn = output_stream.OutputStream.AppendVarUInt64
values_and_expected_strings = [
(0, '\x00'),
(1, '\x01'),
(127, '\x7f'),
(128, '\x80\x01'),
(wire_format.UINT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'),
]
self.AppendNumericTestHelper(append_fn, values_and_expected_strings)
self.AppendOverflowTestHelper(append_fn, -1)
self.AppendOverflowTestHelper(append_fn, wire_format.UINT64_MAX + 1)
if __name__ == '__main__':
unittest.main()

View File

@ -38,6 +38,7 @@ pure-Python protocol compiler.
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'robinson@google.com (Will Robinson)'
import operator import operator
import struct
import unittest import unittest
# TODO(robinson): When we split this test in two, only some of these imports # TODO(robinson): When we split this test in two, only some of these imports
@ -56,6 +57,51 @@ from google.protobuf.internal import test_util
from google.protobuf.internal import decoder from google.protobuf.internal import decoder
class _MiniDecoder(object):
"""Decodes a stream of values from a string.
Once upon a time we actually had a class called decoder.Decoder. Then we
got rid of it during a redesign that made decoding much, much faster overall.
But a couple tests in this file used it to check that the serialized form of
a message was correct. So, this class implements just the methods that were
used by said tests, so that we don't have to rewrite the tests.
"""
def __init__(self, bytes):
self._bytes = bytes
self._pos = 0
def ReadVarint(self):
result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
return result
ReadInt32 = ReadVarint
ReadInt64 = ReadVarint
ReadUInt32 = ReadVarint
ReadUInt64 = ReadVarint
def ReadSInt64(self):
return wire_format.ZigZagDecode(self.ReadVarint())
ReadSInt32 = ReadSInt64
def ReadFieldNumberAndWireType(self):
return wire_format.UnpackTag(self.ReadVarint())
def ReadFloat(self):
result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
self._pos += 4
return result
def ReadDouble(self):
result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
self._pos += 8
return result
def EndOfStream(self):
return self._pos == len(self._bytes)
class ReflectionTest(unittest.TestCase): class ReflectionTest(unittest.TestCase):
def assertIs(self, values, others): def assertIs(self, values, others):
@ -63,6 +109,97 @@ class ReflectionTest(unittest.TestCase):
for i in range(len(values)): for i in range(len(values)):
self.assertTrue(values[i] is others[i]) self.assertTrue(values[i] is others[i])
def testScalarConstructor(self):
# Constructor with only scalar types should succeed.
proto = unittest_pb2.TestAllTypes(
optional_int32=24,
optional_double=54.321,
optional_string='optional_string')
self.assertEqual(24, proto.optional_int32)
self.assertEqual(54.321, proto.optional_double)
self.assertEqual('optional_string', proto.optional_string)
def testRepeatedScalarConstructor(self):
# Constructor with only repeated scalar types should succeed.
proto = unittest_pb2.TestAllTypes(
repeated_int32=[1, 2, 3, 4],
repeated_double=[1.23, 54.321],
repeated_bool=[True, False, False],
repeated_string=["optional_string"])
self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32))
self.assertEquals([1.23, 54.321], list(proto.repeated_double))
self.assertEquals([True, False, False], list(proto.repeated_bool))
self.assertEquals(["optional_string"], list(proto.repeated_string))
def testRepeatedCompositeConstructor(self):
# Constructor with only repeated composite types should succeed.
proto = unittest_pb2.TestAllTypes(
repeated_nested_message=[
unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.FOO),
unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.BAR)],
repeated_foreign_message=[
unittest_pb2.ForeignMessage(c=-43),
unittest_pb2.ForeignMessage(c=45324),
unittest_pb2.ForeignMessage(c=12)],
repeatedgroup=[
unittest_pb2.TestAllTypes.RepeatedGroup(),
unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
self.assertEquals(
[unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.FOO),
unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.BAR)],
list(proto.repeated_nested_message))
self.assertEquals(
[unittest_pb2.ForeignMessage(c=-43),
unittest_pb2.ForeignMessage(c=45324),
unittest_pb2.ForeignMessage(c=12)],
list(proto.repeated_foreign_message))
self.assertEquals(
[unittest_pb2.TestAllTypes.RepeatedGroup(),
unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
list(proto.repeatedgroup))
def testMixedConstructor(self):
# Constructor with only mixed types should succeed.
proto = unittest_pb2.TestAllTypes(
optional_int32=24,
optional_string='optional_string',
repeated_double=[1.23, 54.321],
repeated_bool=[True, False, False],
repeated_nested_message=[
unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.FOO),
unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.BAR)],
repeated_foreign_message=[
unittest_pb2.ForeignMessage(c=-43),
unittest_pb2.ForeignMessage(c=45324),
unittest_pb2.ForeignMessage(c=12)])
self.assertEqual(24, proto.optional_int32)
self.assertEqual('optional_string', proto.optional_string)
self.assertEquals([1.23, 54.321], list(proto.repeated_double))
self.assertEquals([True, False, False], list(proto.repeated_bool))
self.assertEquals(
[unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.FOO),
unittest_pb2.TestAllTypes.NestedMessage(
bb=unittest_pb2.TestAllTypes.BAR)],
list(proto.repeated_nested_message))
self.assertEquals(
[unittest_pb2.ForeignMessage(c=-43),
unittest_pb2.ForeignMessage(c=45324),
unittest_pb2.ForeignMessage(c=12)],
list(proto.repeated_foreign_message))
def testSimpleHasBits(self): def testSimpleHasBits(self):
# Test a scalar. # Test a scalar.
proto = unittest_pb2.TestAllTypes() proto = unittest_pb2.TestAllTypes()
@ -218,12 +355,23 @@ class ReflectionTest(unittest.TestCase):
proto.optional_fixed32 = 1 proto.optional_fixed32 = 1
proto.optional_int32 = 5 proto.optional_int32 = 5
proto.optional_string = 'foo' proto.optional_string = 'foo'
# Access sub-message but don't set it yet.
nested_message = proto.optional_nested_message
self.assertEqual( self.assertEqual(
[ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
(proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
(proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
proto.ListFields()) proto.ListFields())
proto.optional_nested_message.bb = 123
self.assertEqual(
[ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
(proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
(proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
(proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
nested_message) ],
proto.ListFields())
def testRepeatedListFields(self): def testRepeatedListFields(self):
proto = unittest_pb2.TestAllTypes() proto = unittest_pb2.TestAllTypes()
proto.repeated_fixed32.append(1) proto.repeated_fixed32.append(1)
@ -234,6 +382,7 @@ class ReflectionTest(unittest.TestCase):
proto.repeated_string.append('baz') proto.repeated_string.append('baz')
proto.repeated_string.extend(str(x) for x in xrange(2)) proto.repeated_string.extend(str(x) for x in xrange(2))
proto.optional_int32 = 21 proto.optional_int32 = 21
proto.repeated_bool # Access but don't set anything; should not be listed.
self.assertEqual( self.assertEqual(
[ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21),
(proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]),
@ -731,7 +880,6 @@ class ReflectionTest(unittest.TestCase):
extendee_proto.ClearExtension(extension) extendee_proto.ClearExtension(extension)
extension_proto.foreign_message_int = 23 extension_proto.foreign_message_int = 23
self.assertTrue(not toplevel.HasField('submessage'))
self.assertTrue(extension_proto is not extendee_proto.Extensions[extension]) self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
def testExtensionFailureModes(self): def testExtensionFailureModes(self):
@ -957,57 +1105,75 @@ class ReflectionTest(unittest.TestCase):
empty_proto = unittest_pb2.TestAllExtensions() empty_proto = unittest_pb2.TestAllExtensions()
self.assertEquals(proto, empty_proto) self.assertEquals(proto, empty_proto)
def assertInitialized(self, proto):
self.assertTrue(proto.IsInitialized())
# Neither method should raise an exception.
proto.SerializeToString()
proto.SerializePartialToString()
def assertNotInitialized(self, proto):
self.assertFalse(proto.IsInitialized())
self.assertRaises(message.EncodeError, proto.SerializeToString)
# "Partial" serialization doesn't care if message is uninitialized.
proto.SerializePartialToString()
def testIsInitialized(self): def testIsInitialized(self):
# Trivial cases - all optional fields and extensions. # Trivial cases - all optional fields and extensions.
proto = unittest_pb2.TestAllTypes() proto = unittest_pb2.TestAllTypes()
self.assertTrue(proto.IsInitialized()) self.assertInitialized(proto)
proto = unittest_pb2.TestAllExtensions() proto = unittest_pb2.TestAllExtensions()
self.assertTrue(proto.IsInitialized()) self.assertInitialized(proto)
# The case of uninitialized required fields. # The case of uninitialized required fields.
proto = unittest_pb2.TestRequired() proto = unittest_pb2.TestRequired()
self.assertFalse(proto.IsInitialized()) self.assertNotInitialized(proto)
proto.a = proto.b = proto.c = 2 proto.a = proto.b = proto.c = 2
self.assertTrue(proto.IsInitialized()) self.assertInitialized(proto)
# The case of uninitialized submessage. # The case of uninitialized submessage.
proto = unittest_pb2.TestRequiredForeign() proto = unittest_pb2.TestRequiredForeign()
self.assertTrue(proto.IsInitialized()) self.assertInitialized(proto)
proto.optional_message.a = 1 proto.optional_message.a = 1
self.assertFalse(proto.IsInitialized()) self.assertNotInitialized(proto)
proto.optional_message.b = 0 proto.optional_message.b = 0
proto.optional_message.c = 0 proto.optional_message.c = 0
self.assertTrue(proto.IsInitialized()) self.assertInitialized(proto)
# Uninitialized repeated submessage. # Uninitialized repeated submessage.
message1 = proto.repeated_message.add() message1 = proto.repeated_message.add()
self.assertFalse(proto.IsInitialized()) self.assertNotInitialized(proto)
message1.a = message1.b = message1.c = 0 message1.a = message1.b = message1.c = 0
self.assertTrue(proto.IsInitialized()) self.assertInitialized(proto)
# Uninitialized repeated group in an extension. # Uninitialized repeated group in an extension.
proto = unittest_pb2.TestAllExtensions() proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.TestRequired.multi extension = unittest_pb2.TestRequired.multi
message1 = proto.Extensions[extension].add() message1 = proto.Extensions[extension].add()
message2 = proto.Extensions[extension].add() message2 = proto.Extensions[extension].add()
self.assertFalse(proto.IsInitialized()) self.assertNotInitialized(proto)
message1.a = 1 message1.a = 1
message1.b = 1 message1.b = 1
message1.c = 1 message1.c = 1
self.assertFalse(proto.IsInitialized()) self.assertNotInitialized(proto)
message2.a = 2 message2.a = 2
message2.b = 2 message2.b = 2
message2.c = 2 message2.c = 2
self.assertTrue(proto.IsInitialized()) self.assertInitialized(proto)
# Uninitialized nonrepeated message in an extension. # Uninitialized nonrepeated message in an extension.
proto = unittest_pb2.TestAllExtensions() proto = unittest_pb2.TestAllExtensions()
extension = unittest_pb2.TestRequired.single extension = unittest_pb2.TestRequired.single
proto.Extensions[extension].a = 1 proto.Extensions[extension].a = 1
self.assertFalse(proto.IsInitialized()) self.assertNotInitialized(proto)
proto.Extensions[extension].b = 2 proto.Extensions[extension].b = 2
proto.Extensions[extension].c = 3 proto.Extensions[extension].c = 3
self.assertTrue(proto.IsInitialized()) self.assertInitialized(proto)
# Try passing an errors list.
errors = []
proto = unittest_pb2.TestRequired()
self.assertFalse(proto.IsInitialized(errors))
self.assertEqual(errors, ['a', 'b', 'c'])
def testStringUTF8Encoding(self): def testStringUTF8Encoding(self):
proto = unittest_pb2.TestAllTypes() proto = unittest_pb2.TestAllTypes()
@ -1079,6 +1245,36 @@ class ReflectionTest(unittest.TestCase):
test_utf8_bytes, len(test_utf8_bytes) * '\xff') test_utf8_bytes, len(test_utf8_bytes) * '\xff')
self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes) self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes)
def testEmptyNestedMessage(self):
proto = unittest_pb2.TestAllTypes()
proto.optional_nested_message.MergeFrom(
unittest_pb2.TestAllTypes.NestedMessage())
self.assertTrue(proto.HasField('optional_nested_message'))
proto = unittest_pb2.TestAllTypes()
proto.optional_nested_message.CopyFrom(
unittest_pb2.TestAllTypes.NestedMessage())
self.assertTrue(proto.HasField('optional_nested_message'))
proto = unittest_pb2.TestAllTypes()
proto.optional_nested_message.MergeFromString('')
self.assertTrue(proto.HasField('optional_nested_message'))
proto = unittest_pb2.TestAllTypes()
proto.optional_nested_message.ParseFromString('')
self.assertTrue(proto.HasField('optional_nested_message'))
serialized = proto.SerializeToString()
proto2 = unittest_pb2.TestAllTypes()
proto2.MergeFromString(serialized)
self.assertTrue(proto2.HasField('optional_nested_message'))
def testSetInParent(self):
proto = unittest_pb2.TestAllTypes()
self.assertFalse(proto.HasField('optionalgroup'))
proto.optionalgroup.SetInParent()
self.assertTrue(proto.HasField('optionalgroup'))
# Since we had so many tests for protocol buffer equality, we broke these out # Since we had so many tests for protocol buffer equality, we broke these out
# into separate TestCase classes. # into separate TestCase classes.
@ -1541,6 +1737,47 @@ class SerializationTest(unittest.TestCase):
second_proto.MergeFromString(serialized) second_proto.MergeFromString(serialized)
self.assertEqual(first_proto, second_proto) self.assertEqual(first_proto, second_proto)
def testSerializeNegativeValues(self):
first_proto = unittest_pb2.TestAllTypes()
first_proto.optional_int32 = -1
first_proto.optional_int64 = -(2 << 40)
first_proto.optional_sint32 = -3
first_proto.optional_sint64 = -(4 << 40)
first_proto.optional_sfixed32 = -5
first_proto.optional_sfixed64 = -(6 << 40)
second_proto = unittest_pb2.TestAllTypes.FromString(
first_proto.SerializeToString())
self.assertEqual(first_proto, second_proto)
def testParseTruncated(self):
first_proto = unittest_pb2.TestAllTypes()
test_util.SetAllFields(first_proto)
serialized = first_proto.SerializeToString()
for truncation_point in xrange(len(serialized) + 1):
try:
second_proto = unittest_pb2.TestAllTypes()
unknown_fields = unittest_pb2.TestEmptyMessage()
pos = second_proto._InternalParse(serialized, 0, truncation_point)
# If we didn't raise an error then we read exactly the amount expected.
self.assertEqual(truncation_point, pos)
# Parsing to unknown fields should not throw if parsing to known fields
# did not.
try:
pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
self.assertEqual(truncation_point, pos2)
except message.DecodeError:
self.fail('Parsing unknown fields failed when parsing known fields '
'did not.')
except message.DecodeError:
# Parsing unknown fields should also fail.
self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
serialized, 0, truncation_point)
def testCanonicalSerializationOrder(self): def testCanonicalSerializationOrder(self):
proto = more_messages_pb2.OutOfOrderFields() proto = more_messages_pb2.OutOfOrderFields()
# These are also their tag numbers. Even though we're setting these in # These are also their tag numbers. Even though we're setting these in
@ -1553,7 +1790,7 @@ class SerializationTest(unittest.TestCase):
proto.optional_int32 = 1 proto.optional_int32 = 1
serialized = proto.SerializeToString() serialized = proto.SerializeToString()
self.assertEqual(proto.ByteSize(), len(serialized)) self.assertEqual(proto.ByteSize(), len(serialized))
d = decoder.Decoder(serialized) d = _MiniDecoder(serialized)
ReadTag = d.ReadFieldNumberAndWireType ReadTag = d.ReadFieldNumberAndWireType
self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
self.assertEqual(1, d.ReadInt32()) self.assertEqual(1, d.ReadInt32())
@ -1709,7 +1946,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises( self._CheckRaises(
message.EncodeError, message.EncodeError,
proto.SerializeToString, proto.SerializeToString,
'Required field protobuf_unittest.TestRequired.a is not set.') 'Message is missing required fields: a,b,c')
# Shouldn't raise exceptions. # Shouldn't raise exceptions.
partial = proto.SerializePartialToString() partial = proto.SerializePartialToString()
@ -1717,7 +1954,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises( self._CheckRaises(
message.EncodeError, message.EncodeError,
proto.SerializeToString, proto.SerializeToString,
'Required field protobuf_unittest.TestRequired.b is not set.') 'Message is missing required fields: b,c')
# Shouldn't raise exceptions. # Shouldn't raise exceptions.
partial = proto.SerializePartialToString() partial = proto.SerializePartialToString()
@ -1725,7 +1962,7 @@ class SerializationTest(unittest.TestCase):
self._CheckRaises( self._CheckRaises(
message.EncodeError, message.EncodeError,
proto.SerializeToString, proto.SerializeToString,
'Required field protobuf_unittest.TestRequired.c is not set.') 'Message is missing required fields: c')
# Shouldn't raise exceptions. # Shouldn't raise exceptions.
partial = proto.SerializePartialToString() partial = proto.SerializePartialToString()
@ -1744,6 +1981,38 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(2, proto2.b) self.assertEqual(2, proto2.b)
self.assertEqual(3, proto2.c) self.assertEqual(3, proto2.c)
def testSerializeUninitializedSubMessage(self):
proto = unittest_pb2.TestRequiredForeign()
# Sub-message doesn't exist yet, so this succeeds.
proto.SerializeToString()
proto.optional_message.a = 1
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
'Message is missing required fields: '
'optional_message.b,optional_message.c')
proto.optional_message.b = 2
proto.optional_message.c = 3
proto.SerializeToString()
proto.repeated_message.add().a = 1
proto.repeated_message.add().b = 2
self._CheckRaises(
message.EncodeError,
proto.SerializeToString,
'Message is missing required fields: '
'repeated_message[0].b,repeated_message[0].c,'
'repeated_message[1].a,repeated_message[1].c')
proto.repeated_message[0].b = 2
proto.repeated_message[0].c = 3
proto.repeated_message[1].a = 1
proto.repeated_message[1].c = 3
proto.SerializeToString()
def testSerializeAllPackedFields(self): def testSerializeAllPackedFields(self):
first_proto = unittest_pb2.TestPackedTypes() first_proto = unittest_pb2.TestPackedTypes()
second_proto = unittest_pb2.TestPackedTypes() second_proto = unittest_pb2.TestPackedTypes()
@ -1786,7 +2055,7 @@ class SerializationTest(unittest.TestCase):
proto.packed_float.append(2.0) # 4 bytes, will be before double proto.packed_float.append(2.0) # 4 bytes, will be before double
serialized = proto.SerializeToString() serialized = proto.SerializeToString()
self.assertEqual(proto.ByteSize(), len(serialized)) self.assertEqual(proto.ByteSize(), len(serialized))
d = decoder.Decoder(serialized) d = _MiniDecoder(serialized)
ReadTag = d.ReadFieldNumberAndWireType ReadTag = d.ReadFieldNumberAndWireType
self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
self.assertEqual(1+1+1+2, d.ReadInt32()) self.assertEqual(1+1+1+2, d.ReadInt32())
@ -1803,6 +2072,24 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(1000.0, d.ReadDouble()) self.assertEqual(1000.0, d.ReadDouble())
self.assertTrue(d.EndOfStream()) self.assertTrue(d.EndOfStream())
def testParsePackedFromUnpacked(self):
unpacked = unittest_pb2.TestUnpackedTypes()
test_util.SetAllUnpackedFields(unpacked)
packed = unittest_pb2.TestPackedTypes()
packed.MergeFromString(unpacked.SerializeToString())
expected = unittest_pb2.TestPackedTypes()
test_util.SetAllPackedFields(expected)
self.assertEqual(expected, packed)
def testParseUnpackedFromPacked(self):
packed = unittest_pb2.TestPackedTypes()
test_util.SetAllPackedFields(packed)
unpacked = unittest_pb2.TestUnpackedTypes()
unpacked.MergeFromString(packed.SerializeToString())
expected = unittest_pb2.TestUnpackedTypes()
test_util.SetAllUnpackedFields(expected)
self.assertEqual(expected, unpacked)
def testFieldNumbers(self): def testFieldNumbers(self):
proto = unittest_pb2.TestAllTypes() proto = unittest_pb2.TestAllTypes()
self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
@ -1944,33 +2231,6 @@ class OptionsTest(unittest.TestCase):
field_descriptor.label) field_descriptor.label)
class UtilityTest(unittest.TestCase):
def testImergeSorted(self):
ImergeSorted = reflection._ImergeSorted
# Various types of emptiness.
self.assertEqual([], list(ImergeSorted()))
self.assertEqual([], list(ImergeSorted([])))
self.assertEqual([], list(ImergeSorted([], [])))
# One nonempty list.
self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3])))
self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], [])))
self.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3])))
# Merging some nonempty lists together.
self.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2])))
self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2])))
self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], [])))
# Elements repeated across component iterators.
self.assertEqual([1, 2, 2, 3, 3],
list(ImergeSorted([1, 2], [3], [2, 3])))
# Elements repeated within an iterator.
self.assertEqual([1, 2, 2, 3, 3],
list(ImergeSorted([1, 2, 2], [3], [3])))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -31,14 +31,13 @@
"""Utilities for Python proto2 tests. """Utilities for Python proto2 tests.
This is intentionally modeled on C++ code in This is intentionally modeled on C++ code in
//net/proto2/internal/test_util.*. //google/protobuf/test_util.*.
""" """
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'robinson@google.com (Will Robinson)'
import os.path import os.path
import unittest
from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_import_pb2
from google.protobuf import unittest_pb2 from google.protobuf import unittest_pb2
@ -353,198 +352,198 @@ def ExpectAllFieldsAndExtensionsInOrder(serialized):
raise ValueError('Expected %r, found %r' % (expected, serialized)) raise ValueError('Expected %r, found %r' % (expected, serialized))
class GoldenMessageTestCase(unittest.TestCase): def ExpectAllFieldsSet(test_case, message):
"""This adds methods to TestCase useful for verifying our Golden Message.""" """Check all fields for correct values have after Set*Fields() is called."""
test_case.assertTrue(message.HasField('optional_int32'))
test_case.assertTrue(message.HasField('optional_int64'))
test_case.assertTrue(message.HasField('optional_uint32'))
test_case.assertTrue(message.HasField('optional_uint64'))
test_case.assertTrue(message.HasField('optional_sint32'))
test_case.assertTrue(message.HasField('optional_sint64'))
test_case.assertTrue(message.HasField('optional_fixed32'))
test_case.assertTrue(message.HasField('optional_fixed64'))
test_case.assertTrue(message.HasField('optional_sfixed32'))
test_case.assertTrue(message.HasField('optional_sfixed64'))
test_case.assertTrue(message.HasField('optional_float'))
test_case.assertTrue(message.HasField('optional_double'))
test_case.assertTrue(message.HasField('optional_bool'))
test_case.assertTrue(message.HasField('optional_string'))
test_case.assertTrue(message.HasField('optional_bytes'))
def ExpectAllFieldsSet(self, message): test_case.assertTrue(message.HasField('optionalgroup'))
"""Check all fields for correct values have after Set*Fields() is called.""" test_case.assertTrue(message.HasField('optional_nested_message'))
self.assertTrue(message.HasField('optional_int32')) test_case.assertTrue(message.HasField('optional_foreign_message'))
self.assertTrue(message.HasField('optional_int64')) test_case.assertTrue(message.HasField('optional_import_message'))
self.assertTrue(message.HasField('optional_uint32'))
self.assertTrue(message.HasField('optional_uint64'))
self.assertTrue(message.HasField('optional_sint32'))
self.assertTrue(message.HasField('optional_sint64'))
self.assertTrue(message.HasField('optional_fixed32'))
self.assertTrue(message.HasField('optional_fixed64'))
self.assertTrue(message.HasField('optional_sfixed32'))
self.assertTrue(message.HasField('optional_sfixed64'))
self.assertTrue(message.HasField('optional_float'))
self.assertTrue(message.HasField('optional_double'))
self.assertTrue(message.HasField('optional_bool'))
self.assertTrue(message.HasField('optional_string'))
self.assertTrue(message.HasField('optional_bytes'))
self.assertTrue(message.HasField('optionalgroup')) test_case.assertTrue(message.optionalgroup.HasField('a'))
self.assertTrue(message.HasField('optional_nested_message')) test_case.assertTrue(message.optional_nested_message.HasField('bb'))
self.assertTrue(message.HasField('optional_foreign_message')) test_case.assertTrue(message.optional_foreign_message.HasField('c'))
self.assertTrue(message.HasField('optional_import_message')) test_case.assertTrue(message.optional_import_message.HasField('d'))
self.assertTrue(message.optionalgroup.HasField('a')) test_case.assertTrue(message.HasField('optional_nested_enum'))
self.assertTrue(message.optional_nested_message.HasField('bb')) test_case.assertTrue(message.HasField('optional_foreign_enum'))
self.assertTrue(message.optional_foreign_message.HasField('c')) test_case.assertTrue(message.HasField('optional_import_enum'))
self.assertTrue(message.optional_import_message.HasField('d'))
self.assertTrue(message.HasField('optional_nested_enum')) test_case.assertTrue(message.HasField('optional_string_piece'))
self.assertTrue(message.HasField('optional_foreign_enum')) test_case.assertTrue(message.HasField('optional_cord'))
self.assertTrue(message.HasField('optional_import_enum'))
self.assertTrue(message.HasField('optional_string_piece')) test_case.assertEqual(101, message.optional_int32)
self.assertTrue(message.HasField('optional_cord')) test_case.assertEqual(102, message.optional_int64)
test_case.assertEqual(103, message.optional_uint32)
test_case.assertEqual(104, message.optional_uint64)
test_case.assertEqual(105, message.optional_sint32)
test_case.assertEqual(106, message.optional_sint64)
test_case.assertEqual(107, message.optional_fixed32)
test_case.assertEqual(108, message.optional_fixed64)
test_case.assertEqual(109, message.optional_sfixed32)
test_case.assertEqual(110, message.optional_sfixed64)
test_case.assertEqual(111, message.optional_float)
test_case.assertEqual(112, message.optional_double)
test_case.assertEqual(True, message.optional_bool)
test_case.assertEqual('115', message.optional_string)
test_case.assertEqual('116', message.optional_bytes)
self.assertEqual(101, message.optional_int32) test_case.assertEqual(117, message.optionalgroup.a)
self.assertEqual(102, message.optional_int64) test_case.assertEqual(118, message.optional_nested_message.bb)
self.assertEqual(103, message.optional_uint32) test_case.assertEqual(119, message.optional_foreign_message.c)
self.assertEqual(104, message.optional_uint64) test_case.assertEqual(120, message.optional_import_message.d)
self.assertEqual(105, message.optional_sint32)
self.assertEqual(106, message.optional_sint64)
self.assertEqual(107, message.optional_fixed32)
self.assertEqual(108, message.optional_fixed64)
self.assertEqual(109, message.optional_sfixed32)
self.assertEqual(110, message.optional_sfixed64)
self.assertEqual(111, message.optional_float)
self.assertEqual(112, message.optional_double)
self.assertEqual(True, message.optional_bool)
self.assertEqual('115', message.optional_string)
self.assertEqual('116', message.optional_bytes)
self.assertEqual(117, message.optionalgroup.a); test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ,
self.assertEqual(118, message.optional_nested_message.bb) message.optional_nested_enum)
self.assertEqual(119, message.optional_foreign_message.c) test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
self.assertEqual(120, message.optional_import_message.d) message.optional_foreign_enum)
test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
message.optional_import_enum)
self.assertEqual(unittest_pb2.TestAllTypes.BAZ, # -----------------------------------------------------------------
message.optional_nested_enum)
self.assertEqual(unittest_pb2.FOREIGN_BAZ, message.optional_foreign_enum)
self.assertEqual(unittest_import_pb2.IMPORT_BAZ,
message.optional_import_enum)
# ----------------------------------------------------------------- test_case.assertEqual(2, len(message.repeated_int32))
test_case.assertEqual(2, len(message.repeated_int64))
test_case.assertEqual(2, len(message.repeated_uint32))
test_case.assertEqual(2, len(message.repeated_uint64))
test_case.assertEqual(2, len(message.repeated_sint32))
test_case.assertEqual(2, len(message.repeated_sint64))
test_case.assertEqual(2, len(message.repeated_fixed32))
test_case.assertEqual(2, len(message.repeated_fixed64))
test_case.assertEqual(2, len(message.repeated_sfixed32))
test_case.assertEqual(2, len(message.repeated_sfixed64))
test_case.assertEqual(2, len(message.repeated_float))
test_case.assertEqual(2, len(message.repeated_double))
test_case.assertEqual(2, len(message.repeated_bool))
test_case.assertEqual(2, len(message.repeated_string))
test_case.assertEqual(2, len(message.repeated_bytes))
self.assertEqual(2, len(message.repeated_int32)) test_case.assertEqual(2, len(message.repeatedgroup))
self.assertEqual(2, len(message.repeated_int64)) test_case.assertEqual(2, len(message.repeated_nested_message))
self.assertEqual(2, len(message.repeated_uint32)) test_case.assertEqual(2, len(message.repeated_foreign_message))
self.assertEqual(2, len(message.repeated_uint64)) test_case.assertEqual(2, len(message.repeated_import_message))
self.assertEqual(2, len(message.repeated_sint32)) test_case.assertEqual(2, len(message.repeated_nested_enum))
self.assertEqual(2, len(message.repeated_sint64)) test_case.assertEqual(2, len(message.repeated_foreign_enum))
self.assertEqual(2, len(message.repeated_fixed32)) test_case.assertEqual(2, len(message.repeated_import_enum))
self.assertEqual(2, len(message.repeated_fixed64))
self.assertEqual(2, len(message.repeated_sfixed32))
self.assertEqual(2, len(message.repeated_sfixed64))
self.assertEqual(2, len(message.repeated_float))
self.assertEqual(2, len(message.repeated_double))
self.assertEqual(2, len(message.repeated_bool))
self.assertEqual(2, len(message.repeated_string))
self.assertEqual(2, len(message.repeated_bytes))
self.assertEqual(2, len(message.repeatedgroup)) test_case.assertEqual(2, len(message.repeated_string_piece))
self.assertEqual(2, len(message.repeated_nested_message)) test_case.assertEqual(2, len(message.repeated_cord))
self.assertEqual(2, len(message.repeated_foreign_message))
self.assertEqual(2, len(message.repeated_import_message))
self.assertEqual(2, len(message.repeated_nested_enum))
self.assertEqual(2, len(message.repeated_foreign_enum))
self.assertEqual(2, len(message.repeated_import_enum))
self.assertEqual(2, len(message.repeated_string_piece)) test_case.assertEqual(201, message.repeated_int32[0])
self.assertEqual(2, len(message.repeated_cord)) test_case.assertEqual(202, message.repeated_int64[0])
test_case.assertEqual(203, message.repeated_uint32[0])
test_case.assertEqual(204, message.repeated_uint64[0])
test_case.assertEqual(205, message.repeated_sint32[0])
test_case.assertEqual(206, message.repeated_sint64[0])
test_case.assertEqual(207, message.repeated_fixed32[0])
test_case.assertEqual(208, message.repeated_fixed64[0])
test_case.assertEqual(209, message.repeated_sfixed32[0])
test_case.assertEqual(210, message.repeated_sfixed64[0])
test_case.assertEqual(211, message.repeated_float[0])
test_case.assertEqual(212, message.repeated_double[0])
test_case.assertEqual(True, message.repeated_bool[0])
test_case.assertEqual('215', message.repeated_string[0])
test_case.assertEqual('216', message.repeated_bytes[0])
self.assertEqual(201, message.repeated_int32[0]) test_case.assertEqual(217, message.repeatedgroup[0].a)
self.assertEqual(202, message.repeated_int64[0]) test_case.assertEqual(218, message.repeated_nested_message[0].bb)
self.assertEqual(203, message.repeated_uint32[0]) test_case.assertEqual(219, message.repeated_foreign_message[0].c)
self.assertEqual(204, message.repeated_uint64[0]) test_case.assertEqual(220, message.repeated_import_message[0].d)
self.assertEqual(205, message.repeated_sint32[0])
self.assertEqual(206, message.repeated_sint64[0])
self.assertEqual(207, message.repeated_fixed32[0])
self.assertEqual(208, message.repeated_fixed64[0])
self.assertEqual(209, message.repeated_sfixed32[0])
self.assertEqual(210, message.repeated_sfixed64[0])
self.assertEqual(211, message.repeated_float[0])
self.assertEqual(212, message.repeated_double[0])
self.assertEqual(True, message.repeated_bool[0])
self.assertEqual('215', message.repeated_string[0])
self.assertEqual('216', message.repeated_bytes[0])
self.assertEqual(217, message.repeatedgroup[0].a) test_case.assertEqual(unittest_pb2.TestAllTypes.BAR,
self.assertEqual(218, message.repeated_nested_message[0].bb) message.repeated_nested_enum[0])
self.assertEqual(219, message.repeated_foreign_message[0].c) test_case.assertEqual(unittest_pb2.FOREIGN_BAR,
self.assertEqual(220, message.repeated_import_message[0].d) message.repeated_foreign_enum[0])
test_case.assertEqual(unittest_import_pb2.IMPORT_BAR,
message.repeated_import_enum[0])
self.assertEqual(unittest_pb2.TestAllTypes.BAR, test_case.assertEqual(301, message.repeated_int32[1])
message.repeated_nested_enum[0]) test_case.assertEqual(302, message.repeated_int64[1])
self.assertEqual(unittest_pb2.FOREIGN_BAR, test_case.assertEqual(303, message.repeated_uint32[1])
message.repeated_foreign_enum[0]) test_case.assertEqual(304, message.repeated_uint64[1])
self.assertEqual(unittest_import_pb2.IMPORT_BAR, test_case.assertEqual(305, message.repeated_sint32[1])
message.repeated_import_enum[0]) test_case.assertEqual(306, message.repeated_sint64[1])
test_case.assertEqual(307, message.repeated_fixed32[1])
test_case.assertEqual(308, message.repeated_fixed64[1])
test_case.assertEqual(309, message.repeated_sfixed32[1])
test_case.assertEqual(310, message.repeated_sfixed64[1])
test_case.assertEqual(311, message.repeated_float[1])
test_case.assertEqual(312, message.repeated_double[1])
test_case.assertEqual(False, message.repeated_bool[1])
test_case.assertEqual('315', message.repeated_string[1])
test_case.assertEqual('316', message.repeated_bytes[1])
self.assertEqual(301, message.repeated_int32[1]) test_case.assertEqual(317, message.repeatedgroup[1].a)
self.assertEqual(302, message.repeated_int64[1]) test_case.assertEqual(318, message.repeated_nested_message[1].bb)
self.assertEqual(303, message.repeated_uint32[1]) test_case.assertEqual(319, message.repeated_foreign_message[1].c)
self.assertEqual(304, message.repeated_uint64[1]) test_case.assertEqual(320, message.repeated_import_message[1].d)
self.assertEqual(305, message.repeated_sint32[1])
self.assertEqual(306, message.repeated_sint64[1])
self.assertEqual(307, message.repeated_fixed32[1])
self.assertEqual(308, message.repeated_fixed64[1])
self.assertEqual(309, message.repeated_sfixed32[1])
self.assertEqual(310, message.repeated_sfixed64[1])
self.assertEqual(311, message.repeated_float[1])
self.assertEqual(312, message.repeated_double[1])
self.assertEqual(False, message.repeated_bool[1])
self.assertEqual('315', message.repeated_string[1])
self.assertEqual('316', message.repeated_bytes[1])
self.assertEqual(317, message.repeatedgroup[1].a) test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ,
self.assertEqual(318, message.repeated_nested_message[1].bb) message.repeated_nested_enum[1])
self.assertEqual(319, message.repeated_foreign_message[1].c) test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
self.assertEqual(320, message.repeated_import_message[1].d) message.repeated_foreign_enum[1])
test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
message.repeated_import_enum[1])
self.assertEqual(unittest_pb2.TestAllTypes.BAZ, # -----------------------------------------------------------------
message.repeated_nested_enum[1])
self.assertEqual(unittest_pb2.FOREIGN_BAZ,
message.repeated_foreign_enum[1])
self.assertEqual(unittest_import_pb2.IMPORT_BAZ,
message.repeated_import_enum[1])
# ----------------------------------------------------------------- test_case.assertTrue(message.HasField('default_int32'))
test_case.assertTrue(message.HasField('default_int64'))
test_case.assertTrue(message.HasField('default_uint32'))
test_case.assertTrue(message.HasField('default_uint64'))
test_case.assertTrue(message.HasField('default_sint32'))
test_case.assertTrue(message.HasField('default_sint64'))
test_case.assertTrue(message.HasField('default_fixed32'))
test_case.assertTrue(message.HasField('default_fixed64'))
test_case.assertTrue(message.HasField('default_sfixed32'))
test_case.assertTrue(message.HasField('default_sfixed64'))
test_case.assertTrue(message.HasField('default_float'))
test_case.assertTrue(message.HasField('default_double'))
test_case.assertTrue(message.HasField('default_bool'))
test_case.assertTrue(message.HasField('default_string'))
test_case.assertTrue(message.HasField('default_bytes'))
self.assertTrue(message.HasField('default_int32')) test_case.assertTrue(message.HasField('default_nested_enum'))
self.assertTrue(message.HasField('default_int64')) test_case.assertTrue(message.HasField('default_foreign_enum'))
self.assertTrue(message.HasField('default_uint32')) test_case.assertTrue(message.HasField('default_import_enum'))
self.assertTrue(message.HasField('default_uint64'))
self.assertTrue(message.HasField('default_sint32'))
self.assertTrue(message.HasField('default_sint64'))
self.assertTrue(message.HasField('default_fixed32'))
self.assertTrue(message.HasField('default_fixed64'))
self.assertTrue(message.HasField('default_sfixed32'))
self.assertTrue(message.HasField('default_sfixed64'))
self.assertTrue(message.HasField('default_float'))
self.assertTrue(message.HasField('default_double'))
self.assertTrue(message.HasField('default_bool'))
self.assertTrue(message.HasField('default_string'))
self.assertTrue(message.HasField('default_bytes'))
self.assertTrue(message.HasField('default_nested_enum')) test_case.assertEqual(401, message.default_int32)
self.assertTrue(message.HasField('default_foreign_enum')) test_case.assertEqual(402, message.default_int64)
self.assertTrue(message.HasField('default_import_enum')) test_case.assertEqual(403, message.default_uint32)
test_case.assertEqual(404, message.default_uint64)
test_case.assertEqual(405, message.default_sint32)
test_case.assertEqual(406, message.default_sint64)
test_case.assertEqual(407, message.default_fixed32)
test_case.assertEqual(408, message.default_fixed64)
test_case.assertEqual(409, message.default_sfixed32)
test_case.assertEqual(410, message.default_sfixed64)
test_case.assertEqual(411, message.default_float)
test_case.assertEqual(412, message.default_double)
test_case.assertEqual(False, message.default_bool)
test_case.assertEqual('415', message.default_string)
test_case.assertEqual('416', message.default_bytes)
self.assertEqual(401, message.default_int32) test_case.assertEqual(unittest_pb2.TestAllTypes.FOO,
self.assertEqual(402, message.default_int64) message.default_nested_enum)
self.assertEqual(403, message.default_uint32) test_case.assertEqual(unittest_pb2.FOREIGN_FOO,
self.assertEqual(404, message.default_uint64) message.default_foreign_enum)
self.assertEqual(405, message.default_sint32) test_case.assertEqual(unittest_import_pb2.IMPORT_FOO,
self.assertEqual(406, message.default_sint64) message.default_import_enum)
self.assertEqual(407, message.default_fixed32)
self.assertEqual(408, message.default_fixed64)
self.assertEqual(409, message.default_sfixed32)
self.assertEqual(410, message.default_sfixed64)
self.assertEqual(411, message.default_float)
self.assertEqual(412, message.default_double)
self.assertEqual(False, message.default_bool)
self.assertEqual('415', message.default_string)
self.assertEqual('416', message.default_bytes)
self.assertEqual(unittest_pb2.TestAllTypes.FOO, message.default_nested_enum)
self.assertEqual(unittest_pb2.FOREIGN_FOO, message.default_foreign_enum)
self.assertEqual(unittest_import_pb2.IMPORT_FOO,
message.default_import_enum)
def GoldenFile(filename): def GoldenFile(filename):
"""Finds the given golden file and returns a file object representing it.""" """Finds the given golden file and returns a file object representing it."""
@ -570,21 +569,21 @@ def SetAllPackedFields(message):
Args: Args:
message: A unittest_pb2.TestPackedTypes instance. message: A unittest_pb2.TestPackedTypes instance.
""" """
message.packed_int32.extend([101, 102]) message.packed_int32.extend([601, 701])
message.packed_int64.extend([103, 104]) message.packed_int64.extend([602, 702])
message.packed_uint32.extend([105, 106]) message.packed_uint32.extend([603, 703])
message.packed_uint64.extend([107, 108]) message.packed_uint64.extend([604, 704])
message.packed_sint32.extend([109, 110]) message.packed_sint32.extend([605, 705])
message.packed_sint64.extend([111, 112]) message.packed_sint64.extend([606, 706])
message.packed_fixed32.extend([113, 114]) message.packed_fixed32.extend([607, 707])
message.packed_fixed64.extend([115, 116]) message.packed_fixed64.extend([608, 708])
message.packed_sfixed32.extend([117, 118]) message.packed_sfixed32.extend([609, 709])
message.packed_sfixed64.extend([119, 120]) message.packed_sfixed64.extend([610, 710])
message.packed_float.extend([121.0, 122.0]) message.packed_float.extend([611.0, 711.0])
message.packed_double.extend([122.0, 123.0]) message.packed_double.extend([612.0, 712.0])
message.packed_bool.extend([True, False]) message.packed_bool.extend([True, False])
message.packed_enum.extend([unittest_pb2.FOREIGN_FOO, message.packed_enum.extend([unittest_pb2.FOREIGN_BAR,
unittest_pb2.FOREIGN_BAR]) unittest_pb2.FOREIGN_BAZ])
def SetAllPackedExtensions(message): def SetAllPackedExtensions(message):
@ -596,17 +595,41 @@ def SetAllPackedExtensions(message):
extensions = message.Extensions extensions = message.Extensions
pb2 = unittest_pb2 pb2 = unittest_pb2
extensions[pb2.packed_int32_extension].append(101) extensions[pb2.packed_int32_extension].extend([601, 701])
extensions[pb2.packed_int64_extension].append(102) extensions[pb2.packed_int64_extension].extend([602, 702])
extensions[pb2.packed_uint32_extension].append(103) extensions[pb2.packed_uint32_extension].extend([603, 703])
extensions[pb2.packed_uint64_extension].append(104) extensions[pb2.packed_uint64_extension].extend([604, 704])
extensions[pb2.packed_sint32_extension].append(105) extensions[pb2.packed_sint32_extension].extend([605, 705])
extensions[pb2.packed_sint64_extension].append(106) extensions[pb2.packed_sint64_extension].extend([606, 706])
extensions[pb2.packed_fixed32_extension].append(107) extensions[pb2.packed_fixed32_extension].extend([607, 707])
extensions[pb2.packed_fixed64_extension].append(108) extensions[pb2.packed_fixed64_extension].extend([608, 708])
extensions[pb2.packed_sfixed32_extension].append(109) extensions[pb2.packed_sfixed32_extension].extend([609, 709])
extensions[pb2.packed_sfixed64_extension].append(110) extensions[pb2.packed_sfixed64_extension].extend([610, 710])
extensions[pb2.packed_float_extension].append(111.0) extensions[pb2.packed_float_extension].extend([611.0, 711.0])
extensions[pb2.packed_double_extension].append(112.0) extensions[pb2.packed_double_extension].extend([612.0, 712.0])
extensions[pb2.packed_bool_extension].append(True) extensions[pb2.packed_bool_extension].extend([True, False])
extensions[pb2.packed_enum_extension].append(pb2.FOREIGN_BAZ) extensions[pb2.packed_enum_extension].extend([unittest_pb2.FOREIGN_BAR,
unittest_pb2.FOREIGN_BAZ])
def SetAllUnpackedFields(message):
"""Sets every field in the message to a unique value.
Args:
message: A unittest_pb2.TestUnpackedTypes instance.
"""
message.unpacked_int32.extend([601, 701])
message.unpacked_int64.extend([602, 702])
message.unpacked_uint32.extend([603, 703])
message.unpacked_uint64.extend([604, 704])
message.unpacked_sint32.extend([605, 705])
message.unpacked_sint64.extend([606, 706])
message.unpacked_fixed32.extend([607, 707])
message.unpacked_fixed64.extend([608, 708])
message.unpacked_sfixed32.extend([609, 709])
message.unpacked_sfixed64.extend([610, 710])
message.unpacked_float.extend([611.0, 711.0])
message.unpacked_double.extend([612.0, 712.0])
message.unpacked_bool.extend([True, False])
message.unpacked_enum.extend([unittest_pb2.FOREIGN_BAR,
unittest_pb2.FOREIGN_BAZ])

View File

@ -43,7 +43,7 @@ from google.protobuf import unittest_pb2
from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_mset_pb2
class TextFormatTest(test_util.GoldenMessageTestCase): class TextFormatTest(unittest.TestCase):
def ReadGolden(self, golden_filename): def ReadGolden(self, golden_filename):
f = test_util.GoldenFile(golden_filename) f = test_util.GoldenFile(golden_filename)
golden_lines = f.readlines() golden_lines = f.readlines()
@ -149,7 +149,7 @@ class TextFormatTest(test_util.GoldenMessageTestCase):
parsed_message = unittest_pb2.TestAllTypes() parsed_message = unittest_pb2.TestAllTypes()
text_format.Merge(ascii_text, parsed_message) text_format.Merge(ascii_text, parsed_message)
self.assertEqual(message, parsed_message) self.assertEqual(message, parsed_message)
self.ExpectAllFieldsSet(message) test_util.ExpectAllFieldsSet(self, message)
def testMergeAllExtensions(self): def testMergeAllExtensions(self):
message = unittest_pb2.TestAllExtensions() message = unittest_pb2.TestAllExtensions()
@ -212,12 +212,18 @@ class TextFormatTest(test_util.GoldenMessageTestCase):
text_format.Merge, text, message) text_format.Merge, text, message)
def testMergeBadExtension(self): def testMergeBadExtension(self):
message = unittest_pb2.TestAllTypes() message = unittest_pb2.TestAllExtensions()
text = '[unknown_extension]: 8\n' text = '[unknown_extension]: 8\n'
self.assertRaisesWithMessage( self.assertRaisesWithMessage(
text_format.ParseError, text_format.ParseError,
'1:2 : Extension "unknown_extension" not registered.', '1:2 : Extension "unknown_extension" not registered.',
text_format.Merge, text, message) text_format.Merge, text, message)
message = unittest_pb2.TestAllTypes()
self.assertRaisesWithMessage(
text_format.ParseError,
('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
'extensions.'),
text_format.Merge, text, message)
def testMergeGroupNotClosed(self): def testMergeGroupNotClosed(self):
message = unittest_pb2.TestAllTypes() message = unittest_pb2.TestAllTypes()
@ -231,6 +237,19 @@ class TextFormatTest(test_util.GoldenMessageTestCase):
text_format.ParseError, '1:16 : Expected "}".', text_format.ParseError, '1:16 : Expected "}".',
text_format.Merge, text, message) text_format.Merge, text, message)
def testMergeEmptyGroup(self):
message = unittest_pb2.TestAllTypes()
text = 'OptionalGroup: {}'
text_format.Merge(text, message)
self.assertTrue(message.HasField('optionalgroup'))
message.Clear()
message = unittest_pb2.TestAllTypes()
text = 'OptionalGroup: <>'
text_format.Merge(text, message)
self.assertTrue(message.HasField('optionalgroup'))
def testMergeBadEnumValue(self): def testMergeBadEnumValue(self):
message = unittest_pb2.TestAllTypes() message = unittest_pb2.TestAllTypes()
text = 'optional_nested_enum: BARR' text = 'optional_nested_enum: BARR'

View File

@ -192,47 +192,72 @@ TYPE_TO_BYTE_SIZE_FN = {
} }
# Maps from field type to an unbound Encoder method F, such that # Maps from field types to encoder constructors.
# F(encoder, field_number, value) will append the serialization TYPE_TO_ENCODER = {
# of a value of this type to the encoder. _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder,
_Encoder = encoder.Encoder _FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder,
TYPE_TO_SERIALIZE_METHOD = { _FieldDescriptor.TYPE_INT64: encoder.Int64Encoder,
_FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDouble, _FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder,
_FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloat, _FieldDescriptor.TYPE_INT32: encoder.Int32Encoder,
_FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64, _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder,
_FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64, _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder,
_FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32, _FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder,
_FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64, _FieldDescriptor.TYPE_STRING: encoder.StringEncoder,
_FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32, _FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder,
_FieldDescriptor.TYPE_BOOL: _Encoder.AppendBool, _FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder,
_FieldDescriptor.TYPE_STRING: _Encoder.AppendString, _FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder,
_FieldDescriptor.TYPE_GROUP: _Encoder.AppendGroup, _FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder,
_FieldDescriptor.TYPE_MESSAGE: _Encoder.AppendMessage, _FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder,
_FieldDescriptor.TYPE_BYTES: _Encoder.AppendBytes, _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder,
_FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32, _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder,
_FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnum, _FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder,
_FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32, _FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder,
_FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64,
_FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32,
_FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64,
} }
TYPE_TO_NOTAG_SERIALIZE_METHOD = { # Maps from field types to sizer constructors.
_FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDoubleNoTag, TYPE_TO_SIZER = {
_FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloatNoTag, _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer,
_FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64NoTag, _FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer,
_FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64NoTag, _FieldDescriptor.TYPE_INT64: encoder.Int64Sizer,
_FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32NoTag, _FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer,
_FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64NoTag, _FieldDescriptor.TYPE_INT32: encoder.Int32Sizer,
_FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32NoTag, _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer,
_FieldDescriptor.TYPE_BOOL: _Encoder.AppendBoolNoTag, _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer,
_FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32NoTag, _FieldDescriptor.TYPE_BOOL: encoder.BoolSizer,
_FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnumNoTag, _FieldDescriptor.TYPE_STRING: encoder.StringSizer,
_FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32NoTag, _FieldDescriptor.TYPE_GROUP: encoder.GroupSizer,
_FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64NoTag, _FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer,
_FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32NoTag, _FieldDescriptor.TYPE_BYTES: encoder.BytesSizer,
_FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64NoTag, _FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer,
_FieldDescriptor.TYPE_ENUM: encoder.EnumSizer,
_FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer,
_FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer,
_FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer,
_FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer,
}
# Maps from field type to a decoder constructor.
TYPE_TO_DECODER = {
_FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder,
_FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder,
_FieldDescriptor.TYPE_INT64: decoder.Int64Decoder,
_FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder,
_FieldDescriptor.TYPE_INT32: decoder.Int32Decoder,
_FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder,
_FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder,
_FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder,
_FieldDescriptor.TYPE_STRING: decoder.StringDecoder,
_FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder,
_FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder,
_FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder,
_FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder,
_FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder,
_FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder,
_FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder,
_FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder,
_FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder,
} }
# Maps from field type to expected wiretype. # Maps from field type to expected wiretype.
@ -259,29 +284,3 @@ FIELD_TYPE_TO_WIRE_TYPE = {
_FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT, _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT,
_FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT, _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT,
} }
# Maps from field type to an unbound Decoder method F,
# such that F(decoder) will read a field of the requested type.
#
# Note that Message and Group are intentionally missing here.
# They're handled by _RecursivelyMerge().
_Decoder = decoder.Decoder
TYPE_TO_DESERIALIZE_METHOD = {
_FieldDescriptor.TYPE_DOUBLE: _Decoder.ReadDouble,
_FieldDescriptor.TYPE_FLOAT: _Decoder.ReadFloat,
_FieldDescriptor.TYPE_INT64: _Decoder.ReadInt64,
_FieldDescriptor.TYPE_UINT64: _Decoder.ReadUInt64,
_FieldDescriptor.TYPE_INT32: _Decoder.ReadInt32,
_FieldDescriptor.TYPE_FIXED64: _Decoder.ReadFixed64,
_FieldDescriptor.TYPE_FIXED32: _Decoder.ReadFixed32,
_FieldDescriptor.TYPE_BOOL: _Decoder.ReadBool,
_FieldDescriptor.TYPE_STRING: _Decoder.ReadString,
_FieldDescriptor.TYPE_BYTES: _Decoder.ReadBytes,
_FieldDescriptor.TYPE_UINT32: _Decoder.ReadUInt32,
_FieldDescriptor.TYPE_ENUM: _Decoder.ReadEnum,
_FieldDescriptor.TYPE_SFIXED32: _Decoder.ReadSFixed32,
_FieldDescriptor.TYPE_SFIXED64: _Decoder.ReadSFixed64,
_FieldDescriptor.TYPE_SINT32: _Decoder.ReadSInt32,
_FieldDescriptor.TYPE_SINT64: _Decoder.ReadSInt64,
}

View File

@ -33,16 +33,17 @@
__author__ = 'robinson@google.com (Will Robinson)' __author__ = 'robinson@google.com (Will Robinson)'
import struct import struct
from google.protobuf import descriptor
from google.protobuf import message from google.protobuf import message
TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag. TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag.
_TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7 TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7
# These numbers identify the wire type of a protocol buffer value. # These numbers identify the wire type of a protocol buffer value.
# We use the least-significant TAG_TYPE_BITS bits of the varint-encoded # We use the least-significant TAG_TYPE_BITS bits of the varint-encoded
# tag-and-type to store one of these WIRETYPE_* constants. # tag-and-type to store one of these WIRETYPE_* constants.
# These values must match WireType enum in //net/proto2/public/wire_format.h. # These values must match WireType enum in google/protobuf/wire_format.h.
WIRETYPE_VARINT = 0 WIRETYPE_VARINT = 0
WIRETYPE_FIXED64 = 1 WIRETYPE_FIXED64 = 1
WIRETYPE_LENGTH_DELIMITED = 2 WIRETYPE_LENGTH_DELIMITED = 2
@ -93,7 +94,7 @@ def UnpackTag(tag):
"""The inverse of PackTag(). Given an unsigned 32-bit number, """The inverse of PackTag(). Given an unsigned 32-bit number,
returns a (field_number, wire_type) tuple. returns a (field_number, wire_type) tuple.
""" """
return (tag >> TAG_TYPE_BITS), (tag & _TAG_TYPE_MASK) return (tag >> TAG_TYPE_BITS), (tag & TAG_TYPE_MASK)
def ZigZagEncode(value): def ZigZagEncode(value):
@ -245,3 +246,23 @@ def _VarUInt64ByteSizeNoTag(uint64):
if uint64 > UINT64_MAX: if uint64 > UINT64_MAX:
raise message.EncodeError('Value out of range: %d' % uint64) raise message.EncodeError('Value out of range: %d' % uint64)
return 10 return 10
NON_PACKABLE_TYPES = (
descriptor.FieldDescriptor.TYPE_STRING,
descriptor.FieldDescriptor.TYPE_GROUP,
descriptor.FieldDescriptor.TYPE_MESSAGE,
descriptor.FieldDescriptor.TYPE_BYTES
)
def IsTypePackable(field_type):
"""Return true iff packable = true is valid for fields of this type.
Args:
field_type: a FieldDescriptor::Type value.
Returns:
True iff fields of this type are packable.
"""
return field_type not in NON_PACKABLE_TYPES

View File

@ -99,7 +99,7 @@ class Message(object):
Args: Args:
other_msg: Message to copy into the current one. other_msg: Message to copy into the current one.
""" """
if self == other_msg: if self is other_msg:
return return
self.Clear() self.Clear()
self.MergeFrom(other_msg) self.MergeFrom(other_msg)
@ -108,6 +108,15 @@ class Message(object):
"""Clears all data that was set in the message.""" """Clears all data that was set in the message."""
raise NotImplementedError raise NotImplementedError
def SetInParent(self):
"""Mark this as present in the parent.
This normally happens automatically when you assign a field of a
sub-message, but sometimes you want to make the sub-message
present while keeping it empty. If you find yourself using this,
you may want to reconsider your design."""
raise NotImplementedError
def IsInitialized(self): def IsInitialized(self):
"""Checks if the message is initialized. """Checks if the message is initialized.

File diff suppressed because it is too large Load Diff

View File

@ -149,6 +149,10 @@ def _MergeField(tokenizer, message):
name.append(tokenizer.ConsumeIdentifier()) name.append(tokenizer.ConsumeIdentifier())
name = '.'.join(name) name = '.'.join(name)
if not message_descriptor.is_extendable:
raise tokenizer.ParseErrorPreviousToken(
'Message type "%s" does not have extensions.' %
message_descriptor.full_name)
field = message.Extensions._FindExtensionByName(name) field = message.Extensions._FindExtensionByName(name)
if not field: if not field:
raise tokenizer.ParseErrorPreviousToken( raise tokenizer.ParseErrorPreviousToken(
@ -198,6 +202,7 @@ def _MergeField(tokenizer, message):
sub_message = message.Extensions[field] sub_message = message.Extensions[field]
else: else:
sub_message = getattr(message, field.name) sub_message = getattr(message, field.name)
sub_message.SetInParent()
while not tokenizer.TryConsume(end_token): while not tokenizer.TryConsume(end_token):
if tokenizer.AtEnd(): if tokenizer.AtEnd():

View File

@ -58,16 +58,13 @@ def MakeTestSuite():
generate_proto("../src/google/protobuf/unittest.proto") generate_proto("../src/google/protobuf/unittest.proto")
generate_proto("../src/google/protobuf/unittest_import.proto") generate_proto("../src/google/protobuf/unittest_import.proto")
generate_proto("../src/google/protobuf/unittest_mset.proto") generate_proto("../src/google/protobuf/unittest_mset.proto")
generate_proto("../src/google/protobuf/unittest_no_generic_services.proto")
generate_proto("google/protobuf/internal/more_extensions.proto") generate_proto("google/protobuf/internal/more_extensions.proto")
generate_proto("google/protobuf/internal/more_messages.proto") generate_proto("google/protobuf/internal/more_messages.proto")
import unittest import unittest
import google.protobuf.internal.generator_test as generator_test import google.protobuf.internal.generator_test as generator_test
import google.protobuf.internal.decoder_test as decoder_test
import google.protobuf.internal.descriptor_test as descriptor_test import google.protobuf.internal.descriptor_test as descriptor_test
import google.protobuf.internal.encoder_test as encoder_test
import google.protobuf.internal.input_stream_test as input_stream_test
import google.protobuf.internal.output_stream_test as output_stream_test
import google.protobuf.internal.reflection_test as reflection_test import google.protobuf.internal.reflection_test as reflection_test
import google.protobuf.internal.service_reflection_test \ import google.protobuf.internal.service_reflection_test \
as service_reflection_test as service_reflection_test
@ -77,11 +74,7 @@ def MakeTestSuite():
loader = unittest.defaultTestLoader loader = unittest.defaultTestLoader
suite = unittest.TestSuite() suite = unittest.TestSuite()
for test in [ generator_test, for test in [ generator_test,
decoder_test,
descriptor_test, descriptor_test,
encoder_test,
input_stream_test,
output_stream_test,
reflection_test, reflection_test,
service_reflection_test, service_reflection_test,
text_format_test, text_format_test,
@ -114,9 +107,7 @@ if __name__ == '__main__':
'google.protobuf.internal.containers', 'google.protobuf.internal.containers',
'google.protobuf.internal.decoder', 'google.protobuf.internal.decoder',
'google.protobuf.internal.encoder', 'google.protobuf.internal.encoder',
'google.protobuf.internal.input_stream',
'google.protobuf.internal.message_listener', 'google.protobuf.internal.message_listener',
'google.protobuf.internal.output_stream',
'google.protobuf.internal.type_checkers', 'google.protobuf.internal.type_checkers',
'google.protobuf.internal.wire_format', 'google.protobuf.internal.wire_format',
'google.protobuf.descriptor', 'google.protobuf.descriptor',

View File

@ -24,7 +24,8 @@ AM_LDFLAGS = $(PTHREAD_CFLAGS)
# If I say "dist_include_DATA", automake complains that $(includedir) is not # If I say "dist_include_DATA", automake complains that $(includedir) is not
# a "legitimate" directory for DATA. Screw you, automake. # a "legitimate" directory for DATA. Screw you, automake.
protodir = $(includedir) protodir = $(includedir)
nobase_dist_proto_DATA = google/protobuf/descriptor.proto nobase_dist_proto_DATA = google/protobuf/descriptor.proto \
google/protobuf/compiler/plugin.proto
# Not sure why these don't get cleaned automatically. # Not sure why these don't get cleaned automatically.
clean-local: clean-local:
@ -66,6 +67,8 @@ nobase_include_HEADERS = \
google/protobuf/compiler/command_line_interface.h \ google/protobuf/compiler/command_line_interface.h \
google/protobuf/compiler/importer.h \ google/protobuf/compiler/importer.h \
google/protobuf/compiler/parser.h \ google/protobuf/compiler/parser.h \
google/protobuf/compiler/plugin.h \
google/protobuf/compiler/plugin.pb.h \
google/protobuf/compiler/cpp/cpp_generator.h \ google/protobuf/compiler/cpp/cpp_generator.h \
google/protobuf/compiler/java/java_generator.h \ google/protobuf/compiler/java/java_generator.h \
google/protobuf/compiler/python/python_generator.h google/protobuf/compiler/python/python_generator.h
@ -87,6 +90,7 @@ libprotobuf_lite_la_SOURCES = \
google/protobuf/repeated_field.cc \ google/protobuf/repeated_field.cc \
google/protobuf/wire_format_lite.cc \ google/protobuf/wire_format_lite.cc \
google/protobuf/io/coded_stream.cc \ google/protobuf/io/coded_stream.cc \
google/protobuf/io/coded_stream_inl.h \
google/protobuf/io/zero_copy_stream.cc \ google/protobuf/io/zero_copy_stream.cc \
google/protobuf/io/zero_copy_stream_impl_lite.cc google/protobuf/io/zero_copy_stream_impl_lite.cc
@ -123,6 +127,10 @@ libprotoc_la_LDFLAGS = -version-info 5:0:0
libprotoc_la_SOURCES = \ libprotoc_la_SOURCES = \
google/protobuf/compiler/code_generator.cc \ google/protobuf/compiler/code_generator.cc \
google/protobuf/compiler/command_line_interface.cc \ google/protobuf/compiler/command_line_interface.cc \
google/protobuf/compiler/plugin.cc \
google/protobuf/compiler/plugin.pb.cc \
google/protobuf/compiler/subprocess.cc \
google/protobuf/compiler/subprocess.h \
google/protobuf/compiler/cpp/cpp_enum.cc \ google/protobuf/compiler/cpp/cpp_enum.cc \
google/protobuf/compiler/cpp/cpp_enum.h \ google/protobuf/compiler/cpp/cpp_enum.h \
google/protobuf/compiler/cpp/cpp_enum_field.cc \ google/protobuf/compiler/cpp/cpp_enum_field.cc \
@ -186,6 +194,7 @@ protoc_inputs = \
google/protobuf/unittest_lite.proto \ google/protobuf/unittest_lite.proto \
google/protobuf/unittest_import_lite.proto \ google/protobuf/unittest_import_lite.proto \
google/protobuf/unittest_lite_imports_nonlite.proto \ google/protobuf/unittest_lite_imports_nonlite.proto \
google/protobuf/unittest_no_generic_services.proto \
google/protobuf/compiler/cpp/cpp_test_bad_identifiers.proto google/protobuf/compiler/cpp/cpp_test_bad_identifiers.proto
EXTRA_DIST = \ EXTRA_DIST = \
@ -226,6 +235,8 @@ protoc_outputs = \
google/protobuf/unittest_custom_options.pb.h \ google/protobuf/unittest_custom_options.pb.h \
google/protobuf/unittest_lite_imports_nonlite.pb.cc \ google/protobuf/unittest_lite_imports_nonlite.pb.cc \
google/protobuf/unittest_lite_imports_nonlite.pb.h \ google/protobuf/unittest_lite_imports_nonlite.pb.h \
google/protobuf/unittest_no_generic_services.pb.cc \
google/protobuf/unittest_no_generic_services.pb.h \
google/protobuf/compiler/cpp/cpp_test_bad_identifiers.pb.cc \ google/protobuf/compiler/cpp/cpp_test_bad_identifiers.pb.cc \
google/protobuf/compiler/cpp/cpp_test_bad_identifiers.pb.h google/protobuf/compiler/cpp/cpp_test_bad_identifiers.pb.h
@ -265,7 +276,7 @@ COMMON_TEST_SOURCES = \
google/protobuf/testing/file.cc \ google/protobuf/testing/file.cc \
google/protobuf/testing/file.h google/protobuf/testing/file.h
check_PROGRAMS = protobuf-test protobuf-lazy-descriptor-test protobuf-lite-test $(GZCHECKPROGRAMS) check_PROGRAMS = protobuf-test protobuf-lazy-descriptor-test protobuf-lite-test test_plugin $(GZCHECKPROGRAMS)
protobuf_test_LDADD = $(PTHREAD_LIBS) libprotobuf.la libprotoc.la \ protobuf_test_LDADD = $(PTHREAD_LIBS) libprotobuf.la libprotoc.la \
$(top_builddir)/gtest/lib/libgtest.la \ $(top_builddir)/gtest/lib/libgtest.la \
$(top_builddir)/gtest/lib/libgtest_main.la $(top_builddir)/gtest/lib/libgtest_main.la
@ -297,9 +308,14 @@ protobuf_test_SOURCES = \
google/protobuf/io/zero_copy_stream_unittest.cc \ google/protobuf/io/zero_copy_stream_unittest.cc \
google/protobuf/compiler/command_line_interface_unittest.cc \ google/protobuf/compiler/command_line_interface_unittest.cc \
google/protobuf/compiler/importer_unittest.cc \ google/protobuf/compiler/importer_unittest.cc \
google/protobuf/compiler/mock_code_generator.cc \
google/protobuf/compiler/mock_code_generator.h \
google/protobuf/compiler/parser_unittest.cc \ google/protobuf/compiler/parser_unittest.cc \
google/protobuf/compiler/cpp/cpp_bootstrap_unittest.cc \ google/protobuf/compiler/cpp/cpp_bootstrap_unittest.cc \
google/protobuf/compiler/cpp/cpp_unittest.cc \ google/protobuf/compiler/cpp/cpp_unittest.cc \
google/protobuf/compiler/cpp/cpp_plugin_unittest.cc \
google/protobuf/compiler/java/java_plugin_unittest.cc \
google/protobuf/compiler/python/python_plugin_unittest.cc \
$(COMMON_TEST_SOURCES) $(COMMON_TEST_SOURCES)
nodist_protobuf_test_SOURCES = $(protoc_outputs) nodist_protobuf_test_SOURCES = $(protoc_outputs)
@ -325,6 +341,15 @@ protobuf_lite_test_SOURCES = \
google/protobuf/test_util_lite.h google/protobuf/test_util_lite.h
nodist_protobuf_lite_test_SOURCES = $(protoc_lite_outputs) nodist_protobuf_lite_test_SOURCES = $(protoc_lite_outputs)
# Test plugin binary.
test_plugin_LDADD = $(PTHREAD_LIBS) libprotobuf.la libprotoc.la \
$(top_builddir)/gtest/lib/libgtest.la
test_plugin_SOURCES = \
google/protobuf/compiler/mock_code_generator.cc \
google/protobuf/testing/file.cc \
google/protobuf/testing/file.h \
google/protobuf/compiler/test_plugin.cc
if HAVE_ZLIB if HAVE_ZLIB
zcgzip_LDADD = $(PTHREAD_LIBS) libprotobuf.la zcgzip_LDADD = $(PTHREAD_LIBS) libprotobuf.la
zcgzip_SOURCES = google/protobuf/testing/zcgzip.cc zcgzip_SOURCES = google/protobuf/testing/zcgzip.cc

View File

@ -34,6 +34,7 @@
#include <google/protobuf/compiler/code_generator.h> #include <google/protobuf/compiler/code_generator.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
namespace google { namespace google {
@ -43,9 +44,15 @@ namespace compiler {
CodeGenerator::~CodeGenerator() {} CodeGenerator::~CodeGenerator() {}
OutputDirectory::~OutputDirectory() {} OutputDirectory::~OutputDirectory() {}
io::ZeroCopyOutputStream* OutputDirectory::OpenForInsert(
const string& filename, const string& insertion_point) {
GOOGLE_LOG(FATAL) << "This OutputDirectory does not support insertion.";
return NULL; // make compiler happy
}
// Parses a set of comma-delimited name/value pairs. // Parses a set of comma-delimited name/value pairs.
void ParseGeneratorParameter(const string& text, void ParseGeneratorParameter(const string& text,
vector<pair<string, string> >* output) { vector<pair<string, string> >* output) {
vector<string> parts; vector<string> parts;
SplitStringUsing(text, ",", &parts); SplitStringUsing(text, ",", &parts);

View File

@ -103,6 +103,13 @@ class LIBPROTOC_EXPORT OutputDirectory {
// contain "." or ".." components. // contain "." or ".." components.
virtual io::ZeroCopyOutputStream* Open(const string& filename) = 0; virtual io::ZeroCopyOutputStream* Open(const string& filename) = 0;
// Creates a ZeroCopyOutputStream which will insert code into the given file
// at the given insertion point. See plugin.proto for more information on
// insertion points. The default implementation assert-fails -- it exists
// only for backwards-compatibility.
virtual io::ZeroCopyOutputStream* OpenForInsert(
const string& filename, const string& insertion_point);
private: private:
GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(OutputDirectory); GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(OutputDirectory);
}; };
@ -114,7 +121,7 @@ class LIBPROTOC_EXPORT OutputDirectory {
// parses to the pairs: // parses to the pairs:
// ("foo", "bar"), ("baz", ""), ("qux", "corge") // ("foo", "bar"), ("baz", ""), ("qux", "corge")
extern void ParseGeneratorParameter(const string&, extern void ParseGeneratorParameter(const string&,
vector<pair<string, string> >*); vector<pair<string, string> >*);
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf

View File

@ -32,6 +32,8 @@
// Based on original Protocol Buffers design by // Based on original Protocol Buffers design by
// Sanjay Ghemawat, Jeff Dean, and others. // Sanjay Ghemawat, Jeff Dean, and others.
#include <google/protobuf/compiler/command_line_interface.h>
#include <stdio.h> #include <stdio.h>
#include <sys/types.h> #include <sys/types.h>
#include <sys/stat.h> #include <sys/stat.h>
@ -46,15 +48,19 @@
#include <iostream> #include <iostream>
#include <ctype.h> #include <ctype.h>
#include <google/protobuf/compiler/command_line_interface.h>
#include <google/protobuf/compiler/importer.h> #include <google/protobuf/compiler/importer.h>
#include <google/protobuf/compiler/code_generator.h> #include <google/protobuf/compiler/code_generator.h>
#include <google/protobuf/compiler/plugin.pb.h>
#include <google/protobuf/compiler/subprocess.h>
#include <google/protobuf/descriptor.h> #include <google/protobuf/descriptor.h>
#include <google/protobuf/text_format.h> #include <google/protobuf/text_format.h>
#include <google/protobuf/dynamic_message.h> #include <google/protobuf/dynamic_message.h>
#include <google/protobuf/io/zero_copy_stream_impl.h> #include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/io/printer.h>
#include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
#include <google/protobuf/stubs/substitute.h>
#include <google/protobuf/stubs/map-util.h>
namespace google { namespace google {
@ -182,6 +188,8 @@ class CommandLineInterface::DiskOutputDirectory : public OutputDirectory {
// implements OutputDirectory -------------------------------------- // implements OutputDirectory --------------------------------------
io::ZeroCopyOutputStream* Open(const string& filename); io::ZeroCopyOutputStream* Open(const string& filename);
io::ZeroCopyOutputStream* OpenForInsert(
const string& filename, const string& insertion_point);
private: private:
string root_; string root_;
@ -209,11 +217,45 @@ class CommandLineInterface::ErrorReportingFileOutput
private: private:
scoped_ptr<io::FileOutputStream> file_stream_; scoped_ptr<io::FileOutputStream> file_stream_;
int file_descriptor_;
string filename_; string filename_;
DiskOutputDirectory* directory_; DiskOutputDirectory* directory_;
}; };
// Kind of like ErrorReportingFileOutput, but used when inserting
// (OutputDirectory::OpenForInsert()). In this case, we are writing to a
// temporary file, since we must copy data from the original. We copy the
// data up to the insertion point in the constructor, and the remainder in the
// destructor. We then replace the original file with the temporary, also in
// the destructor.
class CommandLineInterface::InsertionOutputStream
: public io::ZeroCopyOutputStream {
public:
InsertionOutputStream(
const string& filename,
const string& temp_filename,
const string& insertion_point,
int original_file_descriptor, // Takes ownership.
int temp_file_descriptor, // Takes ownership.
DiskOutputDirectory* directory); // Does not take ownership.
~InsertionOutputStream();
// implements ZeroCopyOutputStream ---------------------------------
bool Next(void** data, int* size) { return temp_file_->Next(data, size); }
void BackUp(int count) { temp_file_->BackUp(count); }
int64 ByteCount() const { return temp_file_->ByteCount(); }
private:
scoped_ptr<io::FileInputStream> original_file_;
scoped_ptr<io::FileOutputStream> temp_file_;
string filename_;
string temp_filename_;
DiskOutputDirectory* directory_;
// The contents of the line containing the insertion point.
string magic_line_;
};
// ------------------------------------------------------------------- // -------------------------------------------------------------------
CommandLineInterface::DiskOutputDirectory::DiskOutputDirectory( CommandLineInterface::DiskOutputDirectory::DiskOutputDirectory(
@ -242,6 +284,8 @@ bool CommandLineInterface::DiskOutputDirectory::VerifyExistence() {
return true; return true;
} }
// -------------------------------------------------------------------
io::ZeroCopyOutputStream* CommandLineInterface::DiskOutputDirectory::Open( io::ZeroCopyOutputStream* CommandLineInterface::DiskOutputDirectory::Open(
const string& filename) { const string& filename) {
// Recursively create parent directories to the output file. // Recursively create parent directories to the output file.
@ -286,7 +330,6 @@ CommandLineInterface::ErrorReportingFileOutput::ErrorReportingFileOutput(
const string& filename, const string& filename,
DiskOutputDirectory* directory) DiskOutputDirectory* directory)
: file_stream_(new io::FileOutputStream(file_descriptor)), : file_stream_(new io::FileOutputStream(file_descriptor)),
file_descriptor_(file_descriptor),
filename_(filename), filename_(filename),
directory_(directory) {} directory_(directory) {}
@ -304,6 +347,201 @@ CommandLineInterface::ErrorReportingFileOutput::~ErrorReportingFileOutput() {
} }
} }
// -------------------------------------------------------------------
io::ZeroCopyOutputStream*
CommandLineInterface::DiskOutputDirectory::OpenForInsert(
const string& filename, const string& insertion_point) {
string path = root_ + filename;
// Put the temp file in the same directory so that we can simply rename() it
// into place later.
string temp_path = path + ".protoc_temp";
// Open the original file.
int original_file;
do {
original_file = open(path.c_str(), O_RDONLY | O_BINARY);
} while (original_file < 0 && errno == EINTR);
if (original_file < 0) {
// Failed to open.
cerr << path << ": " << strerror(errno) << endl;
had_error_ = true;
// Return a dummy stream.
return new io::ArrayOutputStream(NULL, 0);
}
// Create the temp file.
int temp_file;
do {
temp_file =
open(temp_path.c_str(),
O_WRONLY | O_CREAT | O_TRUNC | O_BINARY, 0666);
} while (temp_file < 0 && errno == EINTR);
if (temp_file < 0) {
// Failed to open.
cerr << temp_path << ": " << strerror(errno) << endl;
had_error_ = true;
close(original_file);
// Return a dummy stream.
return new io::ArrayOutputStream(NULL, 0);
}
return new InsertionOutputStream(
path, temp_path, insertion_point, original_file, temp_file, this);
}
namespace {
// Helper for reading lines from a ZeroCopyInputStream.
// TODO(kenton): Put somewhere reusable?
class LineReader {
public:
LineReader(io::ZeroCopyInputStream* input)
: input_(input), buffer_(NULL), size_(0) {}
~LineReader() {
if (size_ > 0) {
input_->BackUp(size_);
}
}
bool ReadLine(string* line) {
line->clear();
while (true) {
for (int i = 0; i < size_; i++) {
if (buffer_[i] == '\n') {
line->append(buffer_, i + 1);
buffer_ += i + 1;
size_ -= i + 1;
return true;
}
}
line->append(buffer_, size_);
const void* void_buffer;
if (!input_->Next(&void_buffer, &size_)) {
buffer_ = NULL;
size_ = 0;
return false;
}
buffer_ = reinterpret_cast<const char*>(void_buffer);
}
}
private:
io::ZeroCopyInputStream* input_;
const char* buffer_;
int size_;
};
} // namespace
CommandLineInterface::InsertionOutputStream::InsertionOutputStream(
const string& filename,
const string& temp_filename,
const string& insertion_point,
int original_file_descriptor,
int temp_file_descriptor,
DiskOutputDirectory* directory)
: original_file_(new io::FileInputStream(original_file_descriptor)),
temp_file_(new io::FileOutputStream(temp_file_descriptor)),
filename_(filename),
temp_filename_(temp_filename),
directory_(directory) {
string magic_string = strings::Substitute(
"@@protoc_insertion_point($0)", insertion_point);
LineReader reader(original_file_.get());
io::Printer writer(temp_file_.get(), '$');
string line;
while (true) {
if (!reader.ReadLine(&line)) {
int error = temp_file_->GetErrno();
if (error == 0) {
cerr << filename << ": Insertion point not found: "
<< insertion_point << endl;
} else {
cerr << filename << ": " << strerror(error) << endl;
}
original_file_->Close();
original_file_.reset();
// Will finish handling error in the destructor.
break;
}
if (line.find(magic_string) != string::npos) {
// Found the magic line. Since we want to insert before it, save it for
// later.
magic_line_ = line;
break;
}
writer.PrintRaw(line);
}
}
CommandLineInterface::InsertionOutputStream::~InsertionOutputStream() {
// C-style error handling is teh best.
bool had_error = false;
if (original_file_ == NULL) {
// We had an error in the constructor.
had_error = true;
} else {
// Use CodedOutputStream for convenience, so we don't have to deal with
// copying buffers ourselves.
io::CodedOutputStream out(temp_file_.get());
out.WriteRaw(magic_line_.data(), magic_line_.size());
// Write the rest of the original file.
const void* buffer;
int size;
while (original_file_->Next(&buffer, &size)) {
out.WriteRaw(buffer, size);
}
// Close the original file.
if (!original_file_->Close()) {
cerr << filename_ << ": " << strerror(original_file_->GetErrno()) << endl;
had_error = true;
}
}
// Check if we had any errors while writing.
if (temp_file_->GetErrno() != 0) {
cerr << filename_ << ": " << strerror(temp_file_->GetErrno()) << endl;
had_error = true;
}
// Close the temp file.
if (!temp_file_->Close()) {
cerr << filename_ << ": " << strerror(temp_file_->GetErrno()) << endl;
had_error = true;
}
// If everything was successful, overwrite the original file with the temp
// file.
if (!had_error) {
if (rename(temp_filename_.c_str(), filename_.c_str()) < 0) {
cerr << filename_ << ": rename: " << strerror(errno) << endl;
had_error = true;
}
}
if (had_error) {
// We had some sort of error so let's try to delete the temp file.
remove(temp_filename_.c_str());
directory_->set_had_error(true);
}
}
// =================================================================== // ===================================================================
CommandLineInterface::CommandLineInterface() CommandLineInterface::CommandLineInterface()
@ -323,6 +561,10 @@ void CommandLineInterface::RegisterGenerator(const string& flag_name,
generators_[flag_name] = info; generators_[flag_name] = info;
} }
void CommandLineInterface::AllowPlugins(const string& exe_name_prefix) {
plugin_prefix_ = exe_name_prefix;
}
int CommandLineInterface::Run(int argc, const char* const argv[]) { int CommandLineInterface::Run(int argc, const char* const argv[]) {
Clear(); Clear();
if (!ParseArguments(argc, argv)) return 1; if (!ParseArguments(argc, argv)) return 1;
@ -346,7 +588,7 @@ int CommandLineInterface::Run(int argc, const char* const argv[]) {
vector<const FileDescriptor*> parsed_files; vector<const FileDescriptor*> parsed_files;
// Parse each file and generate output. // Parse each file.
for (int i = 0; i < input_files_.size(); i++) { for (int i = 0; i < input_files_.size(); i++) {
// Import the file. // Import the file.
const FileDescriptor* parsed_file = importer.Import(input_files_[i]); const FileDescriptor* parsed_file = importer.Import(input_files_[i]);
@ -359,13 +601,13 @@ int CommandLineInterface::Run(int argc, const char* const argv[]) {
"--disallow_services was used." << endl; "--disallow_services was used." << endl;
return 1; return 1;
} }
}
if (mode_ == MODE_COMPILE) { // Generate output.
// Generate output files. if (mode_ == MODE_COMPILE) {
for (int i = 0; i < output_directives_.size(); i++) { for (int i = 0; i < output_directives_.size(); i++) {
if (!GenerateOutput(parsed_file, output_directives_[i])) { if (!GenerateOutput(parsed_files, output_directives_[i])) {
return 1; return 1;
}
} }
} }
} }
@ -686,10 +928,37 @@ bool CommandLineInterface::InterpretArgument(const string& name,
return false; return false;
} }
} else if (name == "--plugin") {
if (plugin_prefix_.empty()) {
cerr << "This compiler does not support plugins." << endl;
return false;
}
string name;
string path;
string::size_type equals_pos = value.find_first_of('=');
if (equals_pos == string::npos) {
// Use the basename of the file.
string::size_type slash_pos = value.find_last_of('/');
if (slash_pos == string::npos) {
name = value;
} else {
name = value.substr(slash_pos + 1);
}
path = value;
} else {
name = value.substr(0, equals_pos);
path = value.substr(equals_pos + 1);
}
plugins_[name] = path;
} else { } else {
// Some other flag. Look it up in the generators list. // Some other flag. Look it up in the generators list.
GeneratorMap::const_iterator iter = generators_.find(name); const GeneratorInfo* generator_info = FindOrNull(generators_, name);
if (iter == generators_.end()) { if (generator_info == NULL &&
(plugin_prefix_.empty() || !HasSuffixString(name, "_out"))) {
cerr << "Unknown flag: " << name << endl; cerr << "Unknown flag: " << name << endl;
return false; return false;
} }
@ -703,7 +972,11 @@ bool CommandLineInterface::InterpretArgument(const string& name,
OutputDirective directive; OutputDirective directive;
directive.name = name; directive.name = name;
directive.generator = iter->second.generator; if (generator_info == NULL) {
directive.generator = NULL;
} else {
directive.generator = generator_info->generator;
}
// Split value at ':' to separate the generator parameter from the // Split value at ':' to separate the generator parameter from the
// filename. However, avoid doing this if the colon is part of a valid // filename. However, avoid doing this if the colon is part of a valid
@ -755,6 +1028,17 @@ void CommandLineInterface::PrintHelpText() {
" --error_format=FORMAT Set the format in which to print errors.\n" " --error_format=FORMAT Set the format in which to print errors.\n"
" FORMAT may be 'gcc' (the default) or 'msvs'\n" " FORMAT may be 'gcc' (the default) or 'msvs'\n"
" (Microsoft Visual Studio format)." << endl; " (Microsoft Visual Studio format)." << endl;
if (!plugin_prefix_.empty()) {
cerr <<
" --plugin=EXECUTABLE Specifies a plugin executable to use.\n"
" Normally, protoc searches the PATH for\n"
" plugins, but you may specify additional\n"
" executables not in the path using this flag.\n"
" Additionally, EXECUTABLE may be of the form\n"
" NAME=PATH, in which case the given plugin name\n"
" is mapped to the given executable even if\n"
" the executable's own name differs." << endl;
}
for (GeneratorMap::iterator iter = generators_.begin(); for (GeneratorMap::iterator iter = generators_.begin();
iter != generators_.end(); ++iter) { iter != generators_.end(); ++iter) {
@ -768,7 +1052,7 @@ void CommandLineInterface::PrintHelpText() {
} }
bool CommandLineInterface::GenerateOutput( bool CommandLineInterface::GenerateOutput(
const FileDescriptor* parsed_file, const vector<const FileDescriptor*>& parsed_files,
const OutputDirective& output_directive) { const OutputDirective& output_directive) {
// Create the output directory. // Create the output directory.
DiskOutputDirectory output_directory(output_directive.output_location); DiskOutputDirectory output_directory(output_directive.output_location);
@ -780,12 +1064,34 @@ bool CommandLineInterface::GenerateOutput(
// Call the generator. // Call the generator.
string error; string error;
if (!output_directive.generator->Generate( if (output_directive.generator == NULL) {
parsed_file, output_directive.parameter, &output_directory, &error)) { // This is a plugin.
// Generator returned an error. GOOGLE_CHECK(HasPrefixString(output_directive.name, "--") &&
cerr << parsed_file->name() << ": " << output_directive.name << ": " HasSuffixString(output_directive.name, "_out"))
<< error << endl; << "Bad name for plugin generator: " << output_directive.name;
return false;
// Strip the "--" and "_out" and add the plugin prefix.
string plugin_name = plugin_prefix_ + "gen-" +
output_directive.name.substr(2, output_directive.name.size() - 6);
if (!GeneratePluginOutput(parsed_files, plugin_name,
output_directive.parameter,
&output_directory, &error)) {
cerr << output_directive.name << ": " << error << endl;
return false;
}
} else {
// Regular generator.
for (int i = 0; i < parsed_files.size(); i++) {
if (!output_directive.generator->Generate(
parsed_files[i], output_directive.parameter,
&output_directory, &error)) {
// Generator returned an error.
cerr << output_directive.name << ": " << parsed_files[i]->name() << ": "
<< error << endl;
return false;
}
}
} }
// Check for write errors. // Check for write errors.
@ -796,6 +1102,84 @@ bool CommandLineInterface::GenerateOutput(
return true; return true;
} }
bool CommandLineInterface::GeneratePluginOutput(
const vector<const FileDescriptor*>& parsed_files,
const string& plugin_name,
const string& parameter,
OutputDirectory* output_directory,
string* error) {
CodeGeneratorRequest request;
CodeGeneratorResponse response;
// Build the request.
if (!parameter.empty()) {
request.set_parameter(parameter);
}
set<const FileDescriptor*> already_seen;
for (int i = 0; i < parsed_files.size(); i++) {
request.add_file_to_generate(parsed_files[i]->name());
GetTransitiveDependencies(parsed_files[i], &already_seen,
request.mutable_proto_file());
}
// Invoke the plugin.
Subprocess subprocess;
if (plugins_.count(plugin_name) > 0) {
subprocess.Start(plugins_[plugin_name], Subprocess::EXACT_NAME);
} else {
subprocess.Start(plugin_name, Subprocess::SEARCH_PATH);
}
string communicate_error;
if (!subprocess.Communicate(request, &response, &communicate_error)) {
*error = strings::Substitute("$0: $1", plugin_name, communicate_error);
return false;
}
// Write the files. We do this even if there was a generator error in order
// to match the behavior of a compiled-in generator.
scoped_ptr<io::ZeroCopyOutputStream> current_output;
for (int i = 0; i < response.file_size(); i++) {
const CodeGeneratorResponse::File& output_file = response.file(i);
if (!output_file.insertion_point().empty()) {
// Open a file for insert.
// We reset current_output to NULL first so that the old file is closed
// before the new one is opened.
current_output.reset();
current_output.reset(output_directory->OpenForInsert(
output_file.name(), output_file.insertion_point()));
} else if (!output_file.name().empty()) {
// Starting a new file. Open it.
// We reset current_output to NULL first so that the old file is closed
// before the new one is opened.
current_output.reset();
current_output.reset(output_directory->Open(output_file.name()));
} else if (current_output == NULL) {
*error = strings::Substitute(
"$0: First file chunk returned by plugin did not specify a file name.",
plugin_name);
return false;
}
// Use CodedOutputStream for convenience; otherwise we'd need to provide
// our own buffer-copying loop.
io::CodedOutputStream writer(current_output.get());
writer.WriteString(output_file.content());
}
// Check for errors.
if (!response.error().empty()) {
// Generator returned an error.
*error = response.error();
return false;
}
return true;
}
bool CommandLineInterface::EncodeOrDecode(const DescriptorPool* pool) { bool CommandLineInterface::EncodeOrDecode(const DescriptorPool* pool) {
// Look up the type. // Look up the type.
const Descriptor* type = pool->FindMessageTypeByName(codec_type_); const Descriptor* type = pool->FindMessageTypeByName(codec_type_);
@ -862,22 +1246,16 @@ bool CommandLineInterface::EncodeOrDecode(const DescriptorPool* pool) {
bool CommandLineInterface::WriteDescriptorSet( bool CommandLineInterface::WriteDescriptorSet(
const vector<const FileDescriptor*> parsed_files) { const vector<const FileDescriptor*> parsed_files) {
FileDescriptorSet file_set; FileDescriptorSet file_set;
set<const FileDescriptor*> already_added;
vector<const FileDescriptor*> to_add(parsed_files);
while (!to_add.empty()) { if (imports_in_descriptor_set_) {
const FileDescriptor* file = to_add.back(); set<const FileDescriptor*> already_seen;
to_add.pop_back(); for (int i = 0; i < parsed_files.size(); i++) {
if (already_added.insert(file).second) { GetTransitiveDependencies(
// This file was not already in the set. parsed_files[i], &already_seen, file_set.mutable_file());
file->CopyTo(file_set.add_file()); }
} else {
if (imports_in_descriptor_set_) { for (int i = 0; i < parsed_files.size(); i++) {
// Add all of this file's dependencies. parsed_files[i]->CopyTo(file_set.add_file());
for (int i = 0; i < file->dependency_count(); i++) {
to_add.push_back(file->dependency(i));
}
}
} }
} }
@ -906,6 +1284,24 @@ bool CommandLineInterface::WriteDescriptorSet(
return true; return true;
} }
void CommandLineInterface::GetTransitiveDependencies(
const FileDescriptor* file,
set<const FileDescriptor*>* already_seen,
RepeatedPtrField<FileDescriptorProto>* output) {
if (!already_seen->insert(file).second) {
// Already saw this file. Skip.
return;
}
// Add all dependencies.
for (int i = 0; i < file->dependency_count(); i++) {
GetTransitiveDependencies(file->dependency(i), already_seen, output);
}
// Add this file.
file->CopyTo(output->Add());
}
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf

View File

@ -50,10 +50,13 @@ namespace protobuf {
class FileDescriptor; // descriptor.h class FileDescriptor; // descriptor.h
class DescriptorPool; // descriptor.h class DescriptorPool; // descriptor.h
class FileDescriptorProto; // descriptor.pb.h
template<typename T> class RepeatedPtrField; // repeated_field.h
namespace compiler { namespace compiler {
class CodeGenerator; // code_generator.h class CodeGenerator; // code_generator.h
class OutputDirectory; // code_generator.h
class DiskSourceTree; // importer.h class DiskSourceTree; // importer.h
// This class implements the command-line interface to the protocol compiler. // This class implements the command-line interface to the protocol compiler.
@ -109,6 +112,37 @@ class LIBPROTOC_EXPORT CommandLineInterface {
CodeGenerator* generator, CodeGenerator* generator,
const string& help_text); const string& help_text);
// Enables "plugins". In this mode, if a command-line flag ends with "_out"
// but does not match any registered generator, the compiler will attempt to
// find a "plugin" to implement the generator. Plugins are just executables.
// They should live somewhere in the PATH.
//
// The compiler determines the executable name to search for by concatenating
// exe_name_prefix with the unrecognized flag name, removing "_out". So, for
// example, if exe_name_prefix is "protoc-" and you pass the flag --foo_out,
// the compiler will try to run the program "protoc-foo".
//
// The plugin program should implement the following usage:
// plugin [--out=OUTDIR] [--parameter=PARAMETER] PROTO_FILES < DESCRIPTORS
// --out indicates the output directory (as passed to the --foo_out
// parameter); if omitted, the current directory should be used. --parameter
// gives the generator parameter, if any was provided. The PROTO_FILES list
// the .proto files which were given on the compiler command-line; these are
// the files for which the plugin is expected to generate output code.
// Finally, DESCRIPTORS is an encoded FileDescriptorSet (as defined in
// descriptor.proto). This is piped to the plugin's stdin. The set will
// include descriptors for all the files listed in PROTO_FILES as well as
// all files that they import. The plugin MUST NOT attempt to read the
// PROTO_FILES directly -- it must use the FileDescriptorSet.
//
// The plugin should generate whatever files are necessary, as code generators
// normally do. It should write the names of all files it generates to
// stdout. The names should be relative to the output directory, NOT absolute
// names or relative to the current directory. If any errors occur, error
// messages should be written to stderr. If an error is fatal, the plugin
// should exit with a non-zero exit code.
void AllowPlugins(const string& exe_name_prefix);
// Run the Protocol Compiler with the given command-line parameters. // Run the Protocol Compiler with the given command-line parameters.
// Returns the error code which should be returned by main(). // Returns the error code which should be returned by main().
// //
@ -142,6 +176,7 @@ class LIBPROTOC_EXPORT CommandLineInterface {
class ErrorPrinter; class ErrorPrinter;
class DiskOutputDirectory; class DiskOutputDirectory;
class ErrorReportingFileOutput; class ErrorReportingFileOutput;
class InsertionOutputStream;
// Clear state from previous Run(). // Clear state from previous Run().
void Clear(); void Clear();
@ -176,8 +211,13 @@ class LIBPROTOC_EXPORT CommandLineInterface {
// Generate the given output file from the given input. // Generate the given output file from the given input.
struct OutputDirective; // see below struct OutputDirective; // see below
bool GenerateOutput(const FileDescriptor* proto_file, bool GenerateOutput(const vector<const FileDescriptor*>& parsed_files,
const OutputDirective& output_directive); const OutputDirective& output_directive);
bool GeneratePluginOutput(const vector<const FileDescriptor*>& parsed_files,
const string& plugin_name,
const string& parameter,
OutputDirectory* output_directory,
string* error);
// Implements --encode and --decode. // Implements --encode and --decode.
bool EncodeOrDecode(const DescriptorPool* pool); bool EncodeOrDecode(const DescriptorPool* pool);
@ -185,6 +225,17 @@ class LIBPROTOC_EXPORT CommandLineInterface {
// Implements the --descriptor_set_out option. // Implements the --descriptor_set_out option.
bool WriteDescriptorSet(const vector<const FileDescriptor*> parsed_files); bool WriteDescriptorSet(const vector<const FileDescriptor*> parsed_files);
// Get all transitive dependencies of the given file (including the file
// itself), adding them to the given list of FileDescriptorProtos. The
// protos will be ordered such that every file is listed before any file that
// depends on it, so that you can call DescriptorPool::BuildFile() on them
// in order. Any files in *already_seen will not be added, and each file
// added will be inserted into *already_seen.
static void GetTransitiveDependencies(
const FileDescriptor* file,
set<const FileDescriptor*>* already_seen,
RepeatedPtrField<FileDescriptorProto>* output);
// ----------------------------------------------------------------- // -----------------------------------------------------------------
// The name of the executable as invoked (i.e. argv[0]). // The name of the executable as invoked (i.e. argv[0]).
@ -201,6 +252,14 @@ class LIBPROTOC_EXPORT CommandLineInterface {
typedef map<string, GeneratorInfo> GeneratorMap; typedef map<string, GeneratorInfo> GeneratorMap;
GeneratorMap generators_; GeneratorMap generators_;
// See AllowPlugins(). If this is empty, plugins aren't allowed.
string plugin_prefix_;
// Maps specific plugin names to files. When executing a plugin, this map
// is searched first to find the plugin executable. If not found here, the
// PATH (or other OS-specific search strategy) is searched.
map<string, string> plugins_;
// Stuff parsed from command line. // Stuff parsed from command line.
enum Mode { enum Mode {
MODE_COMPILE, // Normal mode: parse .proto files and compile them. MODE_COMPILE, // Normal mode: parse .proto files and compile them.
@ -223,8 +282,8 @@ class LIBPROTOC_EXPORT CommandLineInterface {
// output_directives_ lists all the files we are supposed to output and what // output_directives_ lists all the files we are supposed to output and what
// generator to use for each. // generator to use for each.
struct OutputDirective { struct OutputDirective {
string name; string name; // E.g. "--foo_out"
CodeGenerator* generator; CodeGenerator* generator; // NULL for plugins
string parameter; string parameter;
string output_location; string output_location;
}; };

View File

@ -123,8 +123,11 @@ TEST(BootstrapTest, GeneratedDescriptorMatches) {
Importer importer(&source_tree, &error_collector); Importer importer(&source_tree, &error_collector);
const FileDescriptor* proto_file = const FileDescriptor* proto_file =
importer.Import("google/protobuf/descriptor.proto"); importer.Import("google/protobuf/descriptor.proto");
const FileDescriptor* plugin_proto_file =
importer.Import("google/protobuf/compiler/plugin.proto");
EXPECT_EQ("", error_collector.text_); EXPECT_EQ("", error_collector.text_);
ASSERT_TRUE(proto_file != NULL); ASSERT_TRUE(proto_file != NULL);
ASSERT_TRUE(plugin_proto_file != NULL);
CppGenerator generator; CppGenerator generator;
MockOutputDirectory output_directory; MockOutputDirectory output_directory;
@ -133,11 +136,18 @@ TEST(BootstrapTest, GeneratedDescriptorMatches) {
parameter = "dllexport_decl=LIBPROTOBUF_EXPORT"; parameter = "dllexport_decl=LIBPROTOBUF_EXPORT";
ASSERT_TRUE(generator.Generate(proto_file, parameter, ASSERT_TRUE(generator.Generate(proto_file, parameter,
&output_directory, &error)); &output_directory, &error));
parameter = "dllexport_decl=LIBPROTOC_EXPORT";
ASSERT_TRUE(generator.Generate(plugin_proto_file, parameter,
&output_directory, &error));
output_directory.ExpectFileMatches("google/protobuf/descriptor.pb.h", output_directory.ExpectFileMatches("google/protobuf/descriptor.pb.h",
"google/protobuf/descriptor.pb.h"); "google/protobuf/descriptor.pb.h");
output_directory.ExpectFileMatches("google/protobuf/descriptor.pb.cc", output_directory.ExpectFileMatches("google/protobuf/descriptor.pb.cc",
"google/protobuf/descriptor.pb.cc"); "google/protobuf/descriptor.pb.cc");
output_directory.ExpectFileMatches("google/protobuf/compiler/plugin.pb.h",
"google/protobuf/compiler/plugin.pb.h");
output_directory.ExpectFileMatches("google/protobuf/compiler/plugin.pb.cc",
"google/protobuf/compiler/plugin.pb.cc");
} }
} // namespace } // namespace
@ -145,5 +155,4 @@ TEST(BootstrapTest, GeneratedDescriptorMatches) {
} // namespace cpp } // namespace cpp
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google

View File

@ -98,6 +98,7 @@ void EnumGenerator::GenerateDefinition(io::Printer* printer) {
"$dllexport$bool $classname$_IsValid(int value);\n" "$dllexport$bool $classname$_IsValid(int value);\n"
"const $classname$ $prefix$$short_name$_MIN = $prefix$$min_name$;\n" "const $classname$ $prefix$$short_name$_MIN = $prefix$$min_name$;\n"
"const $classname$ $prefix$$short_name$_MAX = $prefix$$max_name$;\n" "const $classname$ $prefix$$short_name$_MAX = $prefix$$max_name$;\n"
"const int $prefix$$short_name$_ARRAYSIZE = $prefix$$short_name$_MAX + 1;\n"
"\n"); "\n");
if (HasDescriptorMethods(descriptor_->file())) { if (HasDescriptorMethods(descriptor_->file())) {
@ -149,17 +150,21 @@ void EnumGenerator::GenerateSymbolImports(io::Printer* printer) {
"static const $nested_name$ $nested_name$_MIN =\n" "static const $nested_name$ $nested_name$_MIN =\n"
" $classname$_$nested_name$_MIN;\n" " $classname$_$nested_name$_MIN;\n"
"static const $nested_name$ $nested_name$_MAX =\n" "static const $nested_name$ $nested_name$_MAX =\n"
" $classname$_$nested_name$_MAX;\n"); " $classname$_$nested_name$_MAX;\n"
"static const int $nested_name$_ARRAYSIZE =\n"
" $classname$_$nested_name$_ARRAYSIZE;\n");
if (HasDescriptorMethods(descriptor_->file())) { if (HasDescriptorMethods(descriptor_->file())) {
printer->Print(vars, printer->Print(vars,
"static inline const ::google::protobuf::EnumDescriptor*\n" "static inline const ::google::protobuf::EnumDescriptor*\n"
"$nested_name$_descriptor() {\n" "$nested_name$_descriptor() {\n"
" return $classname$_descriptor();\n" " return $classname$_descriptor();\n"
"}\n" "}\n");
printer->Print(vars,
"static inline const ::std::string& $nested_name$_Name($nested_name$ value) {\n" "static inline const ::std::string& $nested_name$_Name($nested_name$ value) {\n"
" return $classname$_Name(value);\n" " return $classname$_Name(value);\n"
"}\n" "}\n");
printer->Print(vars,
"static inline bool $nested_name$_Parse(const ::std::string& name,\n" "static inline bool $nested_name$_Parse(const ::std::string& name,\n"
" $nested_name$* value) {\n" " $nested_name$* value) {\n"
" return $classname$_Parse(name, value);\n" " return $classname$_Parse(name, value);\n"
@ -240,7 +245,8 @@ void EnumGenerator::GenerateMethods(io::Printer* printer) {
} }
printer->Print(vars, printer->Print(vars,
"const $classname$ $parent$::$nested_name$_MIN;\n" "const $classname$ $parent$::$nested_name$_MIN;\n"
"const $classname$ $parent$::$nested_name$_MAX;\n"); "const $classname$ $parent$::$nested_name$_MAX;\n"
"const int $parent$::$nested_name$_ARRAYSIZE;\n");
printer->Print("#endif // _MSC_VER\n"); printer->Print("#endif // _MSC_VER\n");
} }

View File

@ -114,7 +114,9 @@ void EnumFieldGenerator::
GenerateMergeFromCodedStream(io::Printer* printer) const { GenerateMergeFromCodedStream(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"int value;\n" "int value;\n"
"DO_(::google::protobuf::internal::WireFormatLite::ReadEnum(input, &value));\n" "DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive<\n"
" int, ::google::protobuf::internal::WireFormatLite::TYPE_ENUM>(\n"
" input, &value)));\n"
"if ($type$_IsValid(value)) {\n" "if ($type$_IsValid(value)) {\n"
" set_$name$(static_cast< $type$ >(value));\n"); " set_$name$(static_cast< $type$ >(value));\n");
if (HasUnknownFields(descriptor_->file())) { if (HasUnknownFields(descriptor_->file())) {
@ -170,24 +172,17 @@ GeneratePrivateMembers(io::Printer* printer) const {
void RepeatedEnumFieldGenerator:: void RepeatedEnumFieldGenerator::
GenerateAccessorDeclarations(io::Printer* printer) const { GenerateAccessorDeclarations(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"inline const ::google::protobuf::RepeatedField<int>& $name$() const$deprecation$;\n"
"inline ::google::protobuf::RepeatedField<int>* mutable_$name$()$deprecation$;\n"
"inline $type$ $name$(int index) const$deprecation$;\n" "inline $type$ $name$(int index) const$deprecation$;\n"
"inline void set_$name$(int index, $type$ value)$deprecation$;\n" "inline void set_$name$(int index, $type$ value)$deprecation$;\n"
"inline void add_$name$($type$ value)$deprecation$;\n"); "inline void add_$name$($type$ value)$deprecation$;\n");
printer->Print(variables_,
"inline const ::google::protobuf::RepeatedField<int>& $name$() const$deprecation$;\n"
"inline ::google::protobuf::RepeatedField<int>* mutable_$name$()$deprecation$;\n");
} }
void RepeatedEnumFieldGenerator:: void RepeatedEnumFieldGenerator::
GenerateInlineAccessorDefinitions(io::Printer* printer) const { GenerateInlineAccessorDefinitions(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"inline const ::google::protobuf::RepeatedField<int>&\n"
"$classname$::$name$() const {\n"
" return $name$_;\n"
"}\n"
"inline ::google::protobuf::RepeatedField<int>*\n"
"$classname$::mutable_$name$() {\n"
" return &$name$_;\n"
"}\n"
"inline $type$ $classname$::$name$(int index) const {\n" "inline $type$ $classname$::$name$(int index) const {\n"
" return static_cast< $type$ >($name$_.Get(index));\n" " return static_cast< $type$ >($name$_.Get(index));\n"
"}\n" "}\n"
@ -199,6 +194,15 @@ GenerateInlineAccessorDefinitions(io::Printer* printer) const {
" GOOGLE_DCHECK($type$_IsValid(value));\n" " GOOGLE_DCHECK($type$_IsValid(value));\n"
" $name$_.Add(value);\n" " $name$_.Add(value);\n"
"}\n"); "}\n");
printer->Print(variables_,
"inline const ::google::protobuf::RepeatedField<int>&\n"
"$classname$::$name$() const {\n"
" return $name$_;\n"
"}\n"
"inline ::google::protobuf::RepeatedField<int>*\n"
"$classname$::mutable_$name$() {\n"
" return &$name$_;\n"
"}\n");
} }
void RepeatedEnumFieldGenerator:: void RepeatedEnumFieldGenerator::
@ -223,7 +227,33 @@ GenerateConstructorCode(io::Printer* printer) const {
void RepeatedEnumFieldGenerator:: void RepeatedEnumFieldGenerator::
GenerateMergeFromCodedStream(io::Printer* printer) const { GenerateMergeFromCodedStream(io::Printer* printer) const {
if (descriptor_->options().packed()) { // Don't use ReadRepeatedPrimitive here so that the enum can be validated.
printer->Print(variables_,
"int value;\n"
"DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive<\n"
" int, ::google::protobuf::internal::WireFormatLite::TYPE_ENUM>(\n"
" input, &value)));\n"
"if ($type$_IsValid(value)) {\n"
" add_$name$(static_cast< $type$ >(value));\n");
if (HasUnknownFields(descriptor_->file())) {
printer->Print(variables_,
"} else {\n"
" mutable_unknown_fields()->AddVarint($number$, value);\n");
}
printer->Print("}\n");
}
void RepeatedEnumFieldGenerator::
GenerateMergeFromCodedStreamWithPacking(io::Printer* printer) const {
if (!descriptor_->options().packed()) {
// We use a non-inlined implementation in this case, since this path will
// rarely be executed.
printer->Print(variables_,
"DO_((::google::protobuf::internal::WireFormatLite::ReadPackedEnumNoInline(\n"
" input,\n"
" &$type$_IsValid,\n"
" this->mutable_$name$())));\n");
} else {
printer->Print(variables_, printer->Print(variables_,
"::google::protobuf::uint32 length;\n" "::google::protobuf::uint32 length;\n"
"DO_(input->ReadVarint32(&length));\n" "DO_(input->ReadVarint32(&length));\n"
@ -231,25 +261,14 @@ GenerateMergeFromCodedStream(io::Printer* printer) const {
"input->PushLimit(length);\n" "input->PushLimit(length);\n"
"while (input->BytesUntilLimit() > 0) {\n" "while (input->BytesUntilLimit() > 0) {\n"
" int value;\n" " int value;\n"
" DO_(::google::protobuf::internal::WireFormatLite::ReadEnum(input, &value));\n" " DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive<\n"
" int, ::google::protobuf::internal::WireFormatLite::TYPE_ENUM>(\n"
" input, &value)));\n"
" if ($type$_IsValid(value)) {\n" " if ($type$_IsValid(value)) {\n"
" add_$name$(static_cast< $type$ >(value));\n" " add_$name$(static_cast< $type$ >(value));\n"
" }\n" " }\n"
"}\n" "}\n"
"input->PopLimit(limit);\n"); "input->PopLimit(limit);\n");
} else {
printer->Print(variables_,
"int value;\n"
"DO_(::google::protobuf::internal::WireFormatLite::ReadEnum(input, &value));\n"
"if ($type$_IsValid(value)) {\n"
" add_$name$(static_cast< $type$ >(value));\n");
if (HasUnknownFields(descriptor_->file())) {
printer->Print(variables_,
"} else {\n"
" mutable_unknown_fields()->AddVarint($number$, value);\n");
}
printer->Print(variables_,
"}\n");
} }
} }

View File

@ -83,6 +83,7 @@ class RepeatedEnumFieldGenerator : public FieldGenerator {
void GenerateSwappingCode(io::Printer* printer) const; void GenerateSwappingCode(io::Printer* printer) const;
void GenerateConstructorCode(io::Printer* printer) const; void GenerateConstructorCode(io::Printer* printer) const;
void GenerateMergeFromCodedStream(io::Printer* printer) const; void GenerateMergeFromCodedStream(io::Printer* printer) const;
void GenerateMergeFromCodedStreamWithPacking(io::Printer* printer) const;
void GenerateSerializeWithCachedSizes(io::Printer* printer) const; void GenerateSerializeWithCachedSizes(io::Printer* printer) const;
void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const; void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const;
void GenerateByteSize(io::Printer* printer) const; void GenerateByteSize(io::Printer* printer) const;

View File

@ -33,6 +33,7 @@
// Sanjay Ghemawat, Jeff Dean, and others. // Sanjay Ghemawat, Jeff Dean, and others.
#include <google/protobuf/compiler/cpp/cpp_extension.h> #include <google/protobuf/compiler/cpp/cpp_extension.h>
#include <map>
#include <google/protobuf/compiler/cpp/cpp_helpers.h> #include <google/protobuf/compiler/cpp/cpp_helpers.h>
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
#include <google/protobuf/io/printer.h> #include <google/protobuf/io/printer.h>
@ -43,6 +44,18 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace cpp { namespace cpp {
namespace {
// Returns the fully-qualified class name of the message that this field
// extends. This function is used in the Google-internal code to handle some
// legacy cases.
string ExtendeeClassName(const FieldDescriptor* descriptor) {
const Descriptor* extendee = descriptor->containing_type();
return ClassName(extendee, true);
}
} // anonymous namespace
ExtensionGenerator::ExtensionGenerator(const FieldDescriptor* descriptor, ExtensionGenerator::ExtensionGenerator(const FieldDescriptor* descriptor,
const string& dllexport_decl) const string& dllexport_decl)
: descriptor_(descriptor), : descriptor_(descriptor),
@ -80,7 +93,7 @@ ExtensionGenerator::~ExtensionGenerator() {}
void ExtensionGenerator::GenerateDeclaration(io::Printer* printer) { void ExtensionGenerator::GenerateDeclaration(io::Printer* printer) {
map<string, string> vars; map<string, string> vars;
vars["extendee" ] = ClassName(descriptor_->containing_type(), true); vars["extendee" ] = ExtendeeClassName(descriptor_);
vars["number" ] = SimpleItoa(descriptor_->number()); vars["number" ] = SimpleItoa(descriptor_->number());
vars["type_traits" ] = type_traits_; vars["type_traits" ] = type_traits_;
vars["name" ] = descriptor_->name(); vars["name" ] = descriptor_->name();
@ -106,6 +119,7 @@ void ExtensionGenerator::GenerateDeclaration(io::Printer* printer) {
" ::google::protobuf::internal::$type_traits$, $field_type$, $packed$ >\n" " ::google::protobuf::internal::$type_traits$, $field_type$, $packed$ >\n"
" $name$;\n" " $name$;\n"
); );
} }
void ExtensionGenerator::GenerateDefinition(io::Printer* printer) { void ExtensionGenerator::GenerateDefinition(io::Printer* printer) {
@ -115,7 +129,7 @@ void ExtensionGenerator::GenerateDefinition(io::Printer* printer) {
string name = scope + descriptor_->name(); string name = scope + descriptor_->name();
map<string, string> vars; map<string, string> vars;
vars["extendee" ] = ClassName(descriptor_->containing_type(), true); vars["extendee" ] = ExtendeeClassName(descriptor_);
vars["type_traits" ] = type_traits_; vars["type_traits" ] = type_traits_;
vars["name" ] = name; vars["name" ] = name;
vars["constant_name"] = FieldConstantName(descriptor_); vars["constant_name"] = FieldConstantName(descriptor_);
@ -154,7 +168,7 @@ void ExtensionGenerator::GenerateDefinition(io::Printer* printer) {
void ExtensionGenerator::GenerateRegistration(io::Printer* printer) { void ExtensionGenerator::GenerateRegistration(io::Printer* printer) {
map<string, string> vars; map<string, string> vars;
vars["extendee" ] = ClassName(descriptor_->containing_type(), true); vars["extendee" ] = ExtendeeClassName(descriptor_);
vars["number" ] = SimpleItoa(descriptor_->number()); vars["number" ] = SimpleItoa(descriptor_->number());
vars["field_type" ] = SimpleItoa(static_cast<int>(descriptor_->type())); vars["field_type" ] = SimpleItoa(static_cast<int>(descriptor_->type()));
vars["is_repeated"] = descriptor_->is_repeated() ? "true" : "false"; vars["is_repeated"] = descriptor_->is_repeated() ? "true" : "false";
@ -193,5 +207,4 @@ void ExtensionGenerator::GenerateRegistration(io::Printer* printer) {
} // namespace cpp } // namespace cpp
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google

View File

@ -40,6 +40,7 @@
#include <google/protobuf/compiler/cpp/cpp_message_field.h> #include <google/protobuf/compiler/cpp/cpp_message_field.h>
#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/wire_format.h> #include <google/protobuf/wire_format.h>
#include <google/protobuf/io/printer.h>
#include <google/protobuf/stubs/common.h> #include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
@ -61,11 +62,24 @@ void SetCommonFieldVariables(const FieldDescriptor* descriptor,
(*variables)["tag_size"] = SimpleItoa( (*variables)["tag_size"] = SimpleItoa(
WireFormat::TagSize(descriptor->number(), descriptor->type())); WireFormat::TagSize(descriptor->number(), descriptor->type()));
(*variables)["deprecation"] = descriptor->options().deprecated() (*variables)["deprecation"] = descriptor->options().deprecated()
? " DEPRECATED_PROTOBUF_FIELD" : ""; ? " PROTOBUF_DEPRECATED" : "";
} }
FieldGenerator::~FieldGenerator() {} FieldGenerator::~FieldGenerator() {}
void FieldGenerator::
GenerateMergeFromCodedStreamWithPacking(io::Printer* printer) const {
// Reaching here indicates a bug. Cases are:
// - This FieldGenerator should support packing, but this method should be
// overridden.
// - This FieldGenerator doesn't support packing, and this method should
// never have been called.
GOOGLE_LOG(FATAL) << "GenerateMergeFromCodedStreamWithPacking() "
<< "called on field generator that does not support packing.";
}
FieldGeneratorMap::FieldGeneratorMap(const Descriptor* descriptor) FieldGeneratorMap::FieldGeneratorMap(const Descriptor* descriptor)
: descriptor_(descriptor), : descriptor_(descriptor),
field_generators_( field_generators_(
@ -82,7 +96,11 @@ FieldGenerator* FieldGeneratorMap::MakeGenerator(const FieldDescriptor* field) {
case FieldDescriptor::CPPTYPE_MESSAGE: case FieldDescriptor::CPPTYPE_MESSAGE:
return new RepeatedMessageFieldGenerator(field); return new RepeatedMessageFieldGenerator(field);
case FieldDescriptor::CPPTYPE_STRING: case FieldDescriptor::CPPTYPE_STRING:
return new RepeatedStringFieldGenerator(field); switch (field->options().ctype()) {
default: // RepeatedStringFieldGenerator handles unknown ctypes.
case FieldOptions::STRING:
return new RepeatedStringFieldGenerator(field);
}
case FieldDescriptor::CPPTYPE_ENUM: case FieldDescriptor::CPPTYPE_ENUM:
return new RepeatedEnumFieldGenerator(field); return new RepeatedEnumFieldGenerator(field);
default: default:
@ -93,7 +111,11 @@ FieldGenerator* FieldGeneratorMap::MakeGenerator(const FieldDescriptor* field) {
case FieldDescriptor::CPPTYPE_MESSAGE: case FieldDescriptor::CPPTYPE_MESSAGE:
return new MessageFieldGenerator(field); return new MessageFieldGenerator(field);
case FieldDescriptor::CPPTYPE_STRING: case FieldDescriptor::CPPTYPE_STRING:
return new StringFieldGenerator(field); switch (field->options().ctype()) {
default: // StringFieldGenerator handles unknown ctypes.
case FieldOptions::STRING:
return new StringFieldGenerator(field);
}
case FieldDescriptor::CPPTYPE_ENUM: case FieldDescriptor::CPPTYPE_ENUM:
return new EnumFieldGenerator(field); return new EnumFieldGenerator(field);
default: default:
@ -110,6 +132,7 @@ const FieldGenerator& FieldGeneratorMap::get(
return *field_generators_[field->index()]; return *field_generators_[field->index()];
} }
} // namespace cpp } // namespace cpp
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf

View File

@ -118,6 +118,11 @@ class FieldGenerator {
// message's MergeFromCodedStream() method. // message's MergeFromCodedStream() method.
virtual void GenerateMergeFromCodedStream(io::Printer* printer) const = 0; virtual void GenerateMergeFromCodedStream(io::Printer* printer) const = 0;
// Generate lines to decode this field from a packed value, which will be
// placed inside the message's MergeFromCodedStream() method.
virtual void GenerateMergeFromCodedStreamWithPacking(io::Printer* printer)
const;
// Generate lines to serialize this field, which are placed within the // Generate lines to serialize this field, which are placed within the
// message's SerializeWithCachedSizes() method. // message's SerializeWithCachedSizes() method.
virtual void GenerateSerializeWithCachedSizes(io::Printer* printer) const = 0; virtual void GenerateSerializeWithCachedSizes(io::Printer* printer) const = 0;
@ -153,6 +158,7 @@ class FieldGeneratorMap {
GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FieldGeneratorMap); GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(FieldGeneratorMap);
}; };
} // namespace cpp } // namespace cpp
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf

View File

@ -38,6 +38,7 @@
#include <google/protobuf/compiler/cpp/cpp_extension.h> #include <google/protobuf/compiler/cpp/cpp_extension.h>
#include <google/protobuf/compiler/cpp/cpp_helpers.h> #include <google/protobuf/compiler/cpp/cpp_helpers.h>
#include <google/protobuf/compiler/cpp/cpp_message.h> #include <google/protobuf/compiler/cpp/cpp_message.h>
#include <google/protobuf/compiler/cpp/cpp_field.h>
#include <google/protobuf/io/printer.h> #include <google/protobuf/io/printer.h>
#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
@ -93,12 +94,14 @@ void FileGenerator::GenerateHeader(io::Printer* printer) {
// Generate top of header. // Generate top of header.
printer->Print( printer->Print(
"// Generated by the protocol buffer compiler. DO NOT EDIT!\n" "// Generated by the protocol buffer compiler. DO NOT EDIT!\n"
"// source: $filename$\n"
"\n" "\n"
"#ifndef PROTOBUF_$filename_identifier$__INCLUDED\n" "#ifndef PROTOBUF_$filename_identifier$__INCLUDED\n"
"#define PROTOBUF_$filename_identifier$__INCLUDED\n" "#define PROTOBUF_$filename_identifier$__INCLUDED\n"
"\n" "\n"
"#include <string>\n" "#include <string>\n"
"\n", "\n",
"filename", file_->name(),
"filename_identifier", filename_identifier); "filename_identifier", filename_identifier);
printer->Print( printer->Print(
@ -132,19 +135,23 @@ void FileGenerator::GenerateHeader(io::Printer* printer) {
if (HasDescriptorMethods(file_)) { if (HasDescriptorMethods(file_)) {
printer->Print( printer->Print(
"#include <google/protobuf/generated_message_reflection.h>\n"); "#include <google/protobuf/generated_message_reflection.h>\n");
if (file_->service_count() > 0) {
printer->Print(
"#include <google/protobuf/service.h>\n");
}
} }
if (HasGenericServices(file_)) {
printer->Print(
"#include <google/protobuf/service.h>\n");
}
for (int i = 0; i < file_->dependency_count(); i++) { for (int i = 0; i < file_->dependency_count(); i++) {
printer->Print( printer->Print(
"#include \"$dependency$.pb.h\"\n", "#include \"$dependency$.pb.h\"\n",
"dependency", StripProto(file_->dependency(i)->name())); "dependency", StripProto(file_->dependency(i)->name()));
} }
printer->Print(
"// @@protoc_insertion_point(includes)\n");
// Open namespace. // Open namespace.
GenerateNamespaceOpeners(printer); GenerateNamespaceOpeners(printer);
@ -198,7 +205,7 @@ void FileGenerator::GenerateHeader(io::Printer* printer) {
printer->Print(kThickSeparator); printer->Print(kThickSeparator);
printer->Print("\n"); printer->Print("\n");
if (HasDescriptorMethods(file_)) { if (HasGenericServices(file_)) {
// Generate service definitions. // Generate service definitions.
for (int i = 0; i < file_->service_count(); i++) { for (int i = 0; i < file_->service_count(); i++) {
if (i > 0) { if (i > 0) {
@ -232,6 +239,10 @@ void FileGenerator::GenerateHeader(io::Printer* printer) {
message_generators_[i]->GenerateInlineMethods(printer); message_generators_[i]->GenerateInlineMethods(printer);
} }
printer->Print(
"\n"
"// @@protoc_insertion_point(namespace_scope)\n");
// Close up namespace. // Close up namespace.
GenerateNamespaceClosers(printer); GenerateNamespaceClosers(printer);
@ -255,10 +266,14 @@ void FileGenerator::GenerateHeader(io::Printer* printer) {
printer->Print( printer->Print(
"\n" "\n"
"} // namespace google\n} // namespace protobuf\n" "} // namespace google\n} // namespace protobuf\n"
"#endif // SWIG\n" "#endif // SWIG\n");
"\n");
} }
printer->Print(
"\n"
"// @@protoc_insertion_point(global_scope)\n"
"\n");
printer->Print( printer->Print(
"#endif // PROTOBUF_$filename_identifier$__INCLUDED\n", "#endif // PROTOBUF_$filename_identifier$__INCLUDED\n",
"filename_identifier", filename_identifier); "filename_identifier", filename_identifier);
@ -285,6 +300,9 @@ void FileGenerator::GenerateSource(io::Printer* printer) {
"#include <google/protobuf/wire_format.h>\n"); "#include <google/protobuf/wire_format.h>\n");
} }
printer->Print(
"// @@protoc_insertion_point(includes)\n");
GenerateNamespaceOpeners(printer); GenerateNamespaceOpeners(printer);
if (HasDescriptorMethods(file_)) { if (HasDescriptorMethods(file_)) {
@ -300,10 +318,13 @@ void FileGenerator::GenerateSource(io::Printer* printer) {
"const ::google::protobuf::EnumDescriptor* $name$_descriptor_ = NULL;\n", "const ::google::protobuf::EnumDescriptor* $name$_descriptor_ = NULL;\n",
"name", ClassName(file_->enum_type(i), false)); "name", ClassName(file_->enum_type(i), false));
} }
for (int i = 0; i < file_->service_count(); i++) {
printer->Print( if (HasGenericServices(file_)) {
"const ::google::protobuf::ServiceDescriptor* $name$_descriptor_ = NULL;\n", for (int i = 0; i < file_->service_count(); i++) {
"name", file_->service(i)->name()); printer->Print(
"const ::google::protobuf::ServiceDescriptor* $name$_descriptor_ = NULL;\n",
"name", file_->service(i)->name());
}
} }
printer->Print( printer->Print(
@ -329,7 +350,7 @@ void FileGenerator::GenerateSource(io::Printer* printer) {
message_generators_[i]->GenerateClassMethods(printer); message_generators_[i]->GenerateClassMethods(printer);
} }
if (HasDescriptorMethods(file_)) { if (HasGenericServices(file_)) {
// Generate services. // Generate services.
for (int i = 0; i < file_->service_count(); i++) { for (int i = 0; i < file_->service_count(); i++) {
if (i == 0) printer->Print("\n"); if (i == 0) printer->Print("\n");
@ -344,7 +365,15 @@ void FileGenerator::GenerateSource(io::Printer* printer) {
extension_generators_[i]->GenerateDefinition(printer); extension_generators_[i]->GenerateDefinition(printer);
} }
printer->Print(
"\n"
"// @@protoc_insertion_point(namespace_scope)\n");
GenerateNamespaceClosers(printer); GenerateNamespaceClosers(printer);
printer->Print(
"\n"
"// @@protoc_insertion_point(global_scope)\n");
} }
void FileGenerator::GenerateBuildDescriptors(io::Printer* printer) { void FileGenerator::GenerateBuildDescriptors(io::Printer* printer) {
@ -397,8 +426,10 @@ void FileGenerator::GenerateBuildDescriptors(io::Printer* printer) {
for (int i = 0; i < file_->enum_type_count(); i++) { for (int i = 0; i < file_->enum_type_count(); i++) {
enum_generators_[i]->GenerateDescriptorInitializer(printer, i); enum_generators_[i]->GenerateDescriptorInitializer(printer, i);
} }
for (int i = 0; i < file_->service_count(); i++) { if (HasGenericServices(file_)) {
service_generators_[i]->GenerateDescriptorInitializer(printer, i); for (int i = 0; i < file_->service_count(); i++) {
service_generators_[i]->GenerateDescriptorInitializer(printer, i);
}
} }
printer->Outdent(); printer->Outdent();

View File

@ -32,6 +32,7 @@
// Based on original Protocol Buffers design by // Based on original Protocol Buffers design by
// Sanjay Ghemawat, Jeff Dean, and others. // Sanjay Ghemawat, Jeff Dean, and others.
#include <limits>
#include <vector> #include <vector>
#include <google/protobuf/stubs/hash.h> #include <google/protobuf/stubs/hash.h>
@ -40,6 +41,7 @@
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
#include <google/protobuf/stubs/substitute.h> #include <google/protobuf/stubs/substitute.h>
namespace google { namespace google {
namespace protobuf { namespace protobuf {
namespace compiler { namespace compiler {
@ -111,6 +113,7 @@ const char kThinSeparator[] =
"// -------------------------------------------------------------------\n"; "// -------------------------------------------------------------------\n";
string ClassName(const Descriptor* descriptor, bool qualified) { string ClassName(const Descriptor* descriptor, bool qualified) {
// Find "outer", the descriptor of the top-level message in which // Find "outer", the descriptor of the top-level message in which
// "descriptor" is embedded. // "descriptor" is embedded.
const Descriptor* outer = descriptor; const Descriptor* outer = descriptor;
@ -141,6 +144,12 @@ string ClassName(const EnumDescriptor* enum_descriptor, bool qualified) {
} }
} }
string SuperClassName(const Descriptor* descriptor) {
return HasDescriptorMethods(descriptor->file()) ?
"::google::protobuf::Message" : "::google::protobuf::MessageLite";
}
string FieldName(const FieldDescriptor* field) { string FieldName(const FieldDescriptor* field) {
string result = field->name(); string result = field->name();
LowerString(&result); LowerString(&result);
@ -166,6 +175,12 @@ string FieldConstantName(const FieldDescriptor *field) {
return result; return result;
} }
string FieldMessageTypeName(const FieldDescriptor* field) {
// Note: The Google-internal version of Protocol Buffers uses this function
// as a hook point for hacks to support legacy code.
return ClassName(field->message_type(), true);
}
string StripProto(const string& filename) { string StripProto(const string& filename) {
if (HasSuffixString(filename, ".protodevel")) { if (HasSuffixString(filename, ".protodevel")) {
return StripSuffixString(filename, ".protodevel"); return StripSuffixString(filename, ".protodevel");
@ -235,17 +250,37 @@ string DefaultValue(const FieldDescriptor* field) {
return "GOOGLE_LONGLONG(" + SimpleItoa(field->default_value_int64()) + ")"; return "GOOGLE_LONGLONG(" + SimpleItoa(field->default_value_int64()) + ")";
case FieldDescriptor::CPPTYPE_UINT64: case FieldDescriptor::CPPTYPE_UINT64:
return "GOOGLE_ULONGLONG(" + SimpleItoa(field->default_value_uint64())+ ")"; return "GOOGLE_ULONGLONG(" + SimpleItoa(field->default_value_uint64())+ ")";
case FieldDescriptor::CPPTYPE_DOUBLE: case FieldDescriptor::CPPTYPE_DOUBLE: {
return SimpleDtoa(field->default_value_double()); double value = field->default_value_double();
if (value == numeric_limits<double>::infinity()) {
return "::google::protobuf::internal::Infinity()";
} else if (value == -numeric_limits<double>::infinity()) {
return "-::google::protobuf::internal::Infinity()";
} else if (value != value) {
return "::google::protobuf::internal::NaN()";
} else {
return SimpleDtoa(value);
}
}
case FieldDescriptor::CPPTYPE_FLOAT: case FieldDescriptor::CPPTYPE_FLOAT:
{ {
// If floating point value contains a period (.) or an exponent (either float value = field->default_value_float();
// E or e), then append suffix 'f' to make it a floating-point literal. if (value == numeric_limits<float>::infinity()) {
string float_value = SimpleFtoa(field->default_value_float()); return "static_cast<float>(::google::protobuf::internal::Infinity())";
if (float_value.find_first_of(".eE") != string::npos) { } else if (value == -numeric_limits<float>::infinity()) {
float_value.push_back('f'); return "static_cast<float>(-::google::protobuf::internal::Infinity())";
} else if (value != value) {
return "static_cast<float>(::google::protobuf::internal::NaN())";
} else {
string float_value = SimpleFtoa(value);
// If floating point value contains a period (.) or an exponent
// (either E or e), then append suffix 'f' to make it a float
// literal.
if (float_value.find_first_of(".eE") != string::npos) {
float_value.push_back('f');
}
return float_value;
} }
return float_value;
} }
case FieldDescriptor::CPPTYPE_BOOL: case FieldDescriptor::CPPTYPE_BOOL:
return field->default_value_bool() ? "true" : "false"; return field->default_value_bool() ? "true" : "false";
@ -259,7 +294,7 @@ string DefaultValue(const FieldDescriptor* field) {
case FieldDescriptor::CPPTYPE_STRING: case FieldDescriptor::CPPTYPE_STRING:
return "\"" + CEscape(field->default_value_string()) + "\""; return "\"" + CEscape(field->default_value_string()) + "\"";
case FieldDescriptor::CPPTYPE_MESSAGE: case FieldDescriptor::CPPTYPE_MESSAGE:
return ClassName(field->message_type(), true) + "::default_instance()"; return FieldMessageTypeName(field) + "::default_instance()";
} }
// Can't actually get here; make compiler happy. (We could add a default // Can't actually get here; make compiler happy. (We could add a default
// case above but then we wouldn't get the nice compiler warning when a // case above but then we wouldn't get the nice compiler warning when a

View File

@ -60,6 +60,8 @@ extern const char kThinSeparator[];
string ClassName(const Descriptor* descriptor, bool qualified); string ClassName(const Descriptor* descriptor, bool qualified);
string ClassName(const EnumDescriptor* enum_descriptor, bool qualified); string ClassName(const EnumDescriptor* enum_descriptor, bool qualified);
string SuperClassName(const Descriptor* descriptor);
// Get the (unqualified) name that should be used for this field in C++ code. // Get the (unqualified) name that should be used for this field in C++ code.
// The name is coerced to lower-case to emulate proto1 behavior. People // The name is coerced to lower-case to emulate proto1 behavior. People
// should be using lowercase-with-underscores style for proto field names // should be using lowercase-with-underscores style for proto field names
@ -77,6 +79,10 @@ inline const Descriptor* FieldScope(const FieldDescriptor* field) {
field->extension_scope() : field->containing_type(); field->extension_scope() : field->containing_type();
} }
// Returns the fully-qualified type name field->message_type(). Usually this
// is just ClassName(field->message_type(), true);
string FieldMessageTypeName(const FieldDescriptor* field);
// Strips ".proto" or ".protodevel" from the end of a filename. // Strips ".proto" or ".protodevel" from the end of a filename.
string StripProto(const string& filename); string StripProto(const string& filename);
@ -107,33 +113,41 @@ string GlobalAssignDescriptorsName(const string& filename);
string GlobalShutdownFileName(const string& filename); string GlobalShutdownFileName(const string& filename);
// Do message classes in this file keep track of unknown fields? // Do message classes in this file keep track of unknown fields?
inline const bool HasUnknownFields(const FileDescriptor *file) { inline bool HasUnknownFields(const FileDescriptor *file) {
return file->options().optimize_for() != FileOptions::LITE_RUNTIME; return file->options().optimize_for() != FileOptions::LITE_RUNTIME;
} }
// Does this file have generated parsing, serialization, and other // Does this file have generated parsing, serialization, and other
// standard methods for which reflection-based fallback implementations exist? // standard methods for which reflection-based fallback implementations exist?
inline const bool HasGeneratedMethods(const FileDescriptor *file) { inline bool HasGeneratedMethods(const FileDescriptor *file) {
return file->options().optimize_for() != FileOptions::CODE_SIZE; return file->options().optimize_for() != FileOptions::CODE_SIZE;
} }
// Do message classes in this file have descriptor and refelction methods? // Do message classes in this file have descriptor and refelction methods?
inline const bool HasDescriptorMethods(const FileDescriptor *file) { inline bool HasDescriptorMethods(const FileDescriptor *file) {
return file->options().optimize_for() != FileOptions::LITE_RUNTIME; return file->options().optimize_for() != FileOptions::LITE_RUNTIME;
} }
// Should we generate generic services for this file?
inline bool HasGenericServices(const FileDescriptor *file) {
return file->service_count() > 0 &&
file->options().optimize_for() != FileOptions::LITE_RUNTIME &&
file->options().cc_generic_services();
}
// Should string fields in this file verify that their contents are UTF-8? // Should string fields in this file verify that their contents are UTF-8?
inline const bool HasUtf8Verification(const FileDescriptor* file) { inline bool HasUtf8Verification(const FileDescriptor* file) {
return file->options().optimize_for() != FileOptions::LITE_RUNTIME; return file->options().optimize_for() != FileOptions::LITE_RUNTIME;
} }
// Should we generate a separate, super-optimized code path for serializing to // Should we generate a separate, super-optimized code path for serializing to
// flat arrays? We don't do this in Lite mode because we'd rather reduce code // flat arrays? We don't do this in Lite mode because we'd rather reduce code
// size. // size.
inline const bool HasFastArraySerialization(const FileDescriptor* file) { inline bool HasFastArraySerialization(const FileDescriptor* file) {
return file->options().optimize_for() == FileOptions::SPEED; return file->options().optimize_for() == FileOptions::SPEED;
} }
} // namespace cpp } // namespace cpp
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf

View File

@ -308,11 +308,10 @@ GenerateClassDefinition(io::Printer* printer) {
} else { } else {
vars["dllexport"] = dllexport_decl_ + " "; vars["dllexport"] = dllexport_decl_ + " ";
} }
vars["superclass"] = HasDescriptorMethods(descriptor_->file()) ? vars["superclass"] = SuperClassName(descriptor_);
"Message" : "MessageLite";
printer->Print(vars, printer->Print(vars,
"class $dllexport$$classname$ : public ::google::protobuf::$superclass$ {\n" "class $dllexport$$classname$ : public $superclass$ {\n"
" public:\n"); " public:\n");
printer->Indent(); printer->Indent();
@ -349,6 +348,10 @@ GenerateClassDefinition(io::Printer* printer) {
printer->Print(vars, printer->Print(vars,
"static const $classname$& default_instance();\n" "static const $classname$& default_instance();\n"
"\n");
printer->Print(vars,
"void Swap($classname$* other);\n" "void Swap($classname$* other);\n"
"\n" "\n"
"// implements Message ----------------------------------------------\n" "// implements Message ----------------------------------------------\n"
@ -387,7 +390,7 @@ GenerateClassDefinition(io::Printer* printer) {
"private:\n" "private:\n"
"void SharedCtor();\n" "void SharedCtor();\n"
"void SharedDtor();\n" "void SharedDtor();\n"
"void SetCachedSize(int size) const { _cached_size_ = size; }\n" "void SetCachedSize(int size) const;\n"
"public:\n" "public:\n"
"\n"); "\n");
@ -436,6 +439,11 @@ GenerateClassDefinition(io::Printer* printer) {
extension_generators_[i]->GenerateDeclaration(printer); extension_generators_[i]->GenerateDeclaration(printer);
} }
printer->Print(
"// @@protoc_insertion_point(class_scope:$full_name$)\n",
"full_name", descriptor_->full_name());
// Generate private members for fields. // Generate private members for fields.
printer->Outdent(); printer->Outdent();
printer->Print(" private:\n"); printer->Print(" private:\n");
@ -623,6 +631,7 @@ GenerateDefaultInstanceAllocator(io::Printer* printer) {
for (int i = 0; i < descriptor_->nested_type_count(); i++) { for (int i = 0; i < descriptor_->nested_type_count(); i++) {
nested_generators_[i]->GenerateDefaultInstanceAllocator(printer); nested_generators_[i]->GenerateDefaultInstanceAllocator(printer);
} }
} }
void MessageGenerator:: void MessageGenerator::
@ -751,6 +760,7 @@ GenerateClassMethods(io::Printer* printer) {
"classname", classname_, "classname", classname_,
"type_name", descriptor_->full_name()); "type_name", descriptor_->full_name());
} }
} }
void MessageGenerator:: void MessageGenerator::
@ -833,9 +843,8 @@ GenerateSharedDestructorCode(io::Printer* printer) {
void MessageGenerator:: void MessageGenerator::
GenerateStructors(io::Printer* printer) { GenerateStructors(io::Printer* printer) {
string superclass = HasDescriptorMethods(descriptor_->file()) ? string superclass = SuperClassName(descriptor_);
"Message" : "MessageLite";
// Generate the default constructor. // Generate the default constructor.
printer->Print( printer->Print(
"$classname$::$classname$()\n" "$classname$::$classname$()\n"
@ -864,7 +873,7 @@ GenerateStructors(io::Printer* printer) {
printer->Print( printer->Print(
" $name$_ = const_cast< $type$*>(&$type$::default_instance());\n", " $name$_ = const_cast< $type$*>(&$type$::default_instance());\n",
"name", FieldName(field), "name", FieldName(field),
"type", ClassName(field->message_type(), true)); "type", FieldMessageTypeName(field));
} }
} }
printer->Print( printer->Print(
@ -896,6 +905,15 @@ GenerateStructors(io::Printer* printer) {
// Generate the shared destructor code. // Generate the shared destructor code.
GenerateSharedDestructorCode(printer); GenerateSharedDestructorCode(printer);
// Generate SetCachedSize.
printer->Print(
"void $classname$::SetCachedSize(int size) const {\n"
" GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN();\n"
" _cached_size_ = size;\n"
" GOOGLE_SAFE_CONCURRENT_WRITES_END();\n"
"}\n",
"classname", classname_);
// Only generate this member if it's not disabled. // Only generate this member if it's not disabled.
if (HasDescriptorMethods(descriptor_->file()) && if (HasDescriptorMethods(descriptor_->file()) &&
!descriptor_->options().no_standard_descriptor_accessor()) { !descriptor_->options().no_standard_descriptor_accessor()) {
@ -924,6 +942,7 @@ GenerateStructors(io::Printer* printer) {
"classname", classname_, "classname", classname_,
"adddescriptorsname", "adddescriptorsname",
GlobalAddDescriptorsName(descriptor_->file()->name())); GlobalAddDescriptorsName(descriptor_->file()->name()));
} }
void MessageGenerator:: void MessageGenerator::
@ -1237,12 +1256,15 @@ GenerateMergeFromCodedStream(io::Printer* printer) {
PrintFieldComment(printer, field); PrintFieldComment(printer, field);
printer->Print( printer->Print(
"case $number$: {\n" "case $number$: {\n",
" if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) !=\n" "number", SimpleItoa(field->number()));
" ::google::protobuf::internal::WireFormatLite::WIRETYPE_$wiretype$) {\n" printer->Indent();
" goto handle_uninterpreted;\n" const FieldGenerator& field_generator = field_generators_.get(field);
" }\n",
"number", SimpleItoa(field->number()), // Emit code to parse the common, expected case.
printer->Print(
"if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag) ==\n"
" ::google::protobuf::internal::WireFormatLite::WIRETYPE_$wiretype$) {\n",
"wiretype", kWireTypeNames[WireFormat::WireTypeForField(field)]); "wiretype", kWireTypeNames[WireFormat::WireTypeForField(field)]);
if (i > 0 || (field->is_repeated() && !field->options().packed())) { if (i > 0 || (field->is_repeated() && !field->options().packed())) {
@ -1252,8 +1274,38 @@ GenerateMergeFromCodedStream(io::Printer* printer) {
} }
printer->Indent(); printer->Indent();
if (field->options().packed()) {
field_generator.GenerateMergeFromCodedStreamWithPacking(printer);
} else {
field_generator.GenerateMergeFromCodedStream(printer);
}
printer->Outdent();
field_generators_.get(field).GenerateMergeFromCodedStream(printer); // Emit code to parse unexpectedly packed or unpacked values.
if (field->is_packable() && field->options().packed()) {
printer->Print(
"} else if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag)\n"
" == ::google::protobuf::internal::WireFormatLite::\n"
" WIRETYPE_$wiretype$) {\n",
"wiretype",
kWireTypeNames[WireFormat::WireTypeForFieldType(field->type())]);
printer->Indent();
field_generator.GenerateMergeFromCodedStream(printer);
printer->Outdent();
} else if (field->is_packable() && !field->options().packed()) {
printer->Print(
"} else if (::google::protobuf::internal::WireFormatLite::GetTagWireType(tag)\n"
" == ::google::protobuf::internal::WireFormatLite::\n"
" WIRETYPE_LENGTH_DELIMITED) {\n");
printer->Indent();
field_generator.GenerateMergeFromCodedStreamWithPacking(printer);
printer->Outdent();
}
printer->Print(
"} else {\n"
" goto handle_uninterpreted;\n"
"}\n");
// switch() is slow since it can't be predicted well. Insert some if()s // switch() is slow since it can't be predicted well. Insert some if()s
// here that attempt to predict the next tag. // here that attempt to predict the next tag.
@ -1434,18 +1486,6 @@ GenerateSerializeWithCachedSizes(io::Printer* printer) {
"classname", classname_); "classname", classname_);
printer->Indent(); printer->Indent();
if (HasFastArraySerialization(descriptor_->file())) {
printer->Print(
"::google::protobuf::uint8* raw_buffer = "
"output->GetDirectBufferForNBytesAndAdvance(_cached_size_);\n"
"if (raw_buffer != NULL) {\n"
" $classname$::SerializeWithCachedSizesToArray(raw_buffer);\n"
" return;\n"
"}\n"
"\n",
"classname", classname_);
}
GenerateSerializeWithCachedSizesBody(printer, false); GenerateSerializeWithCachedSizesBody(printer, false);
printer->Outdent(); printer->Outdent();
@ -1555,7 +1595,9 @@ GenerateByteSize(io::Printer* printer) {
" ComputeUnknownMessageSetItemsSize(unknown_fields());\n"); " ComputeUnknownMessageSetItemsSize(unknown_fields());\n");
} }
printer->Print( printer->Print(
" GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN();\n"
" _cached_size_ = total_size;\n" " _cached_size_ = total_size;\n"
" GOOGLE_SAFE_CONCURRENT_WRITES_END();\n"
" return total_size;\n" " return total_size;\n"
"}\n"); "}\n");
return; return;
@ -1647,7 +1689,9 @@ GenerateByteSize(io::Printer* printer) {
// exact same value, it works on all common processors. In a future version // exact same value, it works on all common processors. In a future version
// of C++, _cached_size_ should be made into an atomic<int>. // of C++, _cached_size_ should be made into an atomic<int>.
printer->Print( printer->Print(
"GOOGLE_SAFE_CONCURRENT_WRITES_BEGIN();\n"
"_cached_size_ = total_size;\n" "_cached_size_ = total_size;\n"
"GOOGLE_SAFE_CONCURRENT_WRITES_END();\n"
"return total_size;\n"); "return total_size;\n");
printer->Outdent(); printer->Outdent();
@ -1719,6 +1763,7 @@ GenerateIsInitialized(io::Printer* printer) {
"}\n"); "}\n");
} }
} // namespace cpp } // namespace cpp
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf

View File

@ -150,6 +150,7 @@ class MessageGenerator {
io::Printer* printer, const Descriptor::ExtensionRange* range, io::Printer* printer, const Descriptor::ExtensionRange* range,
bool unbounded); bool unbounded);
const Descriptor* descriptor_; const Descriptor* descriptor_;
string classname_; string classname_;
string dllexport_decl_; string dllexport_decl_;

View File

@ -47,7 +47,11 @@ namespace {
void SetMessageVariables(const FieldDescriptor* descriptor, void SetMessageVariables(const FieldDescriptor* descriptor,
map<string, string>* variables) { map<string, string>* variables) {
SetCommonFieldVariables(descriptor, variables); SetCommonFieldVariables(descriptor, variables);
(*variables)["type"] = ClassName(descriptor->message_type(), true); (*variables)["type"] = FieldMessageTypeName(descriptor);
(*variables)["stream_writer"] = (*variables)["declared_type"] +
(HasFastArraySerialization(descriptor->message_type()->file()) ?
"MaybeToArray" :
"");
} }
} // namespace } // namespace
@ -125,7 +129,7 @@ GenerateMergeFromCodedStream(io::Printer* printer) const {
void MessageFieldGenerator:: void MessageFieldGenerator::
GenerateSerializeWithCachedSizes(io::Printer* printer) const { GenerateSerializeWithCachedSizes(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"::google::protobuf::internal::WireFormatLite::Write$declared_type$NoVirtual(\n" "::google::protobuf::internal::WireFormatLite::Write$stream_writer$(\n"
" $number$, this->$name$(), output);\n"); " $number$, this->$name$(), output);\n");
} }
@ -164,26 +168,19 @@ GeneratePrivateMembers(io::Printer* printer) const {
void RepeatedMessageFieldGenerator:: void RepeatedMessageFieldGenerator::
GenerateAccessorDeclarations(io::Printer* printer) const { GenerateAccessorDeclarations(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"inline const ::google::protobuf::RepeatedPtrField< $type$ >& $name$() const"
"$deprecation$;\n"
"inline ::google::protobuf::RepeatedPtrField< $type$ >* mutable_$name$()"
"$deprecation$;\n"
"inline const $type$& $name$(int index) const$deprecation$;\n" "inline const $type$& $name$(int index) const$deprecation$;\n"
"inline $type$* mutable_$name$(int index)$deprecation$;\n" "inline $type$* mutable_$name$(int index)$deprecation$;\n"
"inline $type$* add_$name$()$deprecation$;\n"); "inline $type$* add_$name$()$deprecation$;\n");
printer->Print(variables_,
"inline const ::google::protobuf::RepeatedPtrField< $type$ >&\n"
" $name$() const$deprecation$;\n"
"inline ::google::protobuf::RepeatedPtrField< $type$ >*\n"
" mutable_$name$()$deprecation$;\n");
} }
void RepeatedMessageFieldGenerator:: void RepeatedMessageFieldGenerator::
GenerateInlineAccessorDefinitions(io::Printer* printer) const { GenerateInlineAccessorDefinitions(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"inline const ::google::protobuf::RepeatedPtrField< $type$ >&\n"
"$classname$::$name$() const {\n"
" return $name$_;\n"
"}\n"
"inline ::google::protobuf::RepeatedPtrField< $type$ >*\n"
"$classname$::mutable_$name$() {\n"
" return &$name$_;\n"
"}\n"
"inline const $type$& $classname$::$name$(int index) const {\n" "inline const $type$& $classname$::$name$(int index) const {\n"
" return $name$_.Get(index);\n" " return $name$_.Get(index);\n"
"}\n" "}\n"
@ -193,6 +190,15 @@ GenerateInlineAccessorDefinitions(io::Printer* printer) const {
"inline $type$* $classname$::add_$name$() {\n" "inline $type$* $classname$::add_$name$() {\n"
" return $name$_.Add();\n" " return $name$_.Add();\n"
"}\n"); "}\n");
printer->Print(variables_,
"inline const ::google::protobuf::RepeatedPtrField< $type$ >&\n"
"$classname$::$name$() const {\n"
" return $name$_;\n"
"}\n"
"inline ::google::protobuf::RepeatedPtrField< $type$ >*\n"
"$classname$::mutable_$name$() {\n"
" return &$name$_;\n"
"}\n");
} }
void RepeatedMessageFieldGenerator:: void RepeatedMessageFieldGenerator::
@ -232,7 +238,7 @@ void RepeatedMessageFieldGenerator::
GenerateSerializeWithCachedSizes(io::Printer* printer) const { GenerateSerializeWithCachedSizes(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"for (int i = 0; i < this->$name$_size(); i++) {\n" "for (int i = 0; i < this->$name$_size(); i++) {\n"
" ::google::protobuf::internal::WireFormatLite::Write$declared_type$NoVirtual(\n" " ::google::protobuf::internal::WireFormatLite::Write$stream_writer$(\n"
" $number$, this->$name$(i), output);\n" " $number$, this->$name$(i), output);\n"
"}\n"); "}\n");
} }

View File

@ -84,10 +84,14 @@ void SetPrimitiveVariables(const FieldDescriptor* descriptor,
SetCommonFieldVariables(descriptor, variables); SetCommonFieldVariables(descriptor, variables);
(*variables)["type"] = PrimitiveTypeName(descriptor->cpp_type()); (*variables)["type"] = PrimitiveTypeName(descriptor->cpp_type());
(*variables)["default"] = DefaultValue(descriptor); (*variables)["default"] = DefaultValue(descriptor);
(*variables)["tag"] = SimpleItoa(internal::WireFormat::MakeTag(descriptor));
int fixed_size = FixedSize(descriptor->type()); int fixed_size = FixedSize(descriptor->type());
if (fixed_size != -1) { if (fixed_size != -1) {
(*variables)["fixed_size"] = SimpleItoa(fixed_size); (*variables)["fixed_size"] = SimpleItoa(fixed_size);
} }
(*variables)["wire_format_field_type"] =
"::google::protobuf::internal::WireFormatLite::" + FieldDescriptorProto_Type_Name(
static_cast<FieldDescriptorProto_Type>(descriptor->type()));
} }
} // namespace } // namespace
@ -149,8 +153,9 @@ GenerateConstructorCode(io::Printer* printer) const {
void PrimitiveFieldGenerator:: void PrimitiveFieldGenerator::
GenerateMergeFromCodedStream(io::Printer* printer) const { GenerateMergeFromCodedStream(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"DO_(::google::protobuf::internal::WireFormatLite::Read$declared_type$(\n" "DO_((::google::protobuf::internal::WireFormatLite::ReadPrimitive<\n"
" input, &$name$_));\n" " $type$, $wire_format_field_type$>(\n"
" input, &$name$_)));\n"
"_set_bit($index$);\n"); "_set_bit($index$);\n");
} }
@ -188,6 +193,14 @@ RepeatedPrimitiveFieldGenerator::
RepeatedPrimitiveFieldGenerator(const FieldDescriptor* descriptor) RepeatedPrimitiveFieldGenerator(const FieldDescriptor* descriptor)
: descriptor_(descriptor) { : descriptor_(descriptor) {
SetPrimitiveVariables(descriptor, &variables_); SetPrimitiveVariables(descriptor, &variables_);
if (descriptor->options().packed()) {
variables_["packed_reader"] = "ReadPackedPrimitive";
variables_["repeated_reader"] = "ReadRepeatedPrimitiveNoInline";
} else {
variables_["packed_reader"] = "ReadPackedPrimitiveNoInline";
variables_["repeated_reader"] = "ReadRepeatedPrimitive";
}
} }
RepeatedPrimitiveFieldGenerator::~RepeatedPrimitiveFieldGenerator() {} RepeatedPrimitiveFieldGenerator::~RepeatedPrimitiveFieldGenerator() {}
@ -205,25 +218,19 @@ GeneratePrivateMembers(io::Printer* printer) const {
void RepeatedPrimitiveFieldGenerator:: void RepeatedPrimitiveFieldGenerator::
GenerateAccessorDeclarations(io::Printer* printer) const { GenerateAccessorDeclarations(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"inline const ::google::protobuf::RepeatedField< $type$ >& $name$() const\n"
" $deprecation$;\n"
"inline ::google::protobuf::RepeatedField< $type$ >* mutable_$name$()$deprecation$;\n"
"inline $type$ $name$(int index) const$deprecation$;\n" "inline $type$ $name$(int index) const$deprecation$;\n"
"inline void set_$name$(int index, $type$ value)$deprecation$;\n" "inline void set_$name$(int index, $type$ value)$deprecation$;\n"
"inline void add_$name$($type$ value)$deprecation$;\n"); "inline void add_$name$($type$ value)$deprecation$;\n");
printer->Print(variables_,
"inline const ::google::protobuf::RepeatedField< $type$ >&\n"
" $name$() const$deprecation$;\n"
"inline ::google::protobuf::RepeatedField< $type$ >*\n"
" mutable_$name$()$deprecation$;\n");
} }
void RepeatedPrimitiveFieldGenerator:: void RepeatedPrimitiveFieldGenerator::
GenerateInlineAccessorDefinitions(io::Printer* printer) const { GenerateInlineAccessorDefinitions(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"inline const ::google::protobuf::RepeatedField< $type$ >&\n"
"$classname$::$name$() const {\n"
" return $name$_;\n"
"}\n"
"inline ::google::protobuf::RepeatedField< $type$ >*\n"
"$classname$::mutable_$name$() {\n"
" return &$name$_;\n"
"}\n"
"inline $type$ $classname$::$name$(int index) const {\n" "inline $type$ $classname$::$name$(int index) const {\n"
" return $name$_.Get(index);\n" " return $name$_.Get(index);\n"
"}\n" "}\n"
@ -233,6 +240,15 @@ GenerateInlineAccessorDefinitions(io::Printer* printer) const {
"inline void $classname$::add_$name$($type$ value) {\n" "inline void $classname$::add_$name$($type$ value) {\n"
" $name$_.Add(value);\n" " $name$_.Add(value);\n"
"}\n"); "}\n");
printer->Print(variables_,
"inline const ::google::protobuf::RepeatedField< $type$ >&\n"
"$classname$::$name$() const {\n"
" return $name$_;\n"
"}\n"
"inline ::google::protobuf::RepeatedField< $type$ >*\n"
"$classname$::mutable_$name$() {\n"
" return &$name$_;\n"
"}\n");
} }
void RepeatedPrimitiveFieldGenerator:: void RepeatedPrimitiveFieldGenerator::
@ -257,30 +273,18 @@ GenerateConstructorCode(io::Printer* printer) const {
void RepeatedPrimitiveFieldGenerator:: void RepeatedPrimitiveFieldGenerator::
GenerateMergeFromCodedStream(io::Printer* printer) const { GenerateMergeFromCodedStream(io::Printer* printer) const {
if (descriptor_->options().packed()) { printer->Print(variables_,
printer->Print("{\n"); "DO_((::google::protobuf::internal::WireFormatLite::$repeated_reader$<\n"
printer->Indent(); " $type$, $wire_format_field_type$>(\n"
printer->Print(variables_, " $tag_size$, $tag$, input, this->mutable_$name$())));\n");
"::google::protobuf::uint32 length;\n" }
"DO_(input->ReadVarint32(&length));\n"
"::google::protobuf::io::CodedInputStream::Limit limit =\n" void RepeatedPrimitiveFieldGenerator::
" input->PushLimit(length);\n" GenerateMergeFromCodedStreamWithPacking(io::Printer* printer) const {
"while (input->BytesUntilLimit() > 0) {\n" printer->Print(variables_,
" $type$ value;\n" "DO_((::google::protobuf::internal::WireFormatLite::$packed_reader$<\n"
" DO_(::google::protobuf::internal::WireFormatLite::Read$declared_type$(\n" " $type$, $wire_format_field_type$>(\n"
" input, &value));\n" " input, this->mutable_$name$())));\n");
" add_$name$(value);\n"
"}\n"
"input->PopLimit(limit);\n");
printer->Outdent();
printer->Print("}\n");
} else {
printer->Print(variables_,
"$type$ value;\n"
"DO_(::google::protobuf::internal::WireFormatLite::Read$declared_type$(\n"
" input, &value));\n"
"add_$name$(value);\n");
}
} }
void RepeatedPrimitiveFieldGenerator:: void RepeatedPrimitiveFieldGenerator::

View File

@ -83,6 +83,7 @@ class RepeatedPrimitiveFieldGenerator : public FieldGenerator {
void GenerateSwappingCode(io::Printer* printer) const; void GenerateSwappingCode(io::Printer* printer) const;
void GenerateConstructorCode(io::Printer* printer) const; void GenerateConstructorCode(io::Printer* printer) const;
void GenerateMergeFromCodedStream(io::Printer* printer) const; void GenerateMergeFromCodedStream(io::Printer* printer) const;
void GenerateMergeFromCodedStreamWithPacking(io::Printer* printer) const;
void GenerateSerializeWithCachedSizes(io::Printer* printer) const; void GenerateSerializeWithCachedSizes(io::Printer* printer) const;
void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const; void GenerateSerializeWithCachedSizesToArray(io::Printer* printer) const;
void GenerateByteSize(io::Printer* printer) const; void GenerateByteSize(io::Printer* printer) const;

View File

@ -91,7 +91,7 @@ GenerateAccessorDeclarations(io::Printer* printer) const {
// files that applied the ctype. The field can still be accessed via the // files that applied the ctype. The field can still be accessed via the
// reflection interface since the reflection interface is independent of // reflection interface since the reflection interface is independent of
// the string's underlying representation. // the string's underlying representation.
if (descriptor_->options().has_ctype()) { if (descriptor_->options().ctype() != FieldOptions::STRING) {
printer->Outdent(); printer->Outdent();
printer->Print( printer->Print(
" private:\n" " private:\n"
@ -107,7 +107,7 @@ GenerateAccessorDeclarations(io::Printer* printer) const {
"$deprecation$;\n" "$deprecation$;\n"
"inline ::std::string* mutable_$name$()$deprecation$;\n"); "inline ::std::string* mutable_$name$()$deprecation$;\n");
if (descriptor_->options().has_ctype()) { if (descriptor_->options().ctype() != FieldOptions::STRING) {
printer->Outdent(); printer->Outdent();
printer->Print(" public:\n"); printer->Print(" public:\n");
printer->Indent(); printer->Indent();
@ -278,7 +278,7 @@ GeneratePrivateMembers(io::Printer* printer) const {
void RepeatedStringFieldGenerator:: void RepeatedStringFieldGenerator::
GenerateAccessorDeclarations(io::Printer* printer) const { GenerateAccessorDeclarations(io::Printer* printer) const {
// See comment above about unknown ctypes. // See comment above about unknown ctypes.
if (descriptor_->options().has_ctype()) { if (descriptor_->options().ctype() != FieldOptions::STRING) {
printer->Outdent(); printer->Outdent();
printer->Print( printer->Print(
" private:\n" " private:\n"
@ -287,10 +287,6 @@ GenerateAccessorDeclarations(io::Printer* printer) const {
} }
printer->Print(variables_, printer->Print(variables_,
"inline const ::google::protobuf::RepeatedPtrField< ::std::string>& $name$() const"
"$deprecation$;\n"
"inline ::google::protobuf::RepeatedPtrField< ::std::string>* mutable_$name$()"
"$deprecation$;\n"
"inline const ::std::string& $name$(int index) const$deprecation$;\n" "inline const ::std::string& $name$(int index) const$deprecation$;\n"
"inline ::std::string* mutable_$name$(int index)$deprecation$;\n" "inline ::std::string* mutable_$name$(int index)$deprecation$;\n"
"inline void set_$name$(int index, const ::std::string& value)$deprecation$;\n" "inline void set_$name$(int index, const ::std::string& value)$deprecation$;\n"
@ -304,7 +300,13 @@ GenerateAccessorDeclarations(io::Printer* printer) const {
"inline void add_$name$(const $pointer_type$* value, size_t size)" "inline void add_$name$(const $pointer_type$* value, size_t size)"
"$deprecation$;\n"); "$deprecation$;\n");
if (descriptor_->options().has_ctype()) { printer->Print(variables_,
"inline const ::google::protobuf::RepeatedPtrField< ::std::string>& $name$() const"
"$deprecation$;\n"
"inline ::google::protobuf::RepeatedPtrField< ::std::string>* mutable_$name$()"
"$deprecation$;\n");
if (descriptor_->options().ctype() != FieldOptions::STRING) {
printer->Outdent(); printer->Outdent();
printer->Print(" public:\n"); printer->Print(" public:\n");
printer->Indent(); printer->Indent();
@ -314,14 +316,6 @@ GenerateAccessorDeclarations(io::Printer* printer) const {
void RepeatedStringFieldGenerator:: void RepeatedStringFieldGenerator::
GenerateInlineAccessorDefinitions(io::Printer* printer) const { GenerateInlineAccessorDefinitions(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"inline const ::google::protobuf::RepeatedPtrField< ::std::string>&\n"
"$classname$::$name$() const {\n"
" return $name$_;\n"
"}\n"
"inline ::google::protobuf::RepeatedPtrField< ::std::string>*\n"
"$classname$::mutable_$name$() {\n"
" return &$name$_;\n"
"}\n"
"inline const ::std::string& $classname$::$name$(int index) const {\n" "inline const ::std::string& $classname$::$name$(int index) const {\n"
" return $name$_.Get(index);\n" " return $name$_.Get(index);\n"
"}\n" "}\n"
@ -353,6 +347,15 @@ GenerateInlineAccessorDefinitions(io::Printer* printer) const {
"$classname$::add_$name$(const $pointer_type$* value, size_t size) {\n" "$classname$::add_$name$(const $pointer_type$* value, size_t size) {\n"
" $name$_.Add()->assign(reinterpret_cast<const char*>(value), size);\n" " $name$_.Add()->assign(reinterpret_cast<const char*>(value), size);\n"
"}\n"); "}\n");
printer->Print(variables_,
"inline const ::google::protobuf::RepeatedPtrField< ::std::string>&\n"
"$classname$::$name$() const {\n"
" return $name$_;\n"
"}\n"
"inline ::google::protobuf::RepeatedPtrField< ::std::string>*\n"
"$classname$::mutable_$name$() {\n"
" return &$name$_;\n"
"}\n");
} }
void RepeatedStringFieldGenerator:: void RepeatedStringFieldGenerator::

View File

@ -49,6 +49,7 @@
#include <google/protobuf/unittest.pb.h> #include <google/protobuf/unittest.pb.h>
#include <google/protobuf/unittest_optimize_for.pb.h> #include <google/protobuf/unittest_optimize_for.pb.h>
#include <google/protobuf/unittest_embed_optimize_for.pb.h> #include <google/protobuf/unittest_embed_optimize_for.pb.h>
#include <google/protobuf/unittest_no_generic_services.pb.h>
#include <google/protobuf/test_util.h> #include <google/protobuf/test_util.h>
#include <google/protobuf/compiler/cpp/cpp_test_bad_identifiers.pb.h> #include <google/protobuf/compiler/cpp/cpp_test_bad_identifiers.pb.h>
#include <google/protobuf/compiler/importer.h> #include <google/protobuf/compiler/importer.h>
@ -154,6 +155,16 @@ TEST(GeneratedMessageTest, FloatingPointDefaults) {
EXPECT_EQ(-1.5f, extreme_default.negative_float()); EXPECT_EQ(-1.5f, extreme_default.negative_float());
EXPECT_EQ(2.0e8f, extreme_default.large_float()); EXPECT_EQ(2.0e8f, extreme_default.large_float());
EXPECT_EQ(-8e-28f, extreme_default.small_negative_float()); EXPECT_EQ(-8e-28f, extreme_default.small_negative_float());
EXPECT_EQ(numeric_limits<double>::infinity(),
extreme_default.inf_double());
EXPECT_EQ(-numeric_limits<double>::infinity(),
extreme_default.neg_inf_double());
EXPECT_TRUE(extreme_default.nan_double() != extreme_default.nan_double());
EXPECT_EQ(numeric_limits<float>::infinity(),
extreme_default.inf_float());
EXPECT_EQ(-numeric_limits<float>::infinity(),
extreme_default.neg_inf_float());
EXPECT_TRUE(extreme_default.nan_float() != extreme_default.nan_float());
} }
TEST(GeneratedMessageTest, Accessors) { TEST(GeneratedMessageTest, Accessors) {
@ -779,22 +790,39 @@ TEST(GeneratedEnumTest, IsValidValue) {
} }
TEST(GeneratedEnumTest, MinAndMax) { TEST(GeneratedEnumTest, MinAndMax) {
EXPECT_EQ(unittest::TestAllTypes::FOO,unittest::TestAllTypes::NestedEnum_MIN); EXPECT_EQ(unittest::TestAllTypes::FOO,
EXPECT_EQ(unittest::TestAllTypes::BAZ,unittest::TestAllTypes::NestedEnum_MAX); unittest::TestAllTypes::NestedEnum_MIN);
EXPECT_EQ(unittest::TestAllTypes::BAZ,
unittest::TestAllTypes::NestedEnum_MAX);
EXPECT_EQ(4, unittest::TestAllTypes::NestedEnum_ARRAYSIZE);
EXPECT_EQ(unittest::FOREIGN_FOO, unittest::ForeignEnum_MIN); EXPECT_EQ(unittest::FOREIGN_FOO, unittest::ForeignEnum_MIN);
EXPECT_EQ(unittest::FOREIGN_BAZ, unittest::ForeignEnum_MAX); EXPECT_EQ(unittest::FOREIGN_BAZ, unittest::ForeignEnum_MAX);
EXPECT_EQ(7, unittest::ForeignEnum_ARRAYSIZE);
EXPECT_EQ(1, unittest::TestEnumWithDupValue_MIN); EXPECT_EQ(1, unittest::TestEnumWithDupValue_MIN);
EXPECT_EQ(3, unittest::TestEnumWithDupValue_MAX); EXPECT_EQ(3, unittest::TestEnumWithDupValue_MAX);
EXPECT_EQ(4, unittest::TestEnumWithDupValue_ARRAYSIZE);
EXPECT_EQ(unittest::SPARSE_E, unittest::TestSparseEnum_MIN); EXPECT_EQ(unittest::SPARSE_E, unittest::TestSparseEnum_MIN);
EXPECT_EQ(unittest::SPARSE_C, unittest::TestSparseEnum_MAX); EXPECT_EQ(unittest::SPARSE_C, unittest::TestSparseEnum_MAX);
EXPECT_EQ(12589235, unittest::TestSparseEnum_ARRAYSIZE);
// Make sure we can use _MIN and _MAX as switch cases. // Make sure we can take the address of _MIN, _MAX and _ARRAYSIZE.
switch(unittest::SPARSE_A) { void* nullptr = 0; // NULL may be integer-type, not pointer-type.
EXPECT_NE(nullptr, &unittest::TestAllTypes::NestedEnum_MIN);
EXPECT_NE(nullptr, &unittest::TestAllTypes::NestedEnum_MAX);
EXPECT_NE(nullptr, &unittest::TestAllTypes::NestedEnum_ARRAYSIZE);
EXPECT_NE(nullptr, &unittest::ForeignEnum_MIN);
EXPECT_NE(nullptr, &unittest::ForeignEnum_MAX);
EXPECT_NE(nullptr, &unittest::ForeignEnum_ARRAYSIZE);
// Make sure we can use _MIN, _MAX and _ARRAYSIZE as switch cases.
switch (unittest::SPARSE_A) {
case unittest::TestSparseEnum_MIN: case unittest::TestSparseEnum_MIN:
case unittest::TestSparseEnum_MAX: case unittest::TestSparseEnum_MAX:
case unittest::TestSparseEnum_ARRAYSIZE:
break; break;
default: default:
break; break;
@ -1136,6 +1164,43 @@ TEST_F(GeneratedServiceTest, NotImplemented) {
EXPECT_TRUE(controller.called_); EXPECT_TRUE(controller.called_);
} }
} // namespace cpp_unittest
} // namespace cpp
} // namespace compiler
namespace no_generic_services_test {
// Verify that no class called "TestService" was defined in
// unittest_no_generic_services.pb.h by defining a different type by the same
// name. If such a service was generated, this will not compile.
struct TestService {
int i;
};
}
namespace compiler {
namespace cpp {
namespace cpp_unittest {
TEST_F(GeneratedServiceTest, NoGenericServices) {
// Verify that non-services in unittest_no_generic_services.proto were
// generated.
no_generic_services_test::TestMessage message;
message.set_a(1);
message.SetExtension(no_generic_services_test::test_extension, 123);
no_generic_services_test::TestEnum e = no_generic_services_test::FOO;
EXPECT_EQ(e, 1);
// Verify that a ServiceDescriptor is generated for the service even if the
// class itself is not.
const FileDescriptor* file =
no_generic_services_test::TestMessage::descriptor()->file();
ASSERT_EQ(1, file->service_count());
EXPECT_EQ("TestService", file->service(0)->name());
ASSERT_EQ(1, file->service(0)->method_count());
EXPECT_EQ("Foo", file->service(0)->method(0)->name());
}
#endif // !PROTOBUF_TEST_NO_DESCRIPTORS #endif // !PROTOBUF_TEST_NO_DESCRIPTORS
// =================================================================== // ===================================================================

View File

@ -223,6 +223,11 @@ void EnumGenerator::Generate(io::Printer* printer) {
"file", ClassName(descriptor_->file())); "file", ClassName(descriptor_->file()));
} }
printer->Print(
"\n"
"// @@protoc_insertion_point(enum_scope:$full_name$)\n",
"full_name", descriptor_->full_name());
printer->Outdent(); printer->Outdent();
printer->Print("}\n\n"); printer->Print("}\n\n");
} }

View File

@ -62,7 +62,7 @@ void SetEnumVariables(const FieldDescriptor* descriptor,
(*variables)["default"] = DefaultValue(descriptor); (*variables)["default"] = DefaultValue(descriptor);
(*variables)["tag"] = SimpleItoa(internal::WireFormat::MakeTag(descriptor)); (*variables)["tag"] = SimpleItoa(internal::WireFormat::MakeTag(descriptor));
(*variables)["tag_size"] = SimpleItoa( (*variables)["tag_size"] = SimpleItoa(
internal::WireFormat::TagSize(descriptor->number(), descriptor->type())); internal::WireFormat::TagSize(descriptor->number(), GetType(descriptor)));
} }
} // namespace } // namespace
@ -81,7 +81,7 @@ void EnumFieldGenerator::
GenerateMembers(io::Printer* printer) const { GenerateMembers(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"private boolean has$capitalized_name$;\n" "private boolean has$capitalized_name$;\n"
"private $type$ $name$_ = $default$;\n" "private $type$ $name$_;\n"
"public boolean has$capitalized_name$() { return has$capitalized_name$; }\n" "public boolean has$capitalized_name$() { return has$capitalized_name$; }\n"
"public $type$ get$capitalized_name$() { return $name$_; }\n"); "public $type$ get$capitalized_name$() { return $name$_; }\n");
} }
@ -110,6 +110,11 @@ GenerateBuilderMembers(io::Printer* printer) const {
"}\n"); "}\n");
} }
void EnumFieldGenerator::
GenerateInitializationCode(io::Printer* printer) const {
printer->Print(variables_, "$name$_ = $default$;\n");
}
void EnumFieldGenerator:: void EnumFieldGenerator::
GenerateMergingCode(io::Printer* printer) const { GenerateMergingCode(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
@ -240,6 +245,11 @@ GenerateBuilderMembers(io::Printer* printer) const {
"}\n"); "}\n");
} }
void RepeatedEnumFieldGenerator::
GenerateInitializationCode(io::Printer* printer) const {
// Initialized inline.
}
void RepeatedEnumFieldGenerator:: void RepeatedEnumFieldGenerator::
GenerateMergingCode(io::Printer* printer) const { GenerateMergingCode(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
@ -262,15 +272,6 @@ GenerateBuildingCode(io::Printer* printer) const {
void RepeatedEnumFieldGenerator:: void RepeatedEnumFieldGenerator::
GenerateParsingCode(io::Printer* printer) const { GenerateParsingCode(io::Printer* printer) const {
// If packed, set up the while loop
if (descriptor_->options().packed()) {
printer->Print(variables_,
"int length = input.readRawVarint32();\n"
"int oldLimit = input.pushLimit(length);\n"
"while(input.getBytesUntilLimit() > 0) {\n");
printer->Indent();
}
// Read and store the enum // Read and store the enum
printer->Print(variables_, printer->Print(variables_,
"int rawValue = input.readEnum();\n" "int rawValue = input.readEnum();\n"
@ -287,13 +288,24 @@ GenerateParsingCode(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
" add$capitalized_name$(value);\n" " add$capitalized_name$(value);\n"
"}\n"); "}\n");
}
if (descriptor_->options().packed()) { void RepeatedEnumFieldGenerator::
printer->Outdent(); GenerateParsingCodeFromPacked(io::Printer* printer) const {
printer->Print(variables_, // Wrap GenerateParsingCode's contents with a while loop.
"}\n"
"input.popLimit(oldLimit);\n"); printer->Print(variables_,
} "int length = input.readRawVarint32();\n"
"int oldLimit = input.pushLimit(length);\n"
"while(input.getBytesUntilLimit() > 0) {\n");
printer->Indent();
GenerateParsingCode(printer);
printer->Outdent();
printer->Print(variables_,
"}\n"
"input.popLimit(oldLimit);\n");
} }
void RepeatedEnumFieldGenerator:: void RepeatedEnumFieldGenerator::

View File

@ -52,6 +52,7 @@ class EnumFieldGenerator : public FieldGenerator {
// implements FieldGenerator --------------------------------------- // implements FieldGenerator ---------------------------------------
void GenerateMembers(io::Printer* printer) const; void GenerateMembers(io::Printer* printer) const;
void GenerateBuilderMembers(io::Printer* printer) const; void GenerateBuilderMembers(io::Printer* printer) const;
void GenerateInitializationCode(io::Printer* printer) const;
void GenerateMergingCode(io::Printer* printer) const; void GenerateMergingCode(io::Printer* printer) const;
void GenerateBuildingCode(io::Printer* printer) const; void GenerateBuildingCode(io::Printer* printer) const;
void GenerateParsingCode(io::Printer* printer) const; void GenerateParsingCode(io::Printer* printer) const;
@ -75,9 +76,11 @@ class RepeatedEnumFieldGenerator : public FieldGenerator {
// implements FieldGenerator --------------------------------------- // implements FieldGenerator ---------------------------------------
void GenerateMembers(io::Printer* printer) const; void GenerateMembers(io::Printer* printer) const;
void GenerateBuilderMembers(io::Printer* printer) const; void GenerateBuilderMembers(io::Printer* printer) const;
void GenerateInitializationCode(io::Printer* printer) const;
void GenerateMergingCode(io::Printer* printer) const; void GenerateMergingCode(io::Printer* printer) const;
void GenerateBuildingCode(io::Printer* printer) const; void GenerateBuildingCode(io::Printer* printer) const;
void GenerateParsingCode(io::Printer* printer) const; void GenerateParsingCode(io::Printer* printer) const;
void GenerateParsingCodeFromPacked(io::Printer* printer) const;
void GenerateSerializationCode(io::Printer* printer) const; void GenerateSerializationCode(io::Printer* printer) const;
void GenerateSerializedSizeCode(io::Printer* printer) const; void GenerateSerializedSizeCode(io::Printer* printer) const;

View File

@ -133,7 +133,7 @@ void ExtensionGenerator::GenerateInitializationCode(io::Printer* printer) {
vars["extendee"] = ClassName(descriptor_->containing_type()); vars["extendee"] = ClassName(descriptor_->containing_type());
vars["default"] = descriptor_->is_repeated() ? "" : DefaultValue(descriptor_); vars["default"] = descriptor_->is_repeated() ? "" : DefaultValue(descriptor_);
vars["number"] = SimpleItoa(descriptor_->number()); vars["number"] = SimpleItoa(descriptor_->number());
vars["type_constant"] = TypeName(descriptor_->type()); vars["type_constant"] = TypeName(GetType(descriptor_));
vars["packed"] = descriptor_->options().packed() ? "true" : "false"; vars["packed"] = descriptor_->options().packed() ? "true" : "false";
vars["enum_map"] = "null"; vars["enum_map"] = "null";
vars["prototype"] = "null"; vars["prototype"] = "null";
@ -208,5 +208,4 @@ void ExtensionGenerator::GenerateRegistrationCode(io::Printer* printer) {
} // namespace java } // namespace java
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf
} // namespace google } // namespace google

View File

@ -46,6 +46,16 @@ namespace java {
FieldGenerator::~FieldGenerator() {} FieldGenerator::~FieldGenerator() {}
void FieldGenerator::GenerateParsingCodeFromPacked(io::Printer* printer) const {
// Reaching here indicates a bug. Cases are:
// - This FieldGenerator should support packing, but this method should be
// overridden.
// - This FieldGenerator doesn't support packing, and this method should
// never have been called.
GOOGLE_LOG(FATAL) << "GenerateParsingCodeFromPacked() "
<< "called on field generator that does not support packing.";
}
FieldGeneratorMap::FieldGeneratorMap(const Descriptor* descriptor) FieldGeneratorMap::FieldGeneratorMap(const Descriptor* descriptor)
: descriptor_(descriptor), : descriptor_(descriptor),
field_generators_( field_generators_(

View File

@ -57,9 +57,11 @@ class FieldGenerator {
virtual void GenerateMembers(io::Printer* printer) const = 0; virtual void GenerateMembers(io::Printer* printer) const = 0;
virtual void GenerateBuilderMembers(io::Printer* printer) const = 0; virtual void GenerateBuilderMembers(io::Printer* printer) const = 0;
virtual void GenerateInitializationCode(io::Printer* printer) const = 0;
virtual void GenerateMergingCode(io::Printer* printer) const = 0; virtual void GenerateMergingCode(io::Printer* printer) const = 0;
virtual void GenerateBuildingCode(io::Printer* printer) const = 0; virtual void GenerateBuildingCode(io::Printer* printer) const = 0;
virtual void GenerateParsingCode(io::Printer* printer) const = 0; virtual void GenerateParsingCode(io::Printer* printer) const = 0;
virtual void GenerateParsingCodeFromPacked(io::Printer* printer) const;
virtual void GenerateSerializationCode(io::Printer* printer) const = 0; virtual void GenerateSerializationCode(io::Printer* printer) const = 0;
virtual void GenerateSerializedSizeCode(io::Printer* printer) const = 0; virtual void GenerateSerializedSizeCode(io::Printer* printer) const = 0;

View File

@ -64,7 +64,7 @@ bool UsesExtensions(const Message& message) {
for (int i = 0; i < fields.size(); i++) { for (int i = 0; i < fields.size(); i++) {
if (fields[i]->is_extension()) return true; if (fields[i]->is_extension()) return true;
if (fields[i]->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { if (GetJavaType(fields[i]) == JAVATYPE_MESSAGE) {
if (fields[i]->is_repeated()) { if (fields[i]->is_repeated()) {
int size = reflection->FieldSize(message, fields[i]); int size = reflection->FieldSize(message, fields[i]);
for (int j = 0; j < size; j++) { for (int j = 0; j < size; j++) {
@ -82,6 +82,7 @@ bool UsesExtensions(const Message& message) {
return false; return false;
} }
} // namespace } // namespace
FileGenerator::FileGenerator(const FileDescriptor* file) FileGenerator::FileGenerator(const FileDescriptor* file)
@ -134,7 +135,9 @@ void FileGenerator::Generate(io::Printer* printer) {
// fully-qualified names in the generated source. // fully-qualified names in the generated source.
printer->Print( printer->Print(
"// Generated by the protocol buffer compiler. DO NOT EDIT!\n" "// Generated by the protocol buffer compiler. DO NOT EDIT!\n"
"\n"); "// source: $filename$\n"
"\n",
"filename", file_->name());
if (!java_package_.empty()) { if (!java_package_.empty()) {
printer->Print( printer->Print(
"package $package$;\n" "package $package$;\n"
@ -178,8 +181,10 @@ void FileGenerator::Generate(io::Printer* printer) {
for (int i = 0; i < file_->message_type_count(); i++) { for (int i = 0; i < file_->message_type_count(); i++) {
MessageGenerator(file_->message_type(i)).Generate(printer); MessageGenerator(file_->message_type(i)).Generate(printer);
} }
for (int i = 0; i < file_->service_count(); i++) { if (HasGenericServices(file_)) {
ServiceGenerator(file_->service(i)).Generate(printer); for (int i = 0; i < file_->service_count(); i++) {
ServiceGenerator(file_->service(i)).Generate(printer);
}
} }
} }
@ -228,6 +233,10 @@ void FileGenerator::Generate(io::Printer* printer) {
"\n" "\n"
"public static void internalForceInit() {}\n"); "public static void internalForceInit() {}\n");
printer->Print(
"\n"
"// @@protoc_insertion_point(outer_class_scope)\n");
printer->Outdent(); printer->Outdent();
printer->Print("}\n"); printer->Print("}\n");
} }
@ -245,6 +254,7 @@ void FileGenerator::GenerateEmbeddedDescriptor(io::Printer* printer) {
// embedded raw, which is what we want. // embedded raw, which is what we want.
FileDescriptorProto file_proto; FileDescriptorProto file_proto;
file_->CopyTo(&file_proto); file_->CopyTo(&file_proto);
string file_data; string file_data;
file_proto.SerializeToString(&file_data); file_proto.SerializeToString(&file_data);
@ -343,9 +353,11 @@ void FileGenerator::GenerateEmbeddedDescriptor(io::Printer* printer) {
" new com.google.protobuf.Descriptors.FileDescriptor[] {\n"); " new com.google.protobuf.Descriptors.FileDescriptor[] {\n");
for (int i = 0; i < file_->dependency_count(); i++) { for (int i = 0; i < file_->dependency_count(); i++) {
printer->Print( if (ShouldIncludeDependency(file_->dependency(i))) {
" $dependency$.getDescriptor(),\n", printer->Print(
"dependency", ClassName(file_->dependency(i))); " $dependency$.getDescriptor(),\n",
"dependency", ClassName(file_->dependency(i)));
}
} }
printer->Print( printer->Print(
@ -396,14 +408,20 @@ void FileGenerator::GenerateSiblings(const string& package_dir,
file_->message_type(i), file_->message_type(i),
output_directory, file_list); output_directory, file_list);
} }
for (int i = 0; i < file_->service_count(); i++) { if (HasGenericServices(file_)) {
GenerateSibling<ServiceGenerator>(package_dir, java_package_, for (int i = 0; i < file_->service_count(); i++) {
file_->service(i), GenerateSibling<ServiceGenerator>(package_dir, java_package_,
output_directory, file_list); file_->service(i),
output_directory, file_list);
}
} }
} }
} }
bool FileGenerator::ShouldIncludeDependency(const FileDescriptor* descriptor) {
return true;
}
} // namespace java } // namespace java
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf

View File

@ -77,6 +77,11 @@ class FileGenerator {
const string& classname() { return classname_; } const string& classname() { return classname_; }
private: private:
// Returns whether the dependency should be included in the output file.
// Always returns true for opensource, but used internally at Google to help
// improve compatibility with version 1 of protocol buffers.
bool ShouldIncludeDependency(const FileDescriptor* descriptor);
const FileDescriptor* file_; const FileDescriptor* file_;
string java_package_; string java_package_;
string classname_; string classname_;

View File

@ -45,6 +45,7 @@ namespace protobuf {
namespace compiler { namespace compiler {
namespace java { namespace java {
JavaGenerator::JavaGenerator() {} JavaGenerator::JavaGenerator() {}
JavaGenerator::~JavaGenerator() {} JavaGenerator::~JavaGenerator() {}

View File

@ -32,6 +32,7 @@
// Based on original Protocol Buffers design by // Based on original Protocol Buffers design by
// Sanjay Ghemawat, Jeff Dean, and others. // Sanjay Ghemawat, Jeff Dean, and others.
#include <limits>
#include <vector> #include <vector>
#include <google/protobuf/compiler/java/java_helpers.h> #include <google/protobuf/compiler/java/java_helpers.h>
@ -57,7 +58,7 @@ const string& FieldName(const FieldDescriptor* field) {
// Groups are hacky: The name of the field is just the lower-cased name // Groups are hacky: The name of the field is just the lower-cased name
// of the group type. In Java, though, we would like to retain the original // of the group type. In Java, though, we would like to retain the original
// capitalization of the type name. // capitalization of the type name.
if (field->type() == FieldDescriptor::TYPE_GROUP) { if (GetType(field) == FieldDescriptor::TYPE_GROUP) {
return field->message_type()->name(); return field->message_type()->name();
} else { } else {
return field->name(); return field->name();
@ -178,8 +179,12 @@ string FieldConstantName(const FieldDescriptor *field) {
return name; return name;
} }
JavaType GetJavaType(FieldDescriptor::Type field_type) { FieldDescriptor::Type GetType(const FieldDescriptor* field) {
switch (field_type) { return field->type();
}
JavaType GetJavaType(const FieldDescriptor* field) {
switch (GetType(field)) {
case FieldDescriptor::TYPE_INT32: case FieldDescriptor::TYPE_INT32:
case FieldDescriptor::TYPE_UINT32: case FieldDescriptor::TYPE_UINT32:
case FieldDescriptor::TYPE_SINT32: case FieldDescriptor::TYPE_SINT32:
@ -254,7 +259,7 @@ bool AllAscii(const string& text) {
} }
string DefaultValue(const FieldDescriptor* field) { string DefaultValue(const FieldDescriptor* field) {
// Switch on cpp_type since we need to know which default_value_* method // Switch on CppType since we need to know which default_value_* method
// of FieldDescriptor to call. // of FieldDescriptor to call.
switch (field->cpp_type()) { switch (field->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32: case FieldDescriptor::CPPTYPE_INT32:
@ -267,14 +272,34 @@ string DefaultValue(const FieldDescriptor* field) {
case FieldDescriptor::CPPTYPE_UINT64: case FieldDescriptor::CPPTYPE_UINT64:
return SimpleItoa(static_cast<int64>(field->default_value_uint64())) + return SimpleItoa(static_cast<int64>(field->default_value_uint64())) +
"L"; "L";
case FieldDescriptor::CPPTYPE_DOUBLE: case FieldDescriptor::CPPTYPE_DOUBLE: {
return SimpleDtoa(field->default_value_double()) + "D"; double value = field->default_value_double();
case FieldDescriptor::CPPTYPE_FLOAT: if (value == numeric_limits<double>::infinity()) {
return SimpleFtoa(field->default_value_float()) + "F"; return "Double.POSITIVE_INFINITY";
} else if (value == -numeric_limits<double>::infinity()) {
return "Double.NEGATIVE_INFINITY";
} else if (value != value) {
return "Double.NaN";
} else {
return SimpleDtoa(value) + "D";
}
}
case FieldDescriptor::CPPTYPE_FLOAT: {
float value = field->default_value_float();
if (value == numeric_limits<float>::infinity()) {
return "Float.POSITIVE_INFINITY";
} else if (value == -numeric_limits<float>::infinity()) {
return "Float.NEGATIVE_INFINITY";
} else if (value != value) {
return "Float.NaN";
} else {
return SimpleFtoa(value) + "F";
}
}
case FieldDescriptor::CPPTYPE_BOOL: case FieldDescriptor::CPPTYPE_BOOL:
return field->default_value_bool() ? "true" : "false"; return field->default_value_bool() ? "true" : "false";
case FieldDescriptor::CPPTYPE_STRING: case FieldDescriptor::CPPTYPE_STRING:
if (field->type() == FieldDescriptor::TYPE_BYTES) { if (GetType(field) == FieldDescriptor::TYPE_BYTES) {
if (field->has_default_value()) { if (field->has_default_value()) {
// See comments in Internal.java for gory details. // See comments in Internal.java for gory details.
return strings::Substitute( return strings::Substitute(

View File

@ -93,6 +93,11 @@ string ClassName(const FileDescriptor* descriptor);
// number constant. // number constant.
string FieldConstantName(const FieldDescriptor *field); string FieldConstantName(const FieldDescriptor *field);
// Returns the type of the FieldDescriptor.
// This does nothing interesting for the open source release, but is used for
// hacks that improve compatability with version 1 protocol buffers at Google.
FieldDescriptor::Type GetType(const FieldDescriptor* field);
enum JavaType { enum JavaType {
JAVATYPE_INT, JAVATYPE_INT,
JAVATYPE_LONG, JAVATYPE_LONG,
@ -105,11 +110,7 @@ enum JavaType {
JAVATYPE_MESSAGE JAVATYPE_MESSAGE
}; };
JavaType GetJavaType(FieldDescriptor::Type field_type); JavaType GetJavaType(const FieldDescriptor* field);
inline JavaType GetJavaType(const FieldDescriptor* field) {
return GetJavaType(field->type());
}
// Get the fully-qualified class name for a boxed primitive type, e.g. // Get the fully-qualified class name for a boxed primitive type, e.g.
// "java.lang.Integer" for JAVATYPE_INT. Returns NULL for enum and message // "java.lang.Integer" for JAVATYPE_INT. Returns NULL for enum and message
@ -145,6 +146,13 @@ inline bool HasDescriptorMethods(const FileDescriptor* descriptor) {
FileOptions::LITE_RUNTIME; FileOptions::LITE_RUNTIME;
} }
// Should we generate generic services for this file?
inline bool HasGenericServices(const FileDescriptor *file) {
return file->service_count() > 0 &&
file->options().optimize_for() != FileOptions::LITE_RUNTIME &&
file->options().java_generic_services();
}
} // namespace java } // namespace java
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf

View File

@ -127,7 +127,7 @@ static bool HasRequiredFields(
if (field->is_required()) { if (field->is_required()) {
return true; return true;
} }
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { if (GetJavaType(field) == JAVATYPE_MESSAGE) {
if (HasRequiredFields(field->message_type(), already_seen)) { if (HasRequiredFields(field->message_type(), already_seen)) {
return true; return true;
} }
@ -292,9 +292,14 @@ void MessageGenerator::Generate(io::Printer* printer) {
printer->Indent(); printer->Indent();
printer->Print( printer->Print(
"// Use $classname$.newBuilder() to construct.\n" "// Use $classname$.newBuilder() to construct.\n"
"private $classname$() {}\n" "private $classname$() {\n"
" initFields();\n"
"}\n"
// Used when constructing the default instance, which cannot be initialized
// immediately because it may cyclically refer to other default instances.
"private $classname$(boolean noInit) {}\n"
"\n" "\n"
"private static final $classname$ defaultInstance = new $classname$();\n" "private static final $classname$ defaultInstance;\n"
"public static $classname$ getDefaultInstance() {\n" "public static $classname$ getDefaultInstance() {\n"
" return defaultInstance;\n" " return defaultInstance;\n"
"}\n" "}\n"
@ -344,6 +349,17 @@ void MessageGenerator::Generate(io::Printer* printer) {
printer->Print("\n"); printer->Print("\n");
} }
// Called by the constructor, except in the case of the default instance,
// in which case this is called by static init code later on.
printer->Print("private void initFields() {\n");
printer->Indent();
for (int i = 0; i < descriptor_->field_count(); i++) {
field_generators_.get(descriptor_->field(i))
.GenerateInitializationCode(printer);
}
printer->Outdent();
printer->Print("}\n");
if (HasGeneratedMethods(descriptor_)) { if (HasGeneratedMethods(descriptor_)) {
GenerateIsInitialized(printer); GenerateIsInitialized(printer);
GenerateMessageSerializationMethods(printer); GenerateMessageSerializationMethods(printer);
@ -352,25 +368,23 @@ void MessageGenerator::Generate(io::Printer* printer) {
GenerateParseFromMethods(printer); GenerateParseFromMethods(printer);
GenerateBuilder(printer); GenerateBuilder(printer);
if (HasDescriptorMethods(descriptor_)) {
// Force the static initialization code for the file to run, since it may
// initialize static variables declared in this class.
printer->Print(
"\n"
"static {\n"
" $file$.getDescriptor();\n"
"}\n",
"file", ClassName(descriptor_->file()));
}
// Force initialization of outer class. Otherwise, nested extensions may // Force initialization of outer class. Otherwise, nested extensions may
// not be initialized. // not be initialized. Also carefully initialize the default instance in
// such a way that it doesn't conflict with other initialization.
printer->Print( printer->Print(
"\n" "\n"
"static {\n" "static {\n"
" defaultInstance = new $classname$(true);\n"
" $file$.internalForceInit();\n" " $file$.internalForceInit();\n"
" defaultInstance.initFields();\n"
"}\n", "}\n",
"file", ClassName(descriptor_->file())); "file", ClassName(descriptor_->file()),
"classname", descriptor_->name());
printer->Print(
"\n"
"// @@protoc_insertion_point(class_scope:$full_name$)\n",
"full_name", descriptor_->full_name());
printer->Outdent(); printer->Outdent();
printer->Print("}\n\n"); printer->Print("}\n\n");
@ -529,14 +543,23 @@ GenerateParseFromMethods(io::Printer* printer) {
"}\n" "}\n"
"public static $classname$ parseDelimitedFrom(java.io.InputStream input)\n" "public static $classname$ parseDelimitedFrom(java.io.InputStream input)\n"
" throws java.io.IOException {\n" " throws java.io.IOException {\n"
" return newBuilder().mergeDelimitedFrom(input).buildParsed();\n" " Builder builder = newBuilder();\n"
" if (builder.mergeDelimitedFrom(input)) {\n"
" return builder.buildParsed();\n"
" } else {\n"
" return null;\n"
" }\n"
"}\n" "}\n"
"public static $classname$ parseDelimitedFrom(\n" "public static $classname$ parseDelimitedFrom(\n"
" java.io.InputStream input,\n" " java.io.InputStream input,\n"
" com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n" " com.google.protobuf.ExtensionRegistryLite extensionRegistry)\n"
" throws java.io.IOException {\n" " throws java.io.IOException {\n"
" return newBuilder().mergeDelimitedFrom(input, extensionRegistry)\n" " Builder builder = newBuilder();\n"
" .buildParsed();\n" " if (builder.mergeDelimitedFrom(input, extensionRegistry)) {\n"
" return builder.buildParsed();\n"
" } else {\n"
" return null;\n"
" }\n"
"}\n" "}\n"
"public static $classname$ parseFrom(\n" "public static $classname$ parseFrom(\n"
" com.google.protobuf.CodedInputStream input)\n" " com.google.protobuf.CodedInputStream input)\n"
@ -827,7 +850,7 @@ void MessageGenerator::GenerateBuilderParsingMethods(io::Printer* printer) {
for (int i = 0; i < descriptor_->field_count(); i++) { for (int i = 0; i < descriptor_->field_count(); i++) {
const FieldDescriptor* field = sorted_fields[i]; const FieldDescriptor* field = sorted_fields[i];
uint32 tag = WireFormatLite::MakeTag(field->number(), uint32 tag = WireFormatLite::MakeTag(field->number(),
WireFormat::WireTypeForField(field)); WireFormat::WireTypeForFieldType(field->type()));
printer->Print( printer->Print(
"case $tag$: {\n", "case $tag$: {\n",
@ -840,6 +863,24 @@ void MessageGenerator::GenerateBuilderParsingMethods(io::Printer* printer) {
printer->Print( printer->Print(
" break;\n" " break;\n"
"}\n"); "}\n");
if (field->is_packable()) {
// To make packed = true wire compatible, we generate parsing code from a
// packed version of this field regardless of field->options().packed().
uint32 packed_tag = WireFormatLite::MakeTag(field->number(),
WireFormatLite::WIRETYPE_LENGTH_DELIMITED);
printer->Print(
"case $tag$: {\n",
"tag", SimpleItoa(packed_tag));
printer->Indent();
field_generators_.get(field).GenerateParsingCodeFromPacked(printer);
printer->Outdent();
printer->Print(
" break;\n"
"}\n");
}
} }
printer->Outdent(); printer->Outdent();
@ -875,7 +916,7 @@ void MessageGenerator::GenerateIsInitialized(io::Printer* printer) {
// Now check that all embedded messages are initialized. // Now check that all embedded messages are initialized.
for (int i = 0; i < descriptor_->field_count(); i++) { for (int i = 0; i < descriptor_->field_count(); i++) {
const FieldDescriptor* field = descriptor_->field(i); const FieldDescriptor* field = descriptor_->field(i);
if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE && if (GetJavaType(field) == JAVATYPE_MESSAGE &&
HasRequiredFields(field->message_type())) { HasRequiredFields(field->message_type())) {
switch (field->label()) { switch (field->label()) {
case FieldDescriptor::LABEL_REQUIRED: case FieldDescriptor::LABEL_REQUIRED:

View File

@ -59,7 +59,7 @@ void SetMessageVariables(const FieldDescriptor* descriptor,
(*variables)["number"] = SimpleItoa(descriptor->number()); (*variables)["number"] = SimpleItoa(descriptor->number());
(*variables)["type"] = ClassName(descriptor->message_type()); (*variables)["type"] = ClassName(descriptor->message_type());
(*variables)["group_or_message"] = (*variables)["group_or_message"] =
(descriptor->type() == FieldDescriptor::TYPE_GROUP) ? (GetType(descriptor) == FieldDescriptor::TYPE_GROUP) ?
"Group" : "Message"; "Group" : "Message";
} }
@ -79,7 +79,7 @@ void MessageFieldGenerator::
GenerateMembers(io::Printer* printer) const { GenerateMembers(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"private boolean has$capitalized_name$;\n" "private boolean has$capitalized_name$;\n"
"private $type$ $name$_ = $type$.getDefaultInstance();\n" "private $type$ $name$_;\n"
"public boolean has$capitalized_name$() { return has$capitalized_name$; }\n" "public boolean has$capitalized_name$() { return has$capitalized_name$; }\n"
"public $type$ get$capitalized_name$() { return $name$_; }\n"); "public $type$ get$capitalized_name$() { return $name$_; }\n");
} }
@ -124,6 +124,11 @@ GenerateBuilderMembers(io::Printer* printer) const {
"}\n"); "}\n");
} }
void MessageFieldGenerator::
GenerateInitializationCode(io::Printer* printer) const {
printer->Print(variables_, "$name$_ = $type$.getDefaultInstance();\n");
}
void MessageFieldGenerator:: void MessageFieldGenerator::
GenerateMergingCode(io::Printer* printer) const { GenerateMergingCode(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
@ -145,7 +150,7 @@ GenerateParsingCode(io::Printer* printer) const {
" subBuilder.mergeFrom(get$capitalized_name$());\n" " subBuilder.mergeFrom(get$capitalized_name$());\n"
"}\n"); "}\n");
if (descriptor_->type() == FieldDescriptor::TYPE_GROUP) { if (GetType(descriptor_) == FieldDescriptor::TYPE_GROUP) {
printer->Print(variables_, printer->Print(variables_,
"input.readGroup($number$, subBuilder, extensionRegistry);\n"); "input.readGroup($number$, subBuilder, extensionRegistry);\n");
} else { } else {
@ -261,6 +266,11 @@ GenerateBuilderMembers(io::Printer* printer) const {
"}\n"); "}\n");
} }
void RepeatedMessageFieldGenerator::
GenerateInitializationCode(io::Printer* printer) const {
// Initialized inline.
}
void RepeatedMessageFieldGenerator:: void RepeatedMessageFieldGenerator::
GenerateMergingCode(io::Printer* printer) const { GenerateMergingCode(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
@ -286,7 +296,7 @@ GenerateParsingCode(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
"$type$.Builder subBuilder = $type$.newBuilder();\n"); "$type$.Builder subBuilder = $type$.newBuilder();\n");
if (descriptor_->type() == FieldDescriptor::TYPE_GROUP) { if (GetType(descriptor_) == FieldDescriptor::TYPE_GROUP) {
printer->Print(variables_, printer->Print(variables_,
"input.readGroup($number$, subBuilder, extensionRegistry);\n"); "input.readGroup($number$, subBuilder, extensionRegistry);\n");
} else { } else {

View File

@ -52,6 +52,7 @@ class MessageFieldGenerator : public FieldGenerator {
// implements FieldGenerator --------------------------------------- // implements FieldGenerator ---------------------------------------
void GenerateMembers(io::Printer* printer) const; void GenerateMembers(io::Printer* printer) const;
void GenerateBuilderMembers(io::Printer* printer) const; void GenerateBuilderMembers(io::Printer* printer) const;
void GenerateInitializationCode(io::Printer* printer) const;
void GenerateMergingCode(io::Printer* printer) const; void GenerateMergingCode(io::Printer* printer) const;
void GenerateBuildingCode(io::Printer* printer) const; void GenerateBuildingCode(io::Printer* printer) const;
void GenerateParsingCode(io::Printer* printer) const; void GenerateParsingCode(io::Printer* printer) const;
@ -75,6 +76,7 @@ class RepeatedMessageFieldGenerator : public FieldGenerator {
// implements FieldGenerator --------------------------------------- // implements FieldGenerator ---------------------------------------
void GenerateMembers(io::Printer* printer) const; void GenerateMembers(io::Printer* printer) const;
void GenerateBuilderMembers(io::Printer* printer) const; void GenerateBuilderMembers(io::Printer* printer) const;
void GenerateInitializationCode(io::Printer* printer) const;
void GenerateMergingCode(io::Printer* printer) const; void GenerateMergingCode(io::Printer* printer) const;
void GenerateBuildingCode(io::Printer* printer) const; void GenerateBuildingCode(io::Printer* printer) const;
void GenerateParsingCode(io::Printer* printer) const; void GenerateParsingCode(io::Printer* printer) const;

View File

@ -93,7 +93,7 @@ bool IsReferenceType(JavaType type) {
} }
const char* GetCapitalizedType(const FieldDescriptor* field) { const char* GetCapitalizedType(const FieldDescriptor* field) {
switch (field->type()) { switch (GetType(field)) {
case FieldDescriptor::TYPE_INT32 : return "Int32" ; case FieldDescriptor::TYPE_INT32 : return "Int32" ;
case FieldDescriptor::TYPE_UINT32 : return "UInt32" ; case FieldDescriptor::TYPE_UINT32 : return "UInt32" ;
case FieldDescriptor::TYPE_SINT32 : return "SInt32" ; case FieldDescriptor::TYPE_SINT32 : return "SInt32" ;
@ -166,7 +166,7 @@ void SetPrimitiveVariables(const FieldDescriptor* descriptor,
(*variables)["capitalized_type"] = GetCapitalizedType(descriptor); (*variables)["capitalized_type"] = GetCapitalizedType(descriptor);
(*variables)["tag"] = SimpleItoa(WireFormat::MakeTag(descriptor)); (*variables)["tag"] = SimpleItoa(WireFormat::MakeTag(descriptor));
(*variables)["tag_size"] = SimpleItoa( (*variables)["tag_size"] = SimpleItoa(
WireFormat::TagSize(descriptor->number(), descriptor->type())); WireFormat::TagSize(descriptor->number(), GetType(descriptor)));
if (IsReferenceType(GetJavaType(descriptor))) { if (IsReferenceType(GetJavaType(descriptor))) {
(*variables)["null_check"] = (*variables)["null_check"] =
" if (value == null) {\n" " if (value == null) {\n"
@ -175,7 +175,7 @@ void SetPrimitiveVariables(const FieldDescriptor* descriptor,
} else { } else {
(*variables)["null_check"] = ""; (*variables)["null_check"] = "";
} }
int fixed_size = FixedSize(descriptor->type()); int fixed_size = FixedSize(GetType(descriptor));
if (fixed_size != -1) { if (fixed_size != -1) {
(*variables)["fixed_size"] = SimpleItoa(fixed_size); (*variables)["fixed_size"] = SimpleItoa(fixed_size);
} }
@ -218,7 +218,8 @@ GenerateBuilderMembers(io::Printer* printer) const {
"}\n" "}\n"
"public Builder clear$capitalized_name$() {\n" "public Builder clear$capitalized_name$() {\n"
" result.has$capitalized_name$ = false;\n"); " result.has$capitalized_name$ = false;\n");
if (descriptor_->cpp_type() == FieldDescriptor::CPPTYPE_STRING) { JavaType type = GetJavaType(descriptor_);
if (type == JAVATYPE_STRING || type == JAVATYPE_BYTES) {
// The default value is not a simple literal so we want to avoid executing // The default value is not a simple literal so we want to avoid executing
// it multiple times. Instead, get the default out of the default instance. // it multiple times. Instead, get the default out of the default instance.
printer->Print(variables_, printer->Print(variables_,
@ -232,6 +233,11 @@ GenerateBuilderMembers(io::Printer* printer) const {
"}\n"); "}\n");
} }
void PrimitiveFieldGenerator::
GenerateInitializationCode(io::Printer* printer) const {
// Initialized inline.
}
void PrimitiveFieldGenerator:: void PrimitiveFieldGenerator::
GenerateMergingCode(io::Printer* printer) const { GenerateMergingCode(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
@ -345,6 +351,11 @@ GenerateBuilderMembers(io::Printer* printer) const {
"}\n"); "}\n");
} }
void RepeatedPrimitiveFieldGenerator::
GenerateInitializationCode(io::Printer* printer) const {
// Initialized inline.
}
void RepeatedPrimitiveFieldGenerator:: void RepeatedPrimitiveFieldGenerator::
GenerateMergingCode(io::Printer* printer) const { GenerateMergingCode(io::Printer* printer) const {
printer->Print(variables_, printer->Print(variables_,
@ -367,18 +378,19 @@ GenerateBuildingCode(io::Printer* printer) const {
void RepeatedPrimitiveFieldGenerator:: void RepeatedPrimitiveFieldGenerator::
GenerateParsingCode(io::Printer* printer) const { GenerateParsingCode(io::Printer* printer) const {
if (descriptor_->options().packed()) { printer->Print(variables_,
printer->Print(variables_, "add$capitalized_name$(input.read$capitalized_type$());\n");
"int length = input.readRawVarint32();\n" }
"int limit = input.pushLimit(length);\n"
"while (input.getBytesUntilLimit() > 0) {\n" void RepeatedPrimitiveFieldGenerator::
" add$capitalized_name$(input.read$capitalized_type$());\n" GenerateParsingCodeFromPacked(io::Printer* printer) const {
"}\n" printer->Print(variables_,
"input.popLimit(limit);\n"); "int length = input.readRawVarint32();\n"
} else { "int limit = input.pushLimit(length);\n"
printer->Print(variables_, "while (input.getBytesUntilLimit() > 0) {\n"
"add$capitalized_name$(input.read$capitalized_type$());\n"); " add$capitalized_name$(input.read$capitalized_type$());\n"
} "}\n"
"input.popLimit(limit);\n");
} }
void RepeatedPrimitiveFieldGenerator:: void RepeatedPrimitiveFieldGenerator::
@ -407,7 +419,7 @@ GenerateSerializedSizeCode(io::Printer* printer) const {
" int dataSize = 0;\n"); " int dataSize = 0;\n");
printer->Indent(); printer->Indent();
if (FixedSize(descriptor_->type()) == -1) { if (FixedSize(GetType(descriptor_)) == -1) {
printer->Print(variables_, printer->Print(variables_,
"for ($type$ element : get$capitalized_name$List()) {\n" "for ($type$ element : get$capitalized_name$List()) {\n"
" dataSize += com.google.protobuf.CodedOutputStream\n" " dataSize += com.google.protobuf.CodedOutputStream\n"

View File

@ -52,6 +52,7 @@ class PrimitiveFieldGenerator : public FieldGenerator {
// implements FieldGenerator --------------------------------------- // implements FieldGenerator ---------------------------------------
void GenerateMembers(io::Printer* printer) const; void GenerateMembers(io::Printer* printer) const;
void GenerateBuilderMembers(io::Printer* printer) const; void GenerateBuilderMembers(io::Printer* printer) const;
void GenerateInitializationCode(io::Printer* printer) const;
void GenerateMergingCode(io::Printer* printer) const; void GenerateMergingCode(io::Printer* printer) const;
void GenerateBuildingCode(io::Printer* printer) const; void GenerateBuildingCode(io::Printer* printer) const;
void GenerateParsingCode(io::Printer* printer) const; void GenerateParsingCode(io::Printer* printer) const;
@ -75,9 +76,11 @@ class RepeatedPrimitiveFieldGenerator : public FieldGenerator {
// implements FieldGenerator --------------------------------------- // implements FieldGenerator ---------------------------------------
void GenerateMembers(io::Printer* printer) const; void GenerateMembers(io::Printer* printer) const;
void GenerateBuilderMembers(io::Printer* printer) const; void GenerateBuilderMembers(io::Printer* printer) const;
void GenerateInitializationCode(io::Printer* printer) const;
void GenerateMergingCode(io::Printer* printer) const; void GenerateMergingCode(io::Printer* printer) const;
void GenerateBuildingCode(io::Printer* printer) const; void GenerateBuildingCode(io::Printer* printer) const;
void GenerateParsingCode(io::Printer* printer) const; void GenerateParsingCode(io::Printer* printer) const;
void GenerateParsingCodeFromPacked(io::Printer* printer) const;
void GenerateSerializationCode(io::Printer* printer) const; void GenerateSerializationCode(io::Printer* printer) const;
void GenerateSerializedSizeCode(io::Printer* printer) const; void GenerateSerializedSizeCode(io::Printer* printer) const;

View File

@ -39,6 +39,7 @@
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
google::protobuf::compiler::CommandLineInterface cli; google::protobuf::compiler::CommandLineInterface cli;
cli.AllowPlugins("protoc-");
// Proto2 C++ // Proto2 C++
google::protobuf::compiler::cpp::CppGenerator cpp_generator; google::protobuf::compiler::cpp::CppGenerator cpp_generator;

View File

@ -34,8 +34,9 @@
// //
// Recursive descent FTW. // Recursive descent FTW.
#include <google/protobuf/stubs/hash.h>
#include <float.h> #include <float.h>
#include <google/protobuf/stubs/hash.h>
#include <limits>
#include <google/protobuf/compiler/parser.h> #include <google/protobuf/compiler/parser.h>
@ -206,6 +207,14 @@ bool Parser::ConsumeNumber(double* output, const char* error) {
*output = value; *output = value;
input_->Next(); input_->Next();
return true; return true;
} else if (LookingAt("inf")) {
*output = numeric_limits<double>::infinity();
input_->Next();
return true;
} else if (LookingAt("nan")) {
*output = numeric_limits<double>::quiet_NaN();
input_->Next();
return true;
} else { } else {
AddError(error); AddError(error);
return false; return false;

View File

@ -336,6 +336,9 @@ TEST_F(ParseMessageTest, FieldDefaults) {
" required double foo = 1 [default= 10.5];\n" " required double foo = 1 [default= 10.5];\n"
" required double foo = 1 [default=-11.5];\n" " required double foo = 1 [default=-11.5];\n"
" required double foo = 1 [default= 12 ];\n" " required double foo = 1 [default= 12 ];\n"
" required double foo = 1 [default= inf ];\n"
" required double foo = 1 [default=-inf ];\n"
" required double foo = 1 [default= nan ];\n"
" required string foo = 1 [default='13\\001'];\n" " required string foo = 1 [default='13\\001'];\n"
" required string foo = 1 [default='a' \"b\" \n \"c\"];\n" " required string foo = 1 [default='a' \"b\" \n \"c\"];\n"
" required bytes foo = 1 [default='14\\002'];\n" " required bytes foo = 1 [default='14\\002'];\n"
@ -367,6 +370,9 @@ TEST_F(ParseMessageTest, FieldDefaults) {
" field { type:TYPE_DOUBLE default_value:\"10.5\" "ETC" }" " field { type:TYPE_DOUBLE default_value:\"10.5\" "ETC" }"
" field { type:TYPE_DOUBLE default_value:\"-11.5\" "ETC" }" " field { type:TYPE_DOUBLE default_value:\"-11.5\" "ETC" }"
" field { type:TYPE_DOUBLE default_value:\"12\" "ETC" }" " field { type:TYPE_DOUBLE default_value:\"12\" "ETC" }"
" field { type:TYPE_DOUBLE default_value:\"inf\" "ETC" }"
" field { type:TYPE_DOUBLE default_value:\"-inf\" "ETC" }"
" field { type:TYPE_DOUBLE default_value:\"nan\" "ETC" }"
" field { type:TYPE_STRING default_value:\"13\\001\" "ETC" }" " field { type:TYPE_STRING default_value:\"13\\001\" "ETC" }"
" field { type:TYPE_STRING default_value:\"abc\" "ETC" }" " field { type:TYPE_STRING default_value:\"abc\" "ETC" }"
" field { type:TYPE_BYTES default_value:\"14\\\\002\" "ETC" }" " field { type:TYPE_BYTES default_value:\"14\\\\002\" "ETC" }"

View File

@ -42,8 +42,9 @@
// performance-minded Python code leverage the fast C++ implementation // performance-minded Python code leverage the fast C++ implementation
// directly. // directly.
#include <utility> #include <limits>
#include <map> #include <map>
#include <utility>
#include <string> #include <string>
#include <vector> #include <vector>
@ -105,6 +106,13 @@ string NamePrefixedWithNestedTypes(const DescriptorT& descriptor,
const char kDescriptorKey[] = "DESCRIPTOR"; const char kDescriptorKey[] = "DESCRIPTOR";
// Should we generate generic services for this file?
inline bool HasGenericServices(const FileDescriptor *file) {
return file->service_count() > 0 &&
file->options().py_generic_services();
}
// Prints the common boilerplate needed at the top of every .py // Prints the common boilerplate needed at the top of every .py
// file output by this generator. // file output by this generator.
void PrintTopBoilerplate( void PrintTopBoilerplate(
@ -115,14 +123,21 @@ void PrintTopBoilerplate(
"\n" "\n"
"from google.protobuf import descriptor\n" "from google.protobuf import descriptor\n"
"from google.protobuf import message\n" "from google.protobuf import message\n"
"from google.protobuf import reflection\n" "from google.protobuf import reflection\n");
"from google.protobuf import service\n" if (HasGenericServices(file)) {
"from google.protobuf import service_reflection\n"); printer->Print(
"from google.protobuf import service\n"
"from google.protobuf import service_reflection\n");
}
// Avoid circular imports if this module is descriptor_pb2. // Avoid circular imports if this module is descriptor_pb2.
if (!descriptor_proto) { if (!descriptor_proto) {
printer->Print( printer->Print(
"from google.protobuf import descriptor_pb2\n"); "from google.protobuf import descriptor_pb2\n");
} }
printer->Print(
"# @@protoc_insertion_point(imports)\n");
printer->Print("\n\n");
} }
@ -150,10 +165,30 @@ string StringifyDefaultValue(const FieldDescriptor& field) {
return SimpleItoa(field.default_value_int64()); return SimpleItoa(field.default_value_int64());
case FieldDescriptor::CPPTYPE_UINT64: case FieldDescriptor::CPPTYPE_UINT64:
return SimpleItoa(field.default_value_uint64()); return SimpleItoa(field.default_value_uint64());
case FieldDescriptor::CPPTYPE_DOUBLE: case FieldDescriptor::CPPTYPE_DOUBLE: {
return SimpleDtoa(field.default_value_double()); double value = field.default_value_double();
case FieldDescriptor::CPPTYPE_FLOAT: if (value == numeric_limits<double>::infinity()) {
return SimpleFtoa(field.default_value_float()); return "float('inf')";
} else if (value == -numeric_limits<double>::infinity()) {
return "float('-inf')";
} else if (value != value) {
return "float('nan')";
} else {
return SimpleDtoa(value);
}
}
case FieldDescriptor::CPPTYPE_FLOAT: {
float value = field.default_value_float();
if (value == numeric_limits<float>::infinity()) {
return "float('inf')";
} else if (value == -numeric_limits<float>::infinity()) {
return "float('-inf')";
} else if (value != value) {
return "float('nan')";
} else {
return SimpleFtoa(value);
}
}
case FieldDescriptor::CPPTYPE_BOOL: case FieldDescriptor::CPPTYPE_BOOL:
return field.default_value_bool() ? "True" : "False"; return field.default_value_bool() ? "True" : "False";
case FieldDescriptor::CPPTYPE_ENUM: case FieldDescriptor::CPPTYPE_ENUM:
@ -204,6 +239,10 @@ bool Generator::Generate(const FileDescriptor* file,
StripString(&filename, ".", '/'); StripString(&filename, ".", '/');
filename += ".py"; filename += ".py";
FileDescriptorProto fdp;
file_->CopyTo(&fdp);
fdp.SerializeToString(&file_descriptor_serialized_);
scoped_ptr<io::ZeroCopyOutputStream> output(output_directory->Open(filename)); scoped_ptr<io::ZeroCopyOutputStream> output(output_directory->Open(filename));
GOOGLE_CHECK(output.get()); GOOGLE_CHECK(output.get());
@ -211,6 +250,7 @@ bool Generator::Generate(const FileDescriptor* file,
printer_ = &printer; printer_ = &printer;
PrintTopBoilerplate(printer_, file_, GeneratingDescriptorProto()); PrintTopBoilerplate(printer_, file_, GeneratingDescriptorProto());
PrintFileDescriptor();
PrintTopLevelEnums(); PrintTopLevelEnums();
PrintTopLevelExtensions(); PrintTopLevelExtensions();
PrintAllNestedEnumsInFile(); PrintAllNestedEnumsInFile();
@ -224,7 +264,13 @@ bool Generator::Generate(const FileDescriptor* file,
// since they need to call static RegisterExtension() methods on these // since they need to call static RegisterExtension() methods on these
// classes. // classes.
FixForeignFieldsInExtensions(); FixForeignFieldsInExtensions();
PrintServices(); if (HasGenericServices(file)) {
PrintServices();
}
printer.Print(
"# @@protoc_insertion_point(module_scope)\n");
return !printer.failed(); return !printer.failed();
} }
@ -238,6 +284,30 @@ void Generator::PrintImports() const {
printer_->Print("\n"); printer_->Print("\n");
} }
// Prints the single file descriptor for this file.
void Generator::PrintFileDescriptor() const {
map<string, string> m;
m["descriptor_name"] = kDescriptorKey;
m["name"] = file_->name();
m["package"] = file_->package();
const char file_descriptor_template[] =
"$descriptor_name$ = descriptor.FileDescriptor(\n"
" name='$name$',\n"
" package='$package$',\n";
printer_->Print(m, file_descriptor_template);
printer_->Indent();
printer_->Print(
"serialized_pb='$value$'",
"value", strings::CHexEscape(file_descriptor_serialized_));
// TODO(falk): Also print options and fix the message_type, enum_type,
// service and extension later in the generation.
printer_->Outdent();
printer_->Print(")\n");
printer_->Print("\n");
}
// Prints descriptors and module-level constants for all top-level // Prints descriptors and module-level constants for all top-level
// enums defined in |file|. // enums defined in |file|.
void Generator::PrintTopLevelEnums() const { void Generator::PrintTopLevelEnums() const {
@ -277,12 +347,13 @@ void Generator::PrintEnum(const EnumDescriptor& enum_descriptor) const {
m["descriptor_name"] = ModuleLevelDescriptorName(enum_descriptor); m["descriptor_name"] = ModuleLevelDescriptorName(enum_descriptor);
m["name"] = enum_descriptor.name(); m["name"] = enum_descriptor.name();
m["full_name"] = enum_descriptor.full_name(); m["full_name"] = enum_descriptor.full_name();
m["filename"] = enum_descriptor.name(); m["file"] = kDescriptorKey;
const char enum_descriptor_template[] = const char enum_descriptor_template[] =
"$descriptor_name$ = descriptor.EnumDescriptor(\n" "$descriptor_name$ = descriptor.EnumDescriptor(\n"
" name='$name$',\n" " name='$name$',\n"
" full_name='$full_name$',\n" " full_name='$full_name$',\n"
" filename='$filename$',\n" " filename=None,\n"
" file=$file$,\n"
" values=[\n"; " values=[\n";
string options_string; string options_string;
enum_descriptor.options().SerializeToString(&options_string); enum_descriptor.options().SerializeToString(&options_string);
@ -295,9 +366,12 @@ void Generator::PrintEnum(const EnumDescriptor& enum_descriptor) const {
} }
printer_->Outdent(); printer_->Outdent();
printer_->Print("],\n"); printer_->Print("],\n");
printer_->Print("containing_type=None,\n");
printer_->Print("options=$options_value$,\n", printer_->Print("options=$options_value$,\n",
"options_value", "options_value",
OptionsValue("EnumOptions", CEscape(options_string))); OptionsValue("EnumOptions", CEscape(options_string)));
EnumDescriptorProto edp;
PrintSerializedPbInterval(enum_descriptor, edp);
printer_->Outdent(); printer_->Outdent();
printer_->Print(")\n"); printer_->Print(")\n");
printer_->Print("\n"); printer_->Print("\n");
@ -362,15 +436,21 @@ void Generator::PrintServiceDescriptor(
map<string, string> m; map<string, string> m;
m["name"] = descriptor.name(); m["name"] = descriptor.name();
m["full_name"] = descriptor.full_name(); m["full_name"] = descriptor.full_name();
m["file"] = kDescriptorKey;
m["index"] = SimpleItoa(descriptor.index()); m["index"] = SimpleItoa(descriptor.index());
m["options_value"] = OptionsValue("ServiceOptions", options_string); m["options_value"] = OptionsValue("ServiceOptions", options_string);
const char required_function_arguments[] = const char required_function_arguments[] =
"name='$name$',\n" "name='$name$',\n"
"full_name='$full_name$',\n" "full_name='$full_name$',\n"
"file=$file$,\n"
"index=$index$,\n" "index=$index$,\n"
"options=$options_value$,\n" "options=$options_value$,\n";
"methods=[\n";
printer_->Print(m, required_function_arguments); printer_->Print(m, required_function_arguments);
ServiceDescriptorProto sdp;
PrintSerializedPbInterval(descriptor, sdp);
printer_->Print("methods=[\n");
for (int i = 0; i < descriptor.method_count(); ++i) { for (int i = 0; i < descriptor.method_count(); ++i) {
const MethodDescriptor* method = descriptor.method(i); const MethodDescriptor* method = descriptor.method(i);
string options_string; string options_string;
@ -444,17 +524,27 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const {
map<string, string> m; map<string, string> m;
m["name"] = message_descriptor.name(); m["name"] = message_descriptor.name();
m["full_name"] = message_descriptor.full_name(); m["full_name"] = message_descriptor.full_name();
m["filename"] = message_descriptor.file()->name(); m["file"] = kDescriptorKey;
const char required_function_arguments[] = const char required_function_arguments[] =
"name='$name$',\n" "name='$name$',\n"
"full_name='$full_name$',\n" "full_name='$full_name$',\n"
"filename='$filename$',\n" "filename=None,\n"
"containing_type=None,\n"; // TODO(robinson): Implement containing_type. "file=$file$,\n"
"containing_type=None,\n";
printer_->Print(m, required_function_arguments); printer_->Print(m, required_function_arguments);
PrintFieldsInDescriptor(message_descriptor); PrintFieldsInDescriptor(message_descriptor);
PrintExtensionsInDescriptor(message_descriptor); PrintExtensionsInDescriptor(message_descriptor);
// TODO(robinson): implement printing of nested_types.
printer_->Print("nested_types=[], # TODO(robinson): Implement.\n"); // Nested types
printer_->Print("nested_types=[");
for (int i = 0; i < message_descriptor.nested_type_count(); ++i) {
const string nested_name = ModuleLevelDescriptorName(
*message_descriptor.nested_type(i));
printer_->Print("$name$, ", "name", nested_name);
}
printer_->Print("],\n");
// Enum types
printer_->Print("enum_types=[\n"); printer_->Print("enum_types=[\n");
printer_->Indent(); printer_->Indent();
for (int i = 0; i < message_descriptor.enum_type_count(); ++i) { for (int i = 0; i < message_descriptor.enum_type_count(); ++i) {
@ -468,8 +558,28 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const {
string options_string; string options_string;
message_descriptor.options().SerializeToString(&options_string); message_descriptor.options().SerializeToString(&options_string);
printer_->Print( printer_->Print(
"options=$options_value$", "options=$options_value$,\n"
"options_value", OptionsValue("MessageOptions", options_string)); "is_extendable=$extendable$",
"options_value", OptionsValue("MessageOptions", options_string),
"extendable", message_descriptor.extension_range_count() > 0 ?
"True" : "False");
printer_->Print(",\n");
// Extension ranges
printer_->Print("extension_ranges=[");
for (int i = 0; i < message_descriptor.extension_range_count(); ++i) {
const Descriptor::ExtensionRange* range =
message_descriptor.extension_range(i);
printer_->Print("($start$, $end$), ",
"start", SimpleItoa(range->start),
"end", SimpleItoa(range->end));
}
printer_->Print("],\n");
// Serialization of proto
DescriptorProto edp;
PrintSerializedPbInterval(message_descriptor, edp);
printer_->Outdent(); printer_->Outdent();
printer_->Print(")\n"); printer_->Print(")\n");
} }
@ -511,6 +621,12 @@ void Generator::PrintMessage(
m["descriptor_key"] = kDescriptorKey; m["descriptor_key"] = kDescriptorKey;
m["descriptor_name"] = ModuleLevelDescriptorName(message_descriptor); m["descriptor_name"] = ModuleLevelDescriptorName(message_descriptor);
printer_->Print(m, "$descriptor_key$ = $descriptor_name$\n"); printer_->Print(m, "$descriptor_key$ = $descriptor_name$\n");
printer_->Print(
"\n"
"# @@protoc_insertion_point(class_scope:$full_name$)\n",
"full_name", message_descriptor.full_name());
printer_->Outdent(); printer_->Outdent();
} }
@ -527,16 +643,27 @@ void Generator::PrintNestedMessages(
// Recursively fixes foreign fields in all nested types in |descriptor|, then // Recursively fixes foreign fields in all nested types in |descriptor|, then
// sets the message_type and enum_type of all message and enum fields to point // sets the message_type and enum_type of all message and enum fields to point
// to their respective descriptors. // to their respective descriptors.
// Args:
// descriptor: descriptor to print fields for.
// containing_descriptor: if descriptor is a nested type, this is its
// containing type, or NULL if this is a root/top-level type.
void Generator::FixForeignFieldsInDescriptor( void Generator::FixForeignFieldsInDescriptor(
const Descriptor& descriptor) const { const Descriptor& descriptor,
const Descriptor* containing_descriptor) const {
for (int i = 0; i < descriptor.nested_type_count(); ++i) { for (int i = 0; i < descriptor.nested_type_count(); ++i) {
FixForeignFieldsInDescriptor(*descriptor.nested_type(i)); FixForeignFieldsInDescriptor(*descriptor.nested_type(i), &descriptor);
} }
for (int i = 0; i < descriptor.field_count(); ++i) { for (int i = 0; i < descriptor.field_count(); ++i) {
const FieldDescriptor& field_descriptor = *descriptor.field(i); const FieldDescriptor& field_descriptor = *descriptor.field(i);
FixForeignFieldsInField(&descriptor, field_descriptor, "fields_by_name"); FixForeignFieldsInField(&descriptor, field_descriptor, "fields_by_name");
} }
FixContainingTypeInDescriptor(descriptor, containing_descriptor);
for (int i = 0; i < descriptor.enum_type_count(); ++i) {
const EnumDescriptor& enum_descriptor = *descriptor.enum_type(i);
FixContainingTypeInDescriptor(enum_descriptor, &descriptor);
}
} }
// Sets any necessary message_type and enum_type attributes // Sets any necessary message_type and enum_type attributes
@ -593,13 +720,29 @@ string Generator::FieldReferencingExpression(
python_dict_name, field.name()); python_dict_name, field.name());
} }
// Prints containing_type for nested descriptors or enum descriptors.
template <typename DescriptorT>
void Generator::FixContainingTypeInDescriptor(
const DescriptorT& descriptor,
const Descriptor* containing_descriptor) const {
if (containing_descriptor != NULL) {
const string nested_name = ModuleLevelDescriptorName(descriptor);
const string parent_name = ModuleLevelDescriptorName(
*containing_descriptor);
printer_->Print(
"$nested_name$.containing_type = $parent_name$;\n",
"nested_name", nested_name,
"parent_name", parent_name);
}
}
// Prints statements setting the message_type and enum_type fields in the // Prints statements setting the message_type and enum_type fields in the
// Python descriptor objects we've already output in ths file. We must // Python descriptor objects we've already output in ths file. We must
// do this in a separate step due to circular references (otherwise, we'd // do this in a separate step due to circular references (otherwise, we'd
// just set everything in the initial assignment statements). // just set everything in the initial assignment statements).
void Generator::FixForeignFieldsInDescriptors() const { void Generator::FixForeignFieldsInDescriptors() const {
for (int i = 0; i < file_->message_type_count(); ++i) { for (int i = 0; i < file_->message_type_count(); ++i) {
FixForeignFieldsInDescriptor(*file_->message_type(i)); FixForeignFieldsInDescriptor(*file_->message_type(i), NULL);
} }
printer_->Print("\n"); printer_->Print("\n");
} }
@ -696,6 +839,7 @@ void Generator::PrintFieldDescriptor(
m["type"] = SimpleItoa(field.type()); m["type"] = SimpleItoa(field.type());
m["cpp_type"] = SimpleItoa(field.cpp_type()); m["cpp_type"] = SimpleItoa(field.cpp_type());
m["label"] = SimpleItoa(field.label()); m["label"] = SimpleItoa(field.label());
m["has_default_value"] = field.has_default_value() ? "True" : "False";
m["default_value"] = StringifyDefaultValue(field); m["default_value"] = StringifyDefaultValue(field);
m["is_extension"] = is_extension ? "True" : "False"; m["is_extension"] = is_extension ? "True" : "False";
m["options"] = OptionsValue("FieldOptions", options_string); m["options"] = OptionsValue("FieldOptions", options_string);
@ -703,13 +847,13 @@ void Generator::PrintFieldDescriptor(
// these fields in correctly after all referenced descriptors have been // these fields in correctly after all referenced descriptors have been
// defined and/or imported (see FixForeignFieldsInDescriptors()). // defined and/or imported (see FixForeignFieldsInDescriptors()).
const char field_descriptor_decl[] = const char field_descriptor_decl[] =
"descriptor.FieldDescriptor(\n" "descriptor.FieldDescriptor(\n"
" name='$name$', full_name='$full_name$', index=$index$,\n" " name='$name$', full_name='$full_name$', index=$index$,\n"
" number=$number$, type=$type$, cpp_type=$cpp_type$, label=$label$,\n" " number=$number$, type=$type$, cpp_type=$cpp_type$, label=$label$,\n"
" default_value=$default_value$,\n" " has_default_value=$has_default_value$, default_value=$default_value$,\n"
" message_type=None, enum_type=None, containing_type=None,\n" " message_type=None, enum_type=None, containing_type=None,\n"
" is_extension=$is_extension$, extension_scope=None,\n" " is_extension=$is_extension$, extension_scope=None,\n"
" options=$options$)"; " options=$options$)";
printer_->Print(m, field_descriptor_decl); printer_->Print(m, field_descriptor_decl);
} }
@ -811,6 +955,29 @@ string Generator::ModuleLevelServiceDescriptorName(
return name; return name;
} }
// Prints standard constructor arguments serialized_start and serialized_end.
// Args:
// descriptor: The cpp descriptor to have a serialized reference.
// proto: A proto
// Example printer output:
// serialized_start=41,
// serialized_end=43,
//
template <typename DescriptorT, typename DescriptorProtoT>
void Generator::PrintSerializedPbInterval(
const DescriptorT& descriptor, DescriptorProtoT& proto) const {
descriptor.CopyTo(&proto);
string sp;
proto.SerializeToString(&sp);
int offset = file_descriptor_serialized_.find(sp);
GOOGLE_CHECK_GE(offset, 0);
printer_->Print("serialized_start=$serialized_start$,\n"
"serialized_end=$serialized_end$,\n",
"serialized_start", SimpleItoa(offset),
"serialized_end", SimpleItoa(offset + sp.size()));
}
} // namespace python } // namespace python
} // namespace compiler } // namespace compiler
} // namespace protobuf } // namespace protobuf

View File

@ -71,6 +71,7 @@ class LIBPROTOC_EXPORT Generator : public CodeGenerator {
private: private:
void PrintImports() const; void PrintImports() const;
void PrintFileDescriptor() const;
void PrintTopLevelEnums() const; void PrintTopLevelEnums() const;
void PrintAllNestedEnumsInFile() const; void PrintAllNestedEnumsInFile() const;
void PrintNestedEnums(const Descriptor& descriptor) const; void PrintNestedEnums(const Descriptor& descriptor) const;
@ -97,13 +98,19 @@ class LIBPROTOC_EXPORT Generator : public CodeGenerator {
void PrintNestedMessages(const Descriptor& containing_descriptor) const; void PrintNestedMessages(const Descriptor& containing_descriptor) const;
void FixForeignFieldsInDescriptors() const; void FixForeignFieldsInDescriptors() const;
void FixForeignFieldsInDescriptor(const Descriptor& descriptor) const; void FixForeignFieldsInDescriptor(
const Descriptor& descriptor,
const Descriptor* containing_descriptor) const;
void FixForeignFieldsInField(const Descriptor* containing_type, void FixForeignFieldsInField(const Descriptor* containing_type,
const FieldDescriptor& field, const FieldDescriptor& field,
const string& python_dict_name) const; const string& python_dict_name) const;
string FieldReferencingExpression(const Descriptor* containing_type, string FieldReferencingExpression(const Descriptor* containing_type,
const FieldDescriptor& field, const FieldDescriptor& field,
const string& python_dict_name) const; const string& python_dict_name) const;
template <typename DescriptorT>
void FixContainingTypeInDescriptor(
const DescriptorT& descriptor,
const Descriptor* containing_descriptor) const;
void FixForeignFieldsInExtensions() const; void FixForeignFieldsInExtensions() const;
void FixForeignFieldsInExtension( void FixForeignFieldsInExtension(
@ -126,10 +133,15 @@ class LIBPROTOC_EXPORT Generator : public CodeGenerator {
string ModuleLevelServiceDescriptorName( string ModuleLevelServiceDescriptorName(
const ServiceDescriptor& descriptor) const; const ServiceDescriptor& descriptor) const;
template <typename DescriptorT, typename DescriptorProtoT>
void PrintSerializedPbInterval(
const DescriptorT& descriptor, DescriptorProtoT& proto) const;
// Very coarse-grained lock to ensure that Generate() is reentrant. // Very coarse-grained lock to ensure that Generate() is reentrant.
// Guards file_ and printer_. // Guards file_, printer_ and file_descriptor_serialized_.
mutable Mutex mutex_; mutable Mutex mutex_;
mutable const FileDescriptor* file_; // Set in Generate(). Under mutex_. mutable const FileDescriptor* file_; // Set in Generate(). Under mutex_.
mutable string file_descriptor_serialized_;
mutable io::Printer* printer_; // Set in Generate(). Under mutex_. mutable io::Printer* printer_; // Set in Generate(). Under mutex_.
GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Generator); GOOGLE_DISALLOW_EVIL_CONSTRUCTORS(Generator);

View File

@ -796,9 +796,10 @@ bool DescriptorPool::InternalIsFileLoaded(const string& filename) const {
namespace { namespace {
EncodedDescriptorDatabase* generated_database_ = NULL; EncodedDescriptorDatabase* generated_database_ = NULL;
DescriptorPool* generated_pool_ = NULL; DescriptorPool* generated_pool_ = NULL;
GOOGLE_PROTOBUF_DECLARE_ONCE(generated_pool_init_); GoogleOnceType generated_pool_init_;
void DeleteGeneratedPool() { void DeleteGeneratedPool() {
delete generated_database_; delete generated_database_;
@ -810,6 +811,7 @@ void DeleteGeneratedPool() {
void InitGeneratedPool() { void InitGeneratedPool() {
generated_database_ = new EncodedDescriptorDatabase; generated_database_ = new EncodedDescriptorDatabase;
generated_pool_ = new DescriptorPool(generated_database_); generated_pool_ = new DescriptorPool(generated_database_);
internal::OnShutdown(&DeleteGeneratedPool); internal::OnShutdown(&DeleteGeneratedPool);
} }
@ -3651,17 +3653,11 @@ void DescriptorBuilder::ValidateFieldOptions(FieldDescriptor* field,
} }
// Only repeated primitive fields may be packed. // Only repeated primitive fields may be packed.
if (field->options().packed()) { if (field->options().packed() && !field->is_packable()) {
if (!field->is_repeated() || AddError(
field->type() == FieldDescriptor::TYPE_STRING || field->full_name(), proto,
field->type() == FieldDescriptor::TYPE_GROUP || DescriptorPool::ErrorCollector::TYPE,
field->type() == FieldDescriptor::TYPE_MESSAGE || "[packed = true] can only be specified for repeated primitive fields.");
field->type() == FieldDescriptor::TYPE_BYTES) {
AddError(
field->full_name(), proto,
DescriptorPool::ErrorCollector::TYPE,
"[packed = true] can only be specified for repeated primitive fields.");
}
} }
// Note: Default instance may not yet be initialized here, so we have to // Note: Default instance may not yet be initialized here, so we have to

View File

@ -395,6 +395,8 @@ class LIBPROTOBUF_EXPORT FieldDescriptor {
bool is_required() const; // shorthand for label() == LABEL_REQUIRED bool is_required() const; // shorthand for label() == LABEL_REQUIRED
bool is_optional() const; // shorthand for label() == LABEL_OPTIONAL bool is_optional() const; // shorthand for label() == LABEL_OPTIONAL
bool is_repeated() const; // shorthand for label() == LABEL_REPEATED bool is_repeated() const; // shorthand for label() == LABEL_REPEATED
bool is_packable() const; // shorthand for is_repeated() &&
// IsTypePackable(type())
// Index of this field within the message's field array, or the file or // Index of this field within the message's field array, or the file or
// extension scope's extensions array. // extension scope's extensions array.
@ -474,6 +476,9 @@ class LIBPROTOBUF_EXPORT FieldDescriptor {
// Helper method to get the CppType for a particular Type. // Helper method to get the CppType for a particular Type.
static CppType TypeToCppType(Type type); static CppType TypeToCppType(Type type);
// Return true iff [packed = true] is valid for fields of this type.
static inline bool IsTypePackable(Type field_type);
private: private:
typedef FieldOptions OptionsType; typedef FieldOptions OptionsType;
@ -1069,10 +1074,6 @@ class LIBPROTOBUF_EXPORT DescriptorPool {
// These methods may contain hidden pitfalls and may be removed in a // These methods may contain hidden pitfalls and may be removed in a
// future library version. // future library version.
// DEPRECATED: Use of underlays can lead to many subtle gotchas. Instead,
// try to formulate what you want to do in terms of DescriptorDatabases.
// This constructor will be removed soon.
//
// Create a DescriptorPool which is overlaid on top of some other pool. // Create a DescriptorPool which is overlaid on top of some other pool.
// If you search for a descriptor in the overlay and it is not found, the // If you search for a descriptor in the overlay and it is not found, the
// underlay will be searched as a backup. If the underlay has its own // underlay will be searched as a backup. If the underlay has its own
@ -1090,6 +1091,9 @@ class LIBPROTOBUF_EXPORT DescriptorPool {
// types directly into generated_pool(): this is not allowed, and would be // types directly into generated_pool(): this is not allowed, and would be
// bad design anyway. So, instead, you could use generated_pool() as an // bad design anyway. So, instead, you could use generated_pool() as an
// underlay for a new DescriptorPool in which you add only the new file. // underlay for a new DescriptorPool in which you add only the new file.
//
// WARNING: Use of underlays can lead to many subtle gotchas. Instead,
// try to formulate what you want to do in terms of DescriptorDatabases.
explicit DescriptorPool(const DescriptorPool* underlay); explicit DescriptorPool(const DescriptorPool* underlay);
// Called by generated classes at init time to add their descriptors to // Called by generated classes at init time to add their descriptors to
@ -1294,6 +1298,10 @@ inline bool FieldDescriptor::is_repeated() const {
return label() == LABEL_REPEATED; return label() == LABEL_REPEATED;
} }
inline bool FieldDescriptor::is_packable() const {
return is_repeated() && IsTypePackable(type());
}
// To save space, index() is computed by looking at the descriptor's position // To save space, index() is computed by looking at the descriptor's position
// in the parent's array of children. // in the parent's array of children.
inline int FieldDescriptor::index() const { inline int FieldDescriptor::index() const {
@ -1342,6 +1350,13 @@ inline FieldDescriptor::CppType FieldDescriptor::TypeToCppType(Type type) {
return kTypeToCppTypeMap[type]; return kTypeToCppTypeMap[type];
} }
inline bool FieldDescriptor::IsTypePackable(Type field_type) {
return (field_type != FieldDescriptor::TYPE_STRING &&
field_type != FieldDescriptor::TYPE_GROUP &&
field_type != FieldDescriptor::TYPE_MESSAGE &&
field_type != FieldDescriptor::TYPE_BYTES);
}
inline const FileDescriptor* FileDescriptor::dependency(int index) const { inline const FileDescriptor* FileDescriptor::dependency(int index) const {
return dependencies_[index]; return dependencies_[index];
} }

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -256,6 +256,22 @@ message FileOptions {
// Should generic services be generated in each language? "Generic" services
// are not specific to any particular RPC system. They are generated by the
// main code generators in each language (without additional plugins).
// Generic services were the only kind of service generation supported by
// early versions of proto2.
//
// Generic services are now considered deprecated in favor of using plugins
// that generate code specific to your particular RPC system. If you are
// using such a plugin, set these to false. In the future, we may change
// the default to false, so if you explicitly want generic services, you
// should explicitly set these to true.
optional bool cc_generic_services = 16 [default=true];
optional bool java_generic_services = 17 [default=true];
optional bool py_generic_services = 18 [default=true];
// The parser stores options it doesn't recognize here. See above. // The parser stores options it doesn't recognize here. See above.
repeated UninterpretedOption uninterpreted_option = 999; repeated UninterpretedOption uninterpreted_option = 999;
@ -301,8 +317,11 @@ message FieldOptions {
// representation of the field than it normally would. See the specific // representation of the field than it normally would. See the specific
// options below. This option is not yet implemented in the open source // options below. This option is not yet implemented in the open source
// release -- sorry, we'll try to include it in a future version! // release -- sorry, we'll try to include it in a future version!
optional CType ctype = 1; optional CType ctype = 1 [default = STRING];
enum CType { enum CType {
// Default mode.
STRING = 0;
CORD = 1; CORD = 1;
STRING_PIECE = 2; STRING_PIECE = 2;
@ -313,6 +332,7 @@ message FieldOptions {
// a single length-delimited blob. // a single length-delimited blob.
optional bool packed = 2; optional bool packed = 2;
// Is this field deprecated? // Is this field deprecated?
// Depending on the target platform, this can emit Deprecated annotations // Depending on the target platform, this can emit Deprecated annotations
// for accessors, or it will be completely ignored; in the very least, this // for accessors, or it will be completely ignored; in the very least, this

View File

@ -37,6 +37,7 @@
#include <set> #include <set>
#include <google/protobuf/descriptor.pb.h> #include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/wire_format_lite_inl.h>
#include <google/protobuf/stubs/strutil.h> #include <google/protobuf/stubs/strutil.h>
#include <google/protobuf/stubs/stl_util-inl.h> #include <google/protobuf/stubs/stl_util-inl.h>
#include <google/protobuf/stubs/map-util.h> #include <google/protobuf/stubs/map-util.h>
@ -336,6 +337,35 @@ bool EncodedDescriptorDatabase::FindFileContainingSymbol(
return MaybeParse(index_.FindSymbol(symbol_name), output); return MaybeParse(index_.FindSymbol(symbol_name), output);
} }
bool EncodedDescriptorDatabase::FindNameOfFileContainingSymbol(
const string& symbol_name,
string* output) {
pair<const void*, int> encoded_file = index_.FindSymbol(symbol_name);
if (encoded_file.first == NULL) return false;
// Optimization: The name should be the first field in the encoded message.
// Try to just read it directly.
io::CodedInputStream input(reinterpret_cast<const uint8*>(encoded_file.first),
encoded_file.second);
const uint32 kNameTag = internal::WireFormatLite::MakeTag(
FileDescriptorProto::kNameFieldNumber,
internal::WireFormatLite::WIRETYPE_LENGTH_DELIMITED);
if (input.ReadTag() == kNameTag) {
// Success!
return internal::WireFormatLite::ReadString(&input, output);
} else {
// Slow path. Parse whole message.
FileDescriptorProto file_proto;
if (!file_proto.ParseFromArray(encoded_file.first, encoded_file.second)) {
return false;
}
*output = file_proto.name();
return true;
}
}
bool EncodedDescriptorDatabase::FindFileContainingExtension( bool EncodedDescriptorDatabase::FindFileContainingExtension(
const string& containing_type, const string& containing_type,
int field_number, int field_number,

View File

@ -280,6 +280,10 @@ class LIBPROTOBUF_EXPORT EncodedDescriptorDatabase : public DescriptorDatabase {
// need to keep it around. // need to keep it around.
bool AddCopy(const void* encoded_file_descriptor, int size); bool AddCopy(const void* encoded_file_descriptor, int size);
// Like FindFileContainingSymbol but returns only the name of the file.
bool FindNameOfFileContainingSymbol(const string& symbol_name,
string* output);
// implements DescriptorDatabase ----------------------------------- // implements DescriptorDatabase -----------------------------------
bool FindFileByName(const string& filename, bool FindFileByName(const string& filename,
FileDescriptorProto* output); FileDescriptorProto* output);

Some files were not shown because too many files have changed in this diff Show More