diff --git a/src/amberscript/parser.cc b/src/amberscript/parser.cc index 72b3650e..ec8ad1e8 100644 --- a/src/amberscript/parser.cc +++ b/src/amberscript/parser.cc @@ -2952,7 +2952,8 @@ Result Parser::ParseBufferInitializerFill(Buffer* buffer, } auto fmt = buffer->GetFormat(); - bool is_double_data = fmt->IsFloat32() || fmt->IsFloat64(); + bool is_float_data = + fmt->IsFloat16() || fmt->IsFloat32() || fmt->IsFloat64(); // Inflate the size because our items are multi-dimensional. size_in_items = size_in_items * fmt->InputNeededPerElement(); @@ -2960,7 +2961,7 @@ Result Parser::ParseBufferInitializerFill(Buffer* buffer, std::vector values; values.resize(size_in_items); for (size_t i = 0; i < size_in_items; ++i) { - if (is_double_data) { + if (is_float_data) { values[i].SetDoubleValue(token->AsDouble()); } else { values[i].SetIntValue(token->AsUint64()); @@ -2994,7 +2995,8 @@ Result Parser::ParseBufferInitializerSeries(Buffer* buffer, auto n = type->AsNumber(); FormatMode mode = n->GetFormatMode(); uint32_t num_bits = n->NumBits(); - if (type::Type::IsFloat32(mode, num_bits) || + if (type::Type::IsFloat16(mode, num_bits) || + type::Type::IsFloat32(mode, num_bits) || type::Type::IsFloat64(mode, num_bits)) { counter.SetDoubleValue(token->AsDouble()); } else { @@ -3020,7 +3022,8 @@ Result Parser::ParseBufferInitializerSeries(Buffer* buffer, std::vector values; values.resize(size_in_items); for (size_t i = 0; i < size_in_items; ++i) { - if (type::Type::IsFloat32(mode, num_bits) || + if (type::Type::IsFloat16(mode, num_bits) || + type::Type::IsFloat32(mode, num_bits) || type::Type::IsFloat64(mode, num_bits)) { double value = counter.AsDouble(); values[i].SetDoubleValue(value); diff --git a/src/amberscript/parser_buffer_test.cc b/src/amberscript/parser_buffer_test.cc index ab4824ad..7ce18074 100644 --- a/src/amberscript/parser_buffer_test.cc +++ b/src/amberscript/parser_buffer_test.cc @@ -14,6 +14,7 @@ #include "gtest/gtest.h" #include "src/amberscript/parser.h" +#include "src/float16_helper.h" namespace amber { namespace amberscript { @@ -400,6 +401,34 @@ TEST_F(AmberScriptParserTest, BufferFillFloat) { } } +TEST_F(AmberScriptParserTest, BufferFillFloat16) { + std::string in = "BUFFER my_buffer DATA_TYPE float16 SIZE 5 FILL 5.5"; + + Parser parser; + Result r = parser.Parse(in); + ASSERT_TRUE(r.IsSuccess()) << r.Error(); + + auto script = parser.GetScript(); + const auto& buffers = script->GetBuffers(); + ASSERT_EQ(1U, buffers.size()); + + ASSERT_TRUE(buffers[0] != nullptr); + + auto* buffer = buffers[0].get(); + EXPECT_EQ("my_buffer", buffer->GetName()); + EXPECT_EQ(Format::Layout::kStd430, buffer->GetFormat()->GetLayout()); + EXPECT_EQ(5U, buffer->ElementCount()); + EXPECT_EQ(5U, buffer->ValueCount()); + EXPECT_EQ(5U * sizeof(uint16_t), buffer->GetSizeInBytes()); + + const auto* data = buffer->GetValues(); + for (size_t i = 0; i < buffer->ValueCount(); ++i) { + EXPECT_FLOAT_EQ( + 5.5f, float16::HexFloatToFloat( + reinterpret_cast(&data[i]), 16)); + } +} + TEST_F(AmberScriptParserTest, BufferSeries) { std::string in = "BUFFER my_buffer DATA_TYPE uint8 SIZE 5 SERIES_FROM 2 INC_BY 1"; @@ -461,6 +490,38 @@ TEST_F(AmberScriptParserTest, BufferSeriesFloat) { } } +TEST_F(AmberScriptParserTest, BufferSeriesFloat16) { + std::string in = + "BUFFER my_buffer DATA_TYPE float16 SIZE 5 SERIES_FROM 2.5 INC_BY " + "0.5"; + + Parser parser; + Result r = parser.Parse(in); + ASSERT_TRUE(r.IsSuccess()) << r.Error(); + + auto script = parser.GetScript(); + const auto& buffers = script->GetBuffers(); + ASSERT_EQ(1U, buffers.size()); + + ASSERT_TRUE(buffers[0] != nullptr); + + auto* buffer = buffers[0].get(); + EXPECT_EQ("my_buffer", buffer->GetName()); + EXPECT_EQ(Format::Layout::kStd430, buffer->GetFormat()->GetLayout()); + EXPECT_EQ(5U, buffer->ElementCount()); + EXPECT_EQ(5U, buffer->ValueCount()); + EXPECT_EQ(5U * sizeof(uint16_t), buffer->GetSizeInBytes()); + + std::vector results = {2.5f, 3.0f, 3.5f, 4.0f, 4.5f}; + const auto* data = buffer->GetValues(); + ASSERT_EQ(results.size(), buffer->ValueCount()); + for (size_t i = 0; i < results.size(); ++i) { + EXPECT_FLOAT_EQ( + results[i], float16::HexFloatToFloat( + reinterpret_cast(&data[i]), 16)); + } +} + TEST_F(AmberScriptParserTest, BufferMultipleBuffers) { std::string in = R"( BUFFER color_buffer DATA_TYPE uint8 SIZE 5 FILL 5 diff --git a/src/format.h b/src/format.h index 2e917058..4eac62bb 100644 --- a/src/format.h +++ b/src/format.h @@ -199,6 +199,12 @@ class Format { type::Type::IsUint64(type_->AsNumber()->GetFormatMode(), type_->AsNumber()->NumBits()); } + /// Returns true if all components of this format are a 16 bit float. + bool IsFloat16() const { + return type_->IsNumber() && + type::Type::IsFloat16(type_->AsNumber()->GetFormatMode(), + type_->AsNumber()->NumBits()); + } /// Returns true if all components of this format are a 32 bit float. bool IsFloat32() const { return type_->IsNumber() &&