diff --git a/ruby/ext/google/protobuf_c/storage.c b/ruby/ext/google/protobuf_c/storage.c index b1f65f413..1c8397814 100644 --- a/ruby/ext/google/protobuf_c/storage.c +++ b/ruby/ext/google/protobuf_c/storage.c @@ -57,6 +57,37 @@ size_t native_slot_size(upb_fieldtype_t type) { } } +static VALUE value_from_default(const upb_fielddef *field) { + switch (upb_fielddef_type(field)) { + case UPB_TYPE_FLOAT: return DBL2NUM(upb_fielddef_defaultfloat(field)); + case UPB_TYPE_DOUBLE: return DBL2NUM(upb_fielddef_defaultdouble(field)); + case UPB_TYPE_BOOL: + return upb_fielddef_defaultbool(field) ? Qtrue : Qfalse; + case UPB_TYPE_MESSAGE: return Qnil; + case UPB_TYPE_ENUM: { + const upb_enumdef *enumdef = upb_fielddef_enumsubdef(field); + int32_t num = upb_fielddef_defaultint32(field); + const char *label = upb_enumdef_iton(enumdef, num); + if (label) { + return ID2SYM(rb_intern(label)); + } else { + return INT2NUM(num); + } + } + case UPB_TYPE_INT32: return INT2NUM(upb_fielddef_defaultint32(field)); + case UPB_TYPE_INT64: return LL2NUM(upb_fielddef_defaultint64(field));; + case UPB_TYPE_UINT32: return UINT2NUM(upb_fielddef_defaultuint32(field)); + case UPB_TYPE_UINT64: return ULL2NUM(upb_fielddef_defaultuint64(field)); + case UPB_TYPE_STRING: + case UPB_TYPE_BYTES: { + size_t size; + const char *str = upb_fielddef_defaultstr(field, &size); + return rb_str_new(str, size); + } + default: return Qnil; + } +} + static bool is_ruby_num(VALUE value) { return (TYPE(value) == T_FLOAT || TYPE(value) == T_FIXNUM || @@ -537,7 +568,7 @@ VALUE layout_get(MessageLayout* layout, if (upb_fielddef_containingoneof(field)) { if (*oneof_case != upb_fielddef_number(field)) { - return Qnil; + return value_from_default(field); } return native_slot_get(upb_fielddef_type(field), field_type_class(field), diff --git a/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java b/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java index 39213c4d1..12893f731 100644 --- a/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java +++ b/ruby/src/main/java/com/google/protobuf/jruby/RubyMessage.java @@ -592,13 +592,17 @@ public class RubyMessage extends RubyObject { protected IRubyObject getField(ThreadContext context, Descriptors.FieldDescriptor fieldDescriptor) { Descriptors.OneofDescriptor oneofDescriptor = fieldDescriptor.getContainingOneof(); if (oneofDescriptor != null) { - if (oneofCases.containsKey(oneofDescriptor)) { - if (oneofCases.get(oneofDescriptor) != fieldDescriptor) - return context.runtime.getNil(); + if (oneofCases.get(oneofDescriptor) == fieldDescriptor) { return fields.get(fieldDescriptor); } else { Descriptors.FieldDescriptor oneofCase = builder.getOneofFieldDescriptor(oneofDescriptor); - if (oneofCase != fieldDescriptor) return context.runtime.getNil(); + if (oneofCase != fieldDescriptor) { + if (fieldDescriptor.getType() == Descriptors.FieldDescriptor.Type.MESSAGE) { + return context.runtime.getNil(); + } else { + return wrapField(context, fieldDescriptor, fieldDescriptor.getDefaultValue()); + } + } IRubyObject value = wrapField(context, oneofCase, builder.getField(oneofCase)); fields.put(fieldDescriptor, value); return value; diff --git a/ruby/tests/basic.rb b/ruby/tests/basic.rb index 77c186ef3..fee07e333 100644 --- a/ruby/tests/basic.rb +++ b/ruby/tests/basic.rb @@ -703,36 +703,36 @@ module BasicTest def test_oneof d = OneofMessage.new - assert d.a == nil - assert d.b == nil + assert d.a == "" + assert d.b == 0 assert d.c == nil - assert d.d == nil + assert d.d == :Default assert d.my_oneof == nil d.a = "hi" assert d.a == "hi" - assert d.b == nil + assert d.b == 0 assert d.c == nil - assert d.d == nil + assert d.d == :Default assert d.my_oneof == :a d.b = 42 - assert d.a == nil + assert d.a == "" assert d.b == 42 assert d.c == nil - assert d.d == nil + assert d.d == :Default assert d.my_oneof == :b d.c = TestMessage2.new(:foo => 100) - assert d.a == nil - assert d.b == nil + assert d.a == "" + assert d.b == 0 assert d.c.foo == 100 - assert d.d == nil + assert d.d == :Default assert d.my_oneof == :c d.d = :C - assert d.a == nil - assert d.b == nil + assert d.a == "" + assert d.b == 0 assert d.c == nil assert d.d == :C assert d.my_oneof == :d @@ -748,23 +748,23 @@ module BasicTest d3 = OneofMessage.decode( encoded_field_c + encoded_field_a + encoded_field_d) - assert d3.a == nil - assert d3.b == nil + assert d3.a == "" + assert d3.b == 0 assert d3.c == nil assert d3.d == :B d4 = OneofMessage.decode( encoded_field_c + encoded_field_a + encoded_field_d + encoded_field_c) - assert d4.a == nil - assert d4.b == nil + assert d4.a == "" + assert d4.b == 0 assert d4.c.foo == 1 - assert d4.d == nil + assert d4.d == :Default d5 = OneofMessage.new(:a => "hello") - assert d5.a != nil + assert d5.a == "hello" d5.a = nil - assert d5.a == nil + assert d5.a == "" assert OneofMessage.encode(d5) == '' assert d5.my_oneof == nil end